o
    io                     @  s   d dl mZ d dlmZ d dlZd dlm  mZ d dlm	Z	 d dl
mZ d dlmZmZmZ ddgZddddZddddZd ddZdd ddZdd ddZdS )!    )annotations)SequenceN)Tensor)NdarrayOrTensor)convert_data_typeconvert_to_dst_typeensure_tuple_reperodedilate         ?maskr   filter_sizeint | Sequence[int]	pad_valuefloatreturnc                 C  2   t | tj^}}t|||d}t|| d^}}|S )a]  
    Erode 2D/3D binary mask.

    Args:
        mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
        filter_size: erosion filter size, has to be odd numbers, default to be 3.
        pad_value: the filled value for padding. We need to pad the input before filtering
                   to keep the output with the same size as input. Usually use default value
                   and not changed.

    Return:
        eroded mask, same shape and data type as input.

    Example:

        .. code-block:: python

            # define a naive mask
            mask = torch.zeros(3,2,3,3,3)
            mask[:,:,1,1,1] = 1.0
            filter_size = 3
            erode_result = erode(mask, filter_size)  # expect torch.zeros(3,2,3,3,3)
            dilate_result = dilate(mask, filter_size)  # expect torch.ones(3,2,3,3,3)
    r   r   srcdst)r   torchr   erode_tr   r   r   r   mask_t_Z
res_mask_tres_mask r   j/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/utils_morphological_ops.pyr	                 c                 C  r   )a\  
    Dilate 2D/3D binary mask.

    Args:
        mask: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor or ndarray.
        filter_size: dilation filter size, has to be odd numbers, default to be 3.
        pad_value: the filled value for padding. We need to pad the input before filtering
                   to keep the output with the same size as input. Usually use default value
                   and not changed.

    Return:
        dilated mask, same shape and data type as input.

    Example:

        .. code-block:: python

            # define a naive mask
            mask = torch.zeros(3,2,3,3,3)
            mask[:,:,1,1,1] = 1.0
            filter_size = 3
            erode_result = erode(mask,filter_size) # expect torch.zeros(3,2,3,3,3)
            dilate_result = dilate(mask,filter_size) # expect torch.ones(3,2,3,3,3)
    r   r   )r   r   r   dilate_tr   r   r   r   r   r
   :   r    r   r   c           	      C  s   t | jd }|dvrtd| d| j dt||}tdd |D r-td| dt| jd	 | jd	 f| | j}d
d |D }t	j
|  |d|d}|dkrYt	jnt	j}|||ddt|d  }|S )a  
    Apply a morphological filter to a 2D/3D binary mask tensor.

    Args:
        mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
        filter_size: morphological filter size, has to be odd numbers.
        pad_value: the filled value for padding. We need to pad the input before filtering
                   to keep the output with the same size as input.

    Return:
        Tensor: Morphological filter result mask, same shape as input.
       )r#   r   z5spatial_dims must be either 2 or 3, got spatial_dims=z for mask tensor with shape of .c                 s  s    | ]	}|d  dkV  qdS )r#   r   Nr   ).0sizer   r   r   	<genexpr>p   s    z4get_morphological_filter_result_t.<locals>.<genexpr>z7All dimensions in filter_size must be odd numbers, got    c                 S  s"   g | ]}t d D ]}|d  qqS )r#   )range)r%   r&   r   r   r   r   
<listcomp>w   s   " z5get_morphological_filter_result_t.<locals>.<listcomp>constant)modevaluer   )padding)r   .)lenshape
ValueErrorr   anyr   onestodeviceFpadr   conv2dconv3dsum)	r   r   r   spatial_dimsZstructuring_elementpad_sizeZinput_paddedZconv_fnoutputr   r   r   !get_morphological_filter_result_tZ   s$   
&r>   c                 C  s,   t | ||}tt|d dk dd}|S )a	  
    Erode 2D/3D binary mask with data type as torch tensor.

    Args:
        mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
        filter_size: erosion filter size, has to be odd numbers, default to be 3.
        pad_value: the filled value for padding. We need to pad the input before filtering
                   to keep the output with the same size as input. Usually use default value
                   and not changed.

    Return:
        Tensor: eroded mask, same shape as input.
    r   gHz>r!   )r>   r   whereabsr   r   r   r=   r   r   r   r      s   r   c                 C  s"   t | ||}t|dkdd}|S )a  
    Dilate 2D/3D binary mask with data type as torch tensor.

    Args:
        mask_t: input 2D/3D binary mask, [N,C,M,N] or [N,C,M,N,P] torch tensor.
        filter_size: dilation filter size, has to be odd numbers, default to be 3.
        pad_value: the filled value for padding. We need to pad the input before filtering
                   to keep the output with the same size as input. Usually use default value
                   and not changed.

    Return:
        Tensor: dilated mask, same shape as input.
    r   r   r!   )r>   r   r?   rA   r   r   r   r"      s   r"   )r   r   )r   r   r   r   r   r   r   r   )r   r!   )r   r   r   r   r   r   r   r   )
__future__r   collections.abcr   r   torch.nn.functionalnn
functionalr6   r   monai.configr   monai.utilsr   r   r   __all__r	   r
   r>   r   r"   r   r   r   r   <module>   s    
 (