U
    PhƯ                     @  s  d Z ddlmZ ddlZddlmZmZmZ ddlZ	ddl
Z
ddlm  mZ ddlmZ ddlmZ ddlmZ ddlmZ dd	lmZmZmZ dd
lmZ ddlmZ ddlm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z' ddl(m)Z) ddl*m+Z+m,Z,m-Z-m.Z.m/Z/ ddl0m1Z1 dddddddddddddgZ2G dd deZ3G dd deZ4G d d deZ5G d!d deZ6G d"d deZ7G d#d deZ8G d$d deZ9G d%d& d&Z:G d'd de:eZ;G d(d de:eZ<G d)d deZ=G d*d deZ>G d+d deZ?G d,d deZ@dS )-zD
A collection of "vanilla" transforms for the model output tensors.
    )annotationsN)CallableIterableSequence)NdarrayOrTensor)get_track_meta)
MetaTensor)one_hot)GaussianFilterapply_filterseparable_filtering)InvertibleTransform)	Transform)ToTensor)convert_applied_interp_modedistance_transform_edt
fill_holes$get_largest_connected_component_maskget_unique_labelsremove_small_objects)unravel_index)TransformBackendsconvert_data_typeconvert_to_tensorensure_tuplelook_up_option)convert_to_dst_typeActivations
AsDiscrete	FillHolesKeepLargestConnectedComponentRemoveSmallObjectsLabelFilterLabelToContourMeanEnsembleProbNMSSobelGradientsVoteEnsembleInvertDistanceTransformEDTc                   @  sF   e Zd ZdZejgZdddddddd	Zdd
dddd
dddZdS )r   a  
    Activation operations, typically `Sigmoid` or `Softmax`.

    Args:
        sigmoid: whether to execute sigmoid function on model output before transform.
            Defaults to ``False``.
        softmax: whether to execute softmax function on model output before transform.
            Defaults to ``False``.
        other: callable function to execute other activation layers, for example:
            `other = lambda x: torch.tanh(x)`. Defaults to ``None``.
        kwargs: additional parameters to `torch.softmax` (used when ``softmax=True``).
            Defaults to ``dim=0``, unrecognized parameters will be ignored.

    Raises:
        TypeError: When ``other`` is not an ``Optional[Callable]``.

    FNboolCallable | NoneNone)sigmoidsoftmaxotherreturnc                 K  sB   || _ || _|| _|d k	r8t|s8tdt|j d|| _d S )N&other must be None or callable but is .)r-   r.   kwargscallable	TypeErrortype__name__r/   )selfr-   r.   r/   r3    r9   P/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/post/array.py__init__R   s    zActivations.__init__r   bool | None)imgr-   r.   r/   r0   c           	      C  s   |r|rt d|dk	r6t|s6tdt|j dt|t d}t|tj	tj
d^}}|sd| jrnt|}|sx| jrtj|| jddd	}|dkr| jn|}|dk	r||}t||^}}|S )
a  
        Args:
            sigmoid: whether to execute sigmoid function on model output before transform.
                Defaults to ``self.sigmoid``.
            softmax: whether to execute softmax function on model output before transform.
                Defaults to ``self.softmax``.
            other: callable function to execute other activation layers, for example:
                `other = torch.tanh`. Defaults to ``self.other``.

        Raises:
            ValueError: When ``sigmoid=True`` and ``softmax=True``. Incompatible values.
            TypeError: When ``other`` is not an ``Optional[Callable]``.
            ValueError: When ``self.other=None`` and ``other=None``. Incompatible values.

        z3Incompatible values: sigmoid=True and softmax=True.Nr1   r2   
track_metadtypedimr   rB   )
ValueErrorr4   r5   r6   r7   r   r   r   torchTensorfloatr-   r.   r3   getr/   r   )	r8   r=   r-   r.   r/   img_t_Zact_funcoutr9   r9   r:   __call__Z   s    


zActivations.__call__)FFN)NNN	r7   
__module____qualname____doc__r   TORCHbackendr;   rL   r9   r9   r9   r:   r   =   s      c                   @  sJ   e Zd ZdZejgZddddddd	d
dZddddddddddZdS )r   aX  
    Convert the input tensor/array into discrete values, possible operations are:

        -  `argmax`.
        -  threshold input value to binary values.
        -  convert input value to One-Hot format (set ``to_one_hot=N``, `N` is the number of classes).
        -  round the value to the closest integer.

    Args:
        argmax: whether to execute argmax function on input data before transform.
            Defaults to ``False``.
        to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
            Defaults to ``None``.
        threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold.
            Defaults to ``None``.
        rounding: if not None, round the data according to the specified option,
            available options: ["torchrounding"].
        kwargs: additional parameters to `torch.argmax`, `monai.networks.one_hot`.
            currently ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
            These default to ``0``, ``True``, ``torch.float`` respectively.

    Example:

        >>> transform = AsDiscrete(argmax=True)
        >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
        # [[[1.0, 1.0]]]

        >>> transform = AsDiscrete(threshold=0.6)
        >>> print(transform(np.array([[[0.0, 0.5], [0.8, 3.0]]])))
        # [[[0.0, 0.0], [1.0, 1.0]]]

        >>> transform = AsDiscrete(argmax=True, to_onehot=2, threshold=0.5)
        >>> print(transform(np.array([[[0.0, 1.0]], [[2.0, 3.0]]])))
        # [[[0.0, 0.0]], [[1.0, 1.0]]]

    FNr*   
int | Nonezfloat | Nonez
str | Noner,   )argmax	to_onehot	thresholdroundingr0   c                 K  s4   || _ t|trtd|| _|| _|| _|| _d S )NQ`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.)rT   
isinstancer*   rD   rU   rV   rW   r3   )r8   rT   rU   rV   rW   r3   r9   r9   r:   r;      s    
zAsDiscrete.__init__r   r<   )r=   rT   rU   rV   rW   r0   c                 C  s2  t |trtdt|t d}t|tj^}}|s:| jr^tj|| j	
dd| j	
ddd}|dkrl| jn|}|dk	rt |tstd	t| d
t||| j	
dd| j	
dtjd}|dkr| jn|}|dk	r||k}|dkr| jn|}|dk	rt|dg t|}t||| j	
dtjd^}}|S )a   
        Args:
            img: the input tensor data to convert, if no channel dimension when converting to `One-Hot`,
                will automatically add it.
            argmax: whether to execute argmax function on input data before transform.
                Defaults to ``self.argmax``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                Defaults to ``self.to_onehot``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                Defaults to ``self.threshold``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"].

        rX   r>   rB   r   keepdimTrB   rZ   Nz:the number of classes for One-Hot must be an integer, got r2   rA   )num_classesrB   rA   Ztorchroundingr@   )rY   r*   rD   r   r   r   rE   rF   rT   r3   rH   rU   intr6   r	   rG   rV   rW   r   roundr   )r8   r=   rT   rU   rV   rW   rI   rJ   r9   r9   r:   rL      s2    

$
   

zAsDiscrete.__call__)FNNN)NNNNrM   r9   r9   r9   r:   r      s   %        c                      sN   e Zd ZdZejejgZdddddd	d
d fddZdddddZ	  Z
S )r    a  
    Keeps only the largest connected component in the image.
    This transform can be used as a post-processing step to clean up over-segment areas in model output.

    The input is assumed to be a channel-first PyTorch Tensor:
      1) For not OneHot format data, the values correspond to expected labels,
      0 will be treated as background and the over-segment pixels will be set to 0.
      2) For OneHot format data, the values should be 0, 1 on each labels,
      the over-segment pixels will be set to 0 in its channel.

    For example:
    Use with applied_labels=[1], is_onehot=False, connectivity=1::

       [1, 0, 0]         [0, 0, 0]
       [0, 1, 1]    =>   [0, 1 ,1]
       [0, 1, 1]         [0, 1, 1]

    Use with applied_labels=[1, 2], is_onehot=False, independent=False, connectivity=1::

      [0, 0, 1, 0 ,0]           [0, 0, 1, 0 ,0]
      [0, 2, 1, 1 ,1]           [0, 2, 1, 1 ,1]
      [1, 2, 1, 0 ,0]    =>     [1, 2, 1, 0 ,0]
      [1, 2, 0, 1 ,0]           [1, 2, 0, 0 ,0]
      [2, 2, 0, 0 ,2]           [2, 2, 0, 0 ,0]

    Use with applied_labels=[1, 2], is_onehot=False, independent=True, connectivity=1::

      [0, 0, 1, 0 ,0]           [0, 0, 1, 0 ,0]
      [0, 2, 1, 1 ,1]           [0, 2, 1, 1 ,1]
      [1, 2, 1, 0 ,0]    =>     [0, 2, 1, 0 ,0]
      [1, 2, 0, 1 ,0]           [0, 2, 0, 0 ,0]
      [2, 2, 0, 0 ,2]           [2, 2, 0, 0 ,0]

    Use with applied_labels=[1, 2], is_onehot=False, independent=False, connectivity=2::

      [0, 0, 1, 0 ,0]           [0, 0, 1, 0 ,0]
      [0, 2, 1, 1 ,1]           [0, 2, 1, 1 ,1]
      [1, 2, 1, 0 ,0]    =>     [1, 2, 1, 0 ,0]
      [1, 2, 0, 1 ,0]           [1, 2, 0, 1 ,0]
      [2, 2, 0, 0 ,2]           [2, 2, 0, 0 ,2]

    NT   Sequence[int] | int | Noner<   r*   rS   r]   r,   )applied_labels	is_onehotindependentconnectivitynum_componentsr0   c                   s<   t    |dk	rt|nd| _|| _|| _|| _|| _dS )a1  
        Args:
            applied_labels: Labels for applying the connected component analysis on.
                If given, voxels whose value is in this list will be analyzed.
                If `None`, all non-zero values will be analyzed.
            is_onehot: if `True`, treat the input data as OneHot format data, otherwise, not OneHot format data.
                default to None, which treats multi-channel data as OneHot and single channel data as not OneHot.
            independent: whether to treat ``applied_labels`` as a union of foreground labels.
                If ``True``, the connected component analysis will be performed on each foreground label independently
                and return the intersection of the largest components.
                If ``False``, the analysis will be performed on the union of foreground labels.
                default is `True`.
            connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
                Accepted values are ranging from  1 to input.ndim. If ``None``, a full
                connectivity of ``input.ndim`` is used. for more details:
                https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
            num_components: The number of largest components to preserve.

        N)superr;   r   ra   rb   rc   rd   re   )r8   ra   rb   rc   rd   re   	__class__r9   r:   r;     s    
z&KeepLargestConnectedComponent.__init__r   r=   r0   c           
      C  s~  | j dkr|jd dkn| j }| jdk	r0| j}ntt||dd}t|t d}t|dd}| jr|D ]V}|rz|| dkn
|d |k}t|| j	| j
}|rd|| ||k< qfd|d ||k< qft||dd S |s,t||dd	^}}	|d
 |kdd }t|| j	| j
}d|d ||k< t||dd S ||df dkd}t|| j	| j
}|D ]}d|| ||k< qVt||dd S )z
        Args:
            img: shape must be (C, spatial_dim1[, spatial_dim2, ...]).

        Returns:
            An array with shape (C, spatial_dim1[, spatial_dim2, ...]).
        Nr   r_   )discardr>   FdstT)rl   wrap_sequence).N.)rb   shapera   tupler   r   r   rc   r   rd   re   r   any)
r8   r=   rb   ra   img_i
foregroundmasklabelsrJ   r9   r9   r:   rL   >  s2    
z&KeepLargestConnectedComponent.__call__)NNTNr_   r7   rN   rO   rP   r   NUMPYCUPYrR   r;   rL   __classcell__r9   r9   rg   r:   r       s   +     "c                   @  sB   e Zd ZdZejgZdddddd	d
dddZdddddZdS )r!   a  
    Use `skimage.morphology.remove_small_objects` to remove small objects from images.
    See: https://scikit-image.org/docs/dev/api/skimage.morphology.html#remove-small-objects.

    Data should be one-hotted.

    Args:
        min_size: objects smaller than this size (in number of voxels; or surface area/volume value
            in whatever units your image is if by_measure is True) are removed.
        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
            Accepted values are ranging from  1 to input.ndim. If ``None``, a full
            connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image
            documentation.
        independent_channels: Whether or not to consider channels as independent. If true, then
            conjoining islands from different labels will be removed if they are below the threshold.
            If false, the overall size islands made from all non-background voxels will be used.
        by_measure: Whether the specified min_size is in number of voxels. if this is True then min_size
            represents a surface area or volume value of whatever units your image is in (mm^3, cm^2, etc.)
            default is False. e.g. if min_size is 3, by_measure is True and the units of your data is mm,
            objects smaller than 3mm^3 are removed.
        pixdim: the pixdim of the input image. if a single number, this is used for all axes.
            If a sequence of numbers, the length of the sequence must be equal to the image dimensions.

    Example::

        .. code-block:: python

            from monai.transforms import RemoveSmallObjects, Spacing, Compose
            from monai.data import MetaTensor

            data1 = torch.tensor([[[0, 0, 0, 0, 0], [0, 1, 1, 0, 1], [0, 0, 0, 1, 1]]])
            affine = torch.as_tensor([[2,0,0,0],
                                      [0,1,0,0],
                                      [0,0,1,0],
                                      [0,0,0,1]], dtype=torch.float64)
            data2 = MetaTensor(data1, affine=affine)

            # remove objects smaller than 3mm^3, input is MetaTensor
            trans = RemoveSmallObjects(min_size=3, by_measure=True)
            out = trans(data2)
            # remove objects smaller than 3mm^3, input is not MetaTensor
            trans = RemoveSmallObjects(min_size=3, by_measure=True, pixdim=(2, 1, 1))
            out = trans(data1)

            # remove objects smaller than 3 (in pixel)
            trans = RemoveSmallObjects(min_size=3)
            out = trans(data2)

            # If the affine of the data is not identity, you can also add Spacing before.
            trans = Compose([
                Spacing(pixdim=(1, 1, 1)),
                RemoveSmallObjects(min_size=3)
            ])

    @   r_   TFNr]   r*   z+Sequence[float] | float | np.ndarray | Noner,   )min_sizerd   independent_channels
by_measurepixdimr0   c                 C  s"   || _ || _|| _|| _|| _d S N)r|   rd   r}   r~   r   )r8   r|   rd   r}   r~   r   r9   r9   r:   r;     s
    zRemoveSmallObjects.__init__r   ri   c                 C  s   t || j| j| j| j| jS )z
        Args:
            img: shape must be (C, spatial_dim1[, spatial_dim2, ...]). Data
                should be one-hotted.

        Returns:
            An array with shape (C, spatial_dim1[, spatial_dim2, ...]).
        )r   r|   rd   r}   r~   r   r8   r=   r9   r9   r:   rL     s    
     zRemoveSmallObjects.__call__)r{   r_   TFN	r7   rN   rO   rP   r   rx   rR   r;   rL   r9   r9   r9   r:   r!   d  s   8     c                   @  s<   e Zd ZdZejejgZdddddZdddd	d
Z	dS )r"   a  
    This transform filters out labels and can be used as a processing step to view only certain labels.

    The list of applied labels defines which labels will be kept.

    Note:
        All labels which do not match the `applied_labels` are set to the background label (0).

    For example:

    Use LabelFilter with applied_labels=[1, 5, 9]::

        [1, 2, 3]         [1, 0, 0]
        [4, 5, 6]    =>   [0, 5 ,0]
        [7, 8, 9]         [0, 0, 9]
    zIterable[int] | intr,   )ra   r0   c                 C  s   t || _dS )z
        Initialize the LabelFilter class with the labels to filter on.

        Args:
            applied_labels: Label(s) to filter on.
        N)r   ra   )r8   ra   r9   r9   r:   r;     s    zLabelFilter.__init__r   ri   c                 C  s   t |tjtjfs,t| j dt| dt |tjrt|t	 d}t|dd}t
tdrtj| j|jd}tt|||td|}t||dd	 S | |   }t||d	 }|S ttt|| j|d	S )
aQ  
        Filter the image on the `applied_labels`.

        Args:
            img: Pytorch tensor or numpy array of any shape.

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch tensor or numpy array of the same shape as the input.
        z can not handle data of type r2   r>   Fisindevice        rk   r   )rY   npndarrayrE   rF   NotImplementedErrorrh   r6   r   r   hasattr	as_tensorra   r   wherer   tensortor   detachcpunumpyasarray)r8   r=   rr   Z	appl_lblsrK   r9   r9   r:   rL     s    
"zLabelFilter.__call__N)
r7   rN   rO   rP   r   rQ   rx   rR   r;   rL   r9   r9   r9   r:   r"     s   	c                      sD   e Zd ZdZejgZddddd fddZd	d	d
ddZ  Z	S )r   a*  
    This transform fills holes in the image and can be used to remove artifacts inside segments.

    An enclosed hole is defined as a background pixel/voxel which is only enclosed by a single class.
    The definition of enclosed can be defined with the connectivity parameter::

        1-connectivity     2-connectivity     diagonal connection close-up

             [ ]           [ ]  [ ]  [ ]             [ ]
              |               \  |  /                 |  <- hop 2
        [ ]--[x]--[ ]      [ ]--[x]--[ ]        [x]--[ ]
              |               /  |  \             hop 1
             [ ]           [ ]  [ ]  [ ]

    It is possible to define for which labels the hole filling should be applied.
    The input image is assumed to be a PyTorch Tensor or numpy array with shape [C, spatial_dim1[, spatial_dim2, ...]].
    If C = 1, then the values correspond to expected labels.
    If C > 1, then a one-hot-encoding is expected where the index of C matches the label indexing.

    Note:

        The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label.

        The performance of this method heavily depends on the number of labels.
        It is a bit faster if the list of `applied_labels` is provided.
        Limiting the number of `applied_labels` results in a big decrease in processing time.

    For example:

        Use FillHoles with default parameters::

            [1, 1, 1, 2, 2, 2, 3, 3]         [1, 1, 1, 2, 2, 2, 3, 3]
            [1, 0, 1, 2, 0, 0, 3, 0]    =>   [1, 1 ,1, 2, 0, 0, 3, 0]
            [1, 1, 1, 2, 2, 2, 3, 3]         [1, 1, 1, 2, 2, 2, 3, 3]

        The hole in label 1 is fully enclosed and therefore filled with label 1.
        The background label near label 2 and 3 is not fully enclosed and therefore not filled.
    NzIterable[int] | int | NonerS   r,   )ra   rd   r0   c                   s&   t    |rt|nd| _|| _dS )a  
        Initialize the connectivity and limit the labels for which holes are filled.

        Args:
            applied_labels: Labels for which to fill holes. Defaults to None, that is filling holes for all labels.
            connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
                Accepted values are ranging from  1 to input.ndim. Defaults to a full connectivity of ``input.ndim``.
        N)rf   r;   r   ra   rd   )r8   ra   rd   rg   r9   r:   r;      s    	
zFillHoles.__init__r   ri   c                 C  s@   t |t d}t|tj^}}t|| j| j}t||^}}|S )a  
        Fill the holes in the provided image.

        Note:
            The value 0 is assumed as background label.

        Args:
            img: Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].

        Raises:
            NotImplementedError: The provided image was not a Pytorch Tensor or numpy array.

        Returns:
            Pytorch Tensor or numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
        r>   )	r   r   r   r   r   r   ra   rd   r   )r8   r=   img_nprJ   out_nprK   r9   r9   r:   rL   -  s
    zFillHoles.__call__)NN)
r7   rN   rO   rP   r   rx   rR   r;   rL   rz   r9   r9   rg   r:   r     s   'c                   @  s:   e Zd ZdZejgZddddddZddd	d
dZdS )r#   a  
    Return the contour of binary input images that only compose of 0 and 1, with Laplacian kernel
    set as default for edge detection. Typical usage is to plot the edge of label or segmentation output.

    Args:
        kernel_type: the method applied to do edge detection, default is "Laplace".

    Raises:
        NotImplementedError: When ``kernel_type`` is not "Laplace".

    Laplacestrr,   )kernel_typer0   c                 C  s   |dkrt d|| _d S )Nr   z2Currently only kernel_type="Laplace" is supported.)r   r   )r8   r   r9   r9   r:   r;   S  s    zLabelToContour.__init__r   ri   c                 C  s   t |t d}t |dd}t|jd }|d}|dkrdtjdddgdddgdddggtjd}n:|d	krd
tjd	d	d	tjd }d|d< nt	| j
 dt||}|jddd t|d|^}}|S )a  
        Args:
            img: torch tensor data to extract the contour, with shape: [channels, height, width[, depth]]

        Raises:
            ValueError: When ``image`` ndim is not one of [3, 4].

        Returns:
            A torch tensor with the same shape as img, note:
                1. it's the binary classification result of whether a pixel is edge or not.
                2. in order to keep the original shape of mask image, we use padding as default.
                3. the edge detection is just approximate because it defects inherent to Laplace kernel,
                   ideally the edge should be thin enough, but now it has a thickness.

        r>   Fr_   r      rn      r@      g      g      :@)r_   r_   r_   z! can only handle 2D or 3D images.r   g      ?)minmax)r   r   lenro   	unsqueezerE   r   float32onesrD   rh   r   clamp_r   squeeze)r8   r=   rr   spatial_dimskernelZcontour_imgoutputrJ   r9   r9   r:   rL   X  s    
*

zLabelToContour.__call__N)r   rM   r9   r9   r9   r:   r#   D  s   c                   @  s6   e Zd ZedddddZedddddd	Zd
S )Ensemble+Sequence[NdarrayOrTensor] | NdarrayOrTensortorch.Tensorri   c                 C  s\   t | tr*t | d tjr*dd | D } nt | tjr@t| } t | trTt| n| }|S )z`Get either a sequence or single instance of np.ndarray/torch.Tensor. Return single torch.Tensor.r   c                 S  s   g | ]}t |qS r9   )rE   r   .0rs   r9   r9   r:   
<listcomp>  s     z.Ensemble.get_stacked_torch.<locals>.<listcomp>)rY   r   r   r   rE   r   stack)r=   rK   r9   r9   r:   get_stacked_torch{  s    
zEnsemble.get_stacked_torchr   )r=   orig_imgr0   c                 C  s(   t |tr|d n|}t| |^}}|S )Nr   )rY   r   r   )r=   r   Z	orig_img_rK   rJ   r9   r9   r:   post_convert  s    zEnsemble.post_convertN)r7   rN   rO   staticmethodr   r   r9   r9   r9   r:   r   y  s   	r   c                   @  s:   e Zd ZdZejgZddddddZdd	d
ddZdS )r$   a  
    Execute mean ensemble on the input data.
    The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]],
    Or a single PyTorch Tensor with shape: [E, C[, H, W, D]], the `E` dimension represents
    the output data from different models.
    Typically, the input data is model output of segmentation task or classification task.
    And it also can support to add `weights` for the input data.

    Args:
        weights: can be a list or tuple of numbers for input data with shape: [E, C, H, W[, D]].
            or a Numpy ndarray or a PyTorch Tensor data.
            the `weights` will be added to input data from highest dimension, for example:
            1. if the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.
            2. if the `weights` has 2 dimensions, it will be added to `E` and `C` dimensions.
            it's a typical practice to add weights for different classes:
            to ensemble 3 segmentation model outputs, every output has 4 channels(classes),
            so the input data shape can be: [3, 4, H, W, D].
            and add different `weights` for different classes, so the `weights` shape can be: [3, 4].
            for example: `weights = [[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 1, 1]]`.

    Nz(Sequence[float] | NdarrayOrTensor | Noner,   )weightsr0   c                 C  s"   |d k	rt j|t jdnd | _d S )Nr@   )rE   r   rG   r   )r8   r   r9   r9   r:   r;     s    zMeanEnsemble.__init__r   r   ri   c                 C  s   |  |}| jd k	rv| j|j| _t| jj}t| | j  D ]}|d7 }qF| jj| }|| |j	ddd }t
j	|dd}| ||S )N)r_   r   Tr[   rC   )r   r   r   r   rp   ro   range
ndimensionreshapemeanrE   r   )r8   r=   rr   ro   rJ   r   out_ptr9   r9   r:   rL     s    


zMeanEnsemble.__call__)NrM   r9   r9   r9   r:   r$     s   c                   @  s:   e Zd ZdZejgZddddddZdd	d
ddZdS )r'   af  
    Execute vote ensemble on the input data.
    The input data can be a list or tuple of PyTorch Tensor with shape: [C[, H, W, D]],
    Or a single PyTorch Tensor with shape: [E[, C, H, W, D]], the `E` dimension represents
    the output data from different models.
    Typically, the input data is model output of segmentation task or classification task.

    Note:
        This vote transform expects the input data is discrete values. It can be multiple channels
        data in One-Hot format or single channel data. It will vote to select the most common data
        between items.
        The output data has the same shape as every item of the input data.

    Args:
        num_classes: if the input is single channel data instead of One-Hot, we can't get class number
            from channel, need to explicitly specify the number of classes to vote.

    NrS   r,   )r\   r0   c                 C  s
   || _ d S r   )r\   )r8   r\   r9   r9   r:   r;     s    zVoteEnsemble.__init__r   r   ri   c                 C  s   |  |}| jd k	r^d}| dkr>|jd dkr>td n | dkrNd}t|| jdd}tj|	 dd}| jd k	rtj
|d|d}n
t|}| ||S )NTr_   z7no need to specify num_classes for One-Hot format data.FrC   r   r[   )r   r\   r   ro   warningswarnr	   rE   r   rG   rT   r^   r   )r8   r=   rr   Z
has_ch_dimr   r9   r9   r:   rL     s    



zVoteEnsemble.__call__)NrM   r9   r9   r9   r:   r'     s   c                   @  s>   e Zd ZdZejgZddddd	d
dddZddddZdS )r%   a-  
    Performs probability based non-maximum suppression (NMS) on the probabilities map via
    iteratively selecting the coordinate with highest probability and then move it as well
    as its surrounding values. The remove range is determined by the parameter `box_size`.
    If multiple coordinates have the same highest probability, only one of them will be
    selected.

    Args:
        spatial_dims: number of spatial dimensions of the input probabilities map.
            Defaults to 2.
        sigma: the standard deviation for gaussian filter.
            It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.
        prob_threshold: the probability threshold, the function will stop searching if
            the highest probability is no larger than the threshold. The value should be
            no less than 0.0. Defaults to 0.5.
        box_size: the box size (in pixel) to be removed around the pixel with the maximum probability.
            It can be an integer that defines the size of a square or cube,
            or a list containing different values for each dimensions. Defaults to 48.

    Return:
        a list of selected lists, where inner lists contain probability and coordinates.
        For example, for 3D input, the inner lists are in the form of [probability, x, y, z].

    Raises:
        ValueError: When ``prob_threshold`` is less than 0.0.
        ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.
        ValueError: When ``box_size`` has a less than 1 value.

    r   r         ?0   r]   z?Sequence[float] | float | Sequence[torch.Tensor] | torch.TensorrG   zint | Sequence[int]r,   )r   sigmaprob_thresholdbox_sizer0   c                 C  s   || _ || _| j dkr$t||d| _|dk r4td|| _t|trXt	|g| | _
n"t||krntdnt	|| _
| j
 dkrtd| j
d | _| j
| j | _d S )Nr   )r   r   z*prob_threshold should be no less than 0.0.zCthe sequence length of box_size should be the same as spatial_dims.z!box_size should be larger than 0.r   )r   r   r
   filterrD   r   rY   r]   r   r   r   r   r   box_lower_bdbox_upper_bd)r8   r   r   r   r   r9   r9   r:   r;   	  s     


zProbNMS.__init__r   )prob_mapc                   s
  | j dkr>t|tjs&tj|tjd}| j|j | |}|j	}g }|
 | jkrt| |}|t| }t|tjr|  n|}t|tjr| n|}||gt|  || j dd|| j d| t fddt| jD }d||< qH|S )zZ
        prob_map: the input probabilities map, it must have shape (H[, W, ...]).
        r   r@   Nc                 3  s    | ]}t |  | V  qd S r   )slicer   Zidx_max_rangeZidx_min_ranger9   r:   	<genexpr>:  s     z#ProbNMS.__call__.<locals>.<genexpr>)r   rY   rE   rF   r   rG   r   r   r   ro   r   r   r   rT   rp   r   r   itemappendlistr   clipr   r   r   )r8   r   Zprob_map_shapeoutputsmax_idxZprob_maxslicesr9   r   r:   rL   #  s$    


zProbNMS.__call__N)r   r   r   r   r   r9   r9   r9   r:   r%     s       c                   @  s:   e Zd ZdZejgZdddddddd	d
dZdd ZdS )r(   zV
    Utility transform to automatically invert the previously applied transforms.
    NTzInvertibleTransform | Nonezbool | Sequence[bool]zstr | torch.device | Noner+   r,   )	transformnearest_interpr   	post_func	to_tensorr0   c                 C  s<   t |tstd|| _|| _|| _|| _|| _t | _	dS )aY  
        Args:
            transform: the previously applied transform.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
            device: move the inverted results to a target device before `post_func`, default to `None`.
            post_func: postprocessing for the inverted result, should be a callable function.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
        zAtransform is not invertible, can't invert transform for the data.N)
rY   r   rD   r   r   r   r   r   r   	_totensor)r8   r   r   r   r   r   r9   r9   r:   r;   G  s    
zInvert.__init__c                 C  s   t |ts|S | jr&t|jdd d|_| }| j|}| jrTt |tsT| 	|}t |t
jrn|j| jd}t| jr| |}|S )Nnearest)
trans_infomodealign_cornersr   )rY   r   r   r   applied_operationsr   r   inverser   r   rE   rF   r   r   r4   r   )r8   datainvertedr9   r9   r:   rL   a  s"    
  


zInvert.__call__)NTNNTrM   r9   r9   r9   r:   r(   @  s        c                	      sh   e Zd ZdZejgZdddddejfddd	d	d
ddd fddZ	ddddZ
dddddZ  ZS )r&   a  Calculate Sobel gradients of a grayscale image with the shape of CxH[xWxDx...] or BxH[xWxDx...].

    Args:
        kernel_size: the size of the Sobel kernel. Defaults to 3.
        spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
            along each of the provide axis. By default it calculate the gradient for all spatial axes.
        normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
        normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
        padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
            Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
            See ``torch.nn.Conv1d()`` for more information.
        dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.

    r   NTFreflectr]   r`   r*   r   ztorch.dtyper,   )kernel_sizespatial_axesnormalize_kernelsnormalize_gradientspadding_moderA   r0   c                   s:   t    || _|| _|| _|| _| ||\| _| _d S r   )	rf   r;   paddingr   r   r   _get_kernelkernel_diffkernel_smooth)r8   r   r   r   r   r   rA   rg   r9   r:   r;     s    	
zSobelGradients.__init__z!tuple[torch.Tensor, torch.Tensor])r0   c                 C  s   |dk rt d| d|d dkr4t d| dtjdddggg|d	}tjdddggg|d	}tjdddggg|d	}| jr|jst d
| d|d }|d }|d }|d d }t|D ]$}tj||dd}tj||dd}q| | fS )Nr   z,Sobel kernel size should be at least three. z was given.r   r   z+Sobel kernel size should be an odd number. rn   r_   r@   zQ`dtype` for Sobel kernel should be floating point when `normalize_kernel==True`. g       @g      @)r   )	rD   rE   r   r   is_floating_pointr   Fconv1dr   )r8   sizerA   r   r   Zkernel_expansionexpandrJ   r9   r9   r:   r     s(    
zSobelGradients._get_kernelr   r   )imager0   c                   s`  t |t d}|jd  tt tt  d }| jd krLtt }nNtt| jt| }|rtd| d  d| d fdd	t| jD }|	d}| j
|j}| j|j}g }|D ]p}	|g  }
||
|	< t||
| jd
}| jr.| }|| kr||8 }| }|dkr.|| }|| qtj|dd}t|d|d }|S )Nr>   r_   r   z5The provide axes to calculate gradient is not valid: z. The image has z% spatial dimensions so it should be: r2   c                   s    g | ]}|d k r|  n|qS )r   r9   )r   axn_spatial_dimsr9   r:   r     s     z+SobelGradients.__call__.<locals>.<listcomp>)r   rC   )r   r   ndimr   r   r   setr   rD   r   r   r   r   r   r   r   r   r   r   r   rE   catr   r   )r8   r   Zimage_tensorZvalid_spatial_axesr   Zinvalid_axisr   r   Z	grad_listr   kernelsgradZgrad_minZgrad_maxgradsr9   r   r:   rL     s<    




zSobelGradients.__call__)r7   rN   rO   rP   r   rQ   rR   rE   r   r;   r   rL   rz   r9   r9   rg   r:   r&   u  s    c                      sF   e Zd ZdZejejgZdddd fddZddd	d
dZ	  Z
S )r)   a  
    Applies the Euclidean distance transform on the input.
    Either GPU based with CuPy / cuCIM or CPU based with scipy.
    To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.

    Note that the results of the libraries can differ, so stick to one if possible.
    For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.

    .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
    .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt
    NzNone | float | list[float]r,   )samplingr0   c                   s   t    || _d S r   )rf   r;   r   )r8   r   rg   r9   r:   r;     s    
zDistanceTransformEDT.__init__r   ri   c                 C  s   t || jdS )aA  
        Args:
            img: Input image on which the distance transform shall be run.
                Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
                Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
                Input gets passed channel-wise to the distance-transform, thus results from this function will differ
                from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
            sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
                if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.

        Returns:
            An array with the same shape and data type as img
        )r=   r   )r   r   r   r9   r9   r:   rL     s    zDistanceTransformEDT.__call__)Nrw   r9   r9   rg   r:   r)     s   )ArP   
__future__r   r   collections.abcr   r   r   r   r   rE   torch.nn.functionalnn
functionalr   monai.config.type_definitionsr   monai.data.meta_objr   monai.data.meta_tensorr   monai.networksr	   monai.networks.layersr
   r   r   monai.transforms.inverser   monai.transforms.transformr   monai.transforms.utility.arrayr   monai.transforms.utilsr   r   r   r   r   r   0monai.transforms.utils_pytorch_numpy_unificationr   monai.utilsr   r   r   r   r   monai.utils.type_conversionr   __all__r   r   r    r!   r"   r   r#   r   r$   r'   r%   r(   r&   r)   r9   r9   r9   r:   <module>   s\    GjvX:N5+1X5n