U
    Ph                 !   @  s2  d Z ddlmZ ddlZddlZddlZddlZddlmZm	Z	 ddl
mZ ddlmZ ddlmZmZ ddlZddlZddlmZ ddlmZ dd	lmZ dd
lmZ ddlmZ ddlmZm Z  ddl!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z) ddl*m+Z+ ddl,m-Z- ddl.m/Z/m0Z0m1Z1m2Z2 ddl3m4Z4m5Z5m6Z6m7Z7 ddl8m9Z9m:Z:m;Z;m<Z< ddl=m>Z>m?Z?m@Z@mAZAmBZBmCZCmDZDmEZEmFZFmGZG ddlHmIZI ddlJmKZK ddlLmMZMmNZN eGddd\ZOZPeGddd\ZQZReGd\ZSZTddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:d;d<g!ZUG d=d de2ZVG d>d de0ZWG d?d de2ZXG d@d  d e2ZYG dAd" d"e2ZZG dBd# d#e2Z[G dCd$ d$e2e-Z\G dDd% d%e2Z]G dEd& d&e2Z^G dFd! d!e2Z_G dGd' d'e2Z`G dHd: d:e2ZaG dId( d(e2ZbG dJd) d)e2ZcG dKd* d*e2ZdG dLd+ d+e2ZeG dMd, d,e2ZfG dNd- d-e+ZgG dOd. d.ege1ZhG dPd/ d/e2ZiG dQd0 d0e2e-ZjG dRd1 d1e2e-ZkG dSd2 d2e2ZlG dTd3 d3e/e2ZmG dUd4 d4ZnG dVd5 d5ZoG dWd6 d6e2ZpG dXd7 d7e2ZqG dYd8 d8e2ZrG dZd9 d9ere0ZsG d[d de2ZtG d\d; d;e2ZuG d]d< d<e1ZvdS )^z=
A collection of "vanilla" transforms for utility functions.
    )annotationsN)MappingSequence)deepcopy)partial)AnyCallable)	DtypeLike)NdarrayOrTensor)get_track_meta)
MetaTensor)is_no_channelno_collation)ApplyFilterEllipticalFilterGaussianFilterLaplaceFilter
MeanFilterSavitzkyGolayFilterSharpenFiltermedian_filter)InvertibleTransform)MultiSampleTrait)RandomizableRandomizableTraitRandomizableTransform	Transform)extreme_points_to_imageget_extreme_pointsmap_binary_to_indicesmap_classes_to_indices)concatenatein1dmoveaxisunravel_indices)
MetaKeys	TraceKeysconvert_data_typeconvert_to_cupyconvert_to_numpyconvert_to_tensorensure_tuplelook_up_optionmin_versionoptional_import)TransformBackends)is_module_ver_at_least)convert_to_dst_typeget_equivalent_dtypez	PIL.ImageImagename	fromarraycupyIdentityRandIdentityAsChannelLastAddCoordinateChannelsEnsureChannelFirst
EnsureTypeRepeatChannelRemoveRepeatedChannelSplitDim
CastToTypeToTensorToNumpyToPIL	Transpose
SqueezeDim	DataStatsSimulateDelayLambda
RandLambdaLabelToMaskFgBgToIndicesClassesToIndices(ConvertToMultiChannelBasedOnBratsClassesAddExtremePointsChannelTorchVisionMapLabelValueIntensityStatsToDeviceCuCIM	RandCuCIMToCupyImageFilterRandImageFilterc                   @  s,   e Zd ZdZejejgZdddddZdS )r8   z
    Do nothing to the data.
    As the output value is same as input, it can be used as a testing tool to verify the transform chain,
    Compose or transform adaptor, etc.
    r
   imgreturnc                 C  s   |S /
        Apply the transform to `img`.
         selfrZ   r^   r^   S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/utility/array.py__call__y   s    zIdentity.__call__N	__name__
__module____qualname____doc__r/   TORCHNUMPYbackendrb   r^   r^   r^   ra   r8   p   s   c                   @  s,   e Zd ZdZejejgZdddddZdS )r9   z
    Do nothing to the data. This transform is random, so can be used to stop the caching of any
    subsequent transforms.
    r   )datar[   c                 C  s   |S Nr^   r`   rk   r^   r^   ra   rb      s    zRandIdentity.__call__Nrc   r^   r^   r^   ra   r9      s   c                   @  s>   e Zd ZdZejejgZddddddZddd	d
dZ	dS )r:   am  
    Change the channel dimension of the image to the last dimension.

    Some of other 3rd party transforms assume the input image is in the channel-last format with shape
    (spatial_dim_1[, spatial_dim_2, ...], num_channels).

    This transform could be used to convert, for example, a channel-first image array in shape
    (num_channels, spatial_dim_1[, spatial_dim_2, ...]) into the channel-last format,
    so that MONAI transforms can construct a chain with other 3rd party transforms together.

    Args:
        channel_dim: which dimension of input image is the channel, default is the first dimension.
    r   intNone)channel_dimr[   c                 C  s,   t |tr|dks"td| d|| _d S )Nzinvalid channel dimension (z).)
isinstancern   
ValueErrorrp   )r`   rp   r^   r^   ra   __init__   s    zAsChannelLast.__init__r
   rY   c                 C  s   t t|| jdt d}|S )r]   rq   
track_meta)r*   r#   rp   r   r`   rZ   outr^   r^   ra   rb      s    zAsChannelLast.__call__N)r   
rd   re   rf   rg   r/   rh   ri   rj   rt   rb   r^   r^   r^   ra   r:      s   c                   @  sB   e Zd ZdZejejgZddddddZdd	d
d	dddZ	dS )r<   a  
    Adjust or add the channel dimension of input data to ensure `channel_first` shape.

    This extracts the `original_channel_dim` info from provided meta_data dictionary or MetaTensor input. This value
    should state which dimension is the channel dimension so that it can be moved forward, or contain "no_channel" to
    state no dimension is the channel and so a 1-size first dimension is to be added.

    Args:
        strict_check: whether to raise an error when the meta information is insufficient.
        channel_dim: This argument can be used to specify the original channel dimension (integer) of the input array.
            It overrides the `original_channel_dim` from provided MetaTensor input.
            If the input array doesn't have a channel dim, this value should be ``'no_channel'``.
            If this is set to `None`, this class relies on `img` or `meta_dict` to provide the channel dimension.
    TNboolzNone | str | int)strict_checkrp   c                 C  s   || _ || _d S rl   )r{   input_channel_dim)r`   r{   rp   r^   r^   ra   rt      s    zEnsureChannelFirst.__init__torch.TensorMapping | NonerZ   	meta_dictr[   c                 C  s   t |tsFt |tsF| jdkr>d}| jr0t|t| |S t|}t |trV|j}t |trn|	t
jdnd}| jdk	r| jdkrtdn| j}|dkrd}| jrt|t| |S t |tr||t
j< t|r|d }nt|t|d}t|t dS )r]   NzNMetadata not available and channel_dim=None, EnsureChannelFirst is not in use.
no_channelnanzYUnknown original_channel_dim in the MetaTensor meta dict or `meta_dict` or `channel_dim`.r   ru   )rr   r   r   r|   r{   rs   warningswarnmetagetr%   ORIGINAL_CHANNEL_DIMfloatdictr   r#   rn   r*   r   )r`   rZ   r   msgrp   resultr^   r^   ra   rb      s2    







zEnsureChannelFirst.__call__)TN)Nry   r^   r^   r^   ra   r<      s   c                   @  s8   e Zd ZdZejgZdddddZdddd	d
ZdS )r>   a5  
    Repeat channel data to construct expected input shape for models.
    The `repeats` count includes the origin data, for example:
    ``RepeatChannel(repeats=2)([[1, 2], [3, 4]])`` generates: ``[[1, 2], [1, 2], [3, 4], [3, 4]]``

    Args:
        repeats: the number of repetitions for each element.
    rn   ro   repeatsr[   c                 C  s"   |dkrt d| d|| _d S Nr   z*repeats count must be greater than 0, got .rs   r   r`   r   r^   r^   ra   rt      s    zRepeatChannel.__init__r
   rY   c                 C  s0   t |tjrtjntj}t||| jdt dS )Z
        Apply the transform to `img`, assuming `img` is a "channel-first" array.
        r   ru   )	rr   torchTensorrepeat_interleavenprepeatr*   r   r   )r`   rZ   Z	repeat_fnr^   r^   ra   rb      s    zRepeatChannel.__call__N	rd   re   rf   rg   r/   rh   rj   rt   rb   r^   r^   r^   ra   r>      s   	c                   @  s<   e Zd ZdZejejgZdddddZdddd	d
Z	dS )r?   aK  
    RemoveRepeatedChannel data to undo RepeatChannel
    The `repeats` count specifies the deletion of the origin data, for example:
    ``RemoveRepeatedChannel(repeats=2)([[1, 2], [1, 2], [3, 4], [3, 4]])`` generates: ``[[1, 2], [3, 4]]``

    Args:
        repeats: the number of repetitions to be deleted for each element.
    rn   ro   r   c                 C  s"   |dkrt d| d|| _d S r   r   r   r^   r^   ra   rt     s    zRemoveRepeatedChannel.__init__r
   rY   c                 C  sJ   |j d dk r$td|j d  dt|dd| jddf t d}|S )r   r      z+Image must have more than one channel, got z
 channels.Nru   )shapers   r*   r   r   rw   r^   r^   ra   rb     s    "zRemoveRepeatedChannel.__call__Nry   r^   r^   r^   ra   r?     s   	c                   @  s@   e Zd ZdZejejgZddddddd	Zd
ddddZ	dS )r@   aV  
    Given an image of size X along a certain dimension, return a list of length X containing
    images. Useful for converting 3D images into a stack of 2D images, splitting multichannel inputs into
    single channels, for example.

    Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy).

    Args:
        dim: dimension on which to split
        keepdim: if `True`, output will have singleton in the split dimension. If `False`, this
            dimension will be squeezed.
        update_meta: whether to update the MetaObj in each split result.
    rq   Trn   rz   ro   )dimkeepdimr[   c                 C  s   || _ || _|| _d S rl   )r   r   update_meta)r`   r   r   r   r^   r^   ra   rt   /  s    zSplitDim.__init__r}   zlist[torch.Tensor]rY   c                 C  s   |j | j }t|tjr.tt|d| j}nt||| j}t|D ]\}}| j	sd|
| j||< | jrFt|trFt|tst||jd}| jdkrqFt|j}tj||jj|jjd}||| jd df< |j| |_qF|S )r]      )r   r   )devicedtyperq   )r   r   rr   r   r   listsplitr   	enumerater   squeezer   r   r   lenaffineeyer   r   )r`   rZ   n_outoutputsidxitemndimshiftr^   r^   ra   rb   4  s"    


zSplitDim.__call__N)rq   TTry   r^   r^   r^   ra   r@     s   c                   @  sD   e Zd ZdZejejgZej	fddddZ
ddddd	d
dZdS )rA   a  
    Cast the Numpy data to specified numpy data type, or cast the PyTorch Tensor to
    specified PyTorch data type.

    Example:
        >>> import numpy as np
        >>> import torch
        >>> transform = CastToType(dtype=np.float32)

        >>> # Example with a numpy array
        >>> img_np = np.array([0, 127, 255], dtype=np.uint8)
        >>> img_np_casted = transform(img_np)
        >>> img_np_casted
        array([  0. , 127. , 255. ], dtype=float32)

        >>> # Example with a PyTorch tensor
        >>> img_tensor = torch.tensor([0, 127, 255], dtype=torch.uint8)
        >>> img_tensor_casted = transform(img_tensor)
        >>> img_tensor_casted
        tensor([  0., 127., 255.])  # dtype is float32
    ro   )r[   c                 C  s
   || _ dS )zd
        Args:
            dtype: convert image to this data type, default is `np.float32`.
        N)r   )r`   r   r^   r^   ra   rt   e  s    zCastToType.__init__Nr
   DtypeLike | torch.dtype)rZ   r   r[   c                 C  s   t |t||p| jdd S )a+  
        Apply the transform to `img`, assuming `img` is a numpy array or PyTorch Tensor.

        Args:
            dtype: convert image to this data type, default is `self.dtype`.

        Raises:
            TypeError: When ``img`` type is not in ``Union[numpy.ndarray, torch.Tensor]``.

        )output_typer   r   )r'   typer   )r`   rZ   r   r^   r^   ra   rb   l  s    zCastToType.__call__)N)rd   re   rf   rg   r/   rh   ri   rj   r   float32rt   rb   r^   r^   r^   ra   rA   L  s   c                      sF   e Zd ZdZejgZddddddd	 fd
dZddddZ  Z	S )rB   a  
    Converts the input image to a tensor without applying any other transformations.
    Input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
    Will convert Tensor, Numpy array, float, int, bool to Tensor, strings and objects keep the original.
    For dictionary, list or tuple, convert every item to a Tensor if applicable and `wrap_sequence=False`.

    Args:
        dtype: target data type to when converting to Tensor.
        device: target device to put the converted Tensor data.
        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
            E.g., if `False`, `[1, 2]` -> `[tensor(1), tensor(2)]`, if `True`, then `[1, 2]` -> `tensor([1, 2])`.
        track_meta: whether to convert to `MetaTensor` or regular tensor, default to `None`,
            use the return value of ``get_track_meta``.

    NTztorch.dtype | Noneztorch.device | str | Nonerz   bool | Nonero   )r   r   wrap_sequencerv   r[   c                   s8   t    || _|| _|| _|d kr*t nt|| _d S rl   )superrt   r   r   r   r   rz   rv   )r`   r   r   r   rv   	__class__r^   ra   rt     s
    
zToTensor.__init__r
   rZ   c                 C  s*   t |trg |_t|| j| j| j| jdS )F
        Apply the transform to `img` and make it contiguous.
        )r   r   r   rv   )rr   r   applied_operationsr*   r   r   r   rv   r_   r^   r^   ra   rb     s    
    zToTensor.__call__)NNTN
rd   re   rf   rg   r/   rh   rj   rt   rb   __classcell__r^   r^   r   ra   rB   z  s       c                   @  sH   e Zd ZdZejejgZdddddd	d
dddZddddddZ	dS )r=   a  
    Ensure the input data to be a PyTorch Tensor or numpy array, support: `numpy array`, `PyTorch Tensor`,
    `float`, `int`, `bool`, `string` and `object` keep the original.
    If passing a dictionary, list or tuple, still return dictionary, list or tuple will recursively convert
    every item to the expected data type if `wrap_sequence=False`.

    Args:
        data_type: target data type to convert, should be "tensor" or "numpy".
        dtype: target data content type to convert, for example: np.float32, torch.float, etc.
        device: for Tensor data type, specify the target device.
        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
        track_meta: if `True` convert to ``MetaTensor``, otherwise to Pytorch ``Tensor``,
            if ``None`` behave according to return value of py:func:`monai.data.meta_obj.get_track_meta`.

    Example with wrap_sequence=True:
        >>> import numpy as np
        >>> import torch
        >>> transform = EnsureType(data_type="tensor", wrap_sequence=True)
        >>> # Converting a list to a tensor
        >>> data_list = [1, 2., 3]
        >>> tensor_data = transform(data_list)
        >>> tensor_data
        tensor([1., 2., 3.])    # All elements have dtype float32

    Example with wrap_sequence=False:
        >>> transform = EnsureType(data_type="tensor", wrap_sequence=False)
        >>> # Converting each element in a list to individual tensors
        >>> data_list = [1, 2, 3]
        >>> tensors_list = transform(data_list)
        >>> tensors_list
        [tensor(1), tensor(2.), tensor(3)]  # Only second element is float32 rest are int64
    tensorNTstrr   ztorch.device | Nonerz   r   ro   )	data_typer   r   r   rv   r[   c                 C  sB   t | ddh| _|| _|| _|| _|d kr4t nt|| _d S )Nr   numpy)	r,   lowerr   r   r   r   r   rz   rv   )r`   r   r   r   r   rv   r^   r^   ra   rt     s
    zEnsureType.__init__r
   )rk   r   c                 C  sN   | j dkr| jrtntj}ntj}t|||dkr6| jn|| j	| j
d^}}|S )a  
        Args:
            data: input data can be PyTorch Tensor, numpy array, list, dictionary, int, float, bool, str, etc.
                will ensure Tensor, Numpy array, float, int, bool as Tensors or numpy arrays, strings and
                objects keep the original. for dictionary, list or tuple, ensure every item as expected type
                if applicable and `wrap_sequence=False`.
            dtype: target data content type to convert, for example: np.float32, torch.float, etc.

        r   N)rk   r   r   r   r   )r   rv   r   r   r   r   ndarrayr'   r   r   r   )r`   rk   r   r   rx   _r^   r^   ra   rb     s    


zEnsureType.__call__)r   NNTN)Nry   r^   r^   r^   ra   r=     s   !     c                      sB   e Zd ZdZejgZddddd fdd	Zd
dddZ  Z	S )rC   a  
    Converts the input data to numpy array, can support list or tuple of numbers and PyTorch Tensor.

    Args:
        dtype: target data type when converting to numpy array.
        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
            E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.

    NTr	   rz   ro   r   r   r[   c                   s   t    || _|| _d S rl   r   rt   r   r   r`   r   r   r   r^   ra   rt     s    
zToNumpy.__init__r
   r   c                 C  s   t || j| jdS )r   r   r   )r)   r   r   r_   r^   r^   ra   rb     s    zToNumpy.__call__)NT)
rd   re   rf   rg   r/   ri   rj   rt   rb   r   r^   r^   r   ra   rC     s   
c                      sB   e Zd ZdZejgZddddd fdd	Zd
dddZ  Z	S )rV   aA  
    Converts the input data to CuPy array, can support list or tuple of numbers, NumPy and PyTorch Tensor.

    Args:
        dtype: data type specifier. It is inferred from the input by default.
            if not None, must be an argument of `numpy.dtype`, for more details:
            https://docs.cupy.dev/en/stable/reference/generated/cupy.array.html.
        wrap_sequence: if `False`, then lists will recursively call this function, default to `True`.
            E.g., if `False`, `[1, 2]` -> `[array(1), array(2)]`, if `True`, then `[1, 2]` -> `array([1, 2])`.

    NTznp.dtype | Nonerz   ro   r   c                   s   t    || _|| _d S rl   r   r   r   r^   ra   rt     s    
zToCupy.__init__r
   rk   c                 C  s   t || j| jdS )zH
        Create a CuPy array from `data` and make it contiguous
        r   )r(   r   r   rm   r^   r^   ra   rb     s    zToCupy.__call__)NT)
rd   re   rf   rg   r/   CUPYrj   rt   rb   r   r^   r^   r   ra   rV   	  s   c                   @  s    e Zd ZdZejgZdd ZdS )rD   z^
    Converts the input image (in the form of NumPy array or PyTorch Tensor) to PIL image
    c                 C  s2   t |tr|S t |tjr*|   }t|S r\   )rr   PILImageImager   r   detachcpur   pil_image_fromarrayr_   r^   r^   ra   rb   +  s
    
zToPIL.__call__N)rd   re   rf   rg   r/   ri   rj   rb   r^   r^   r^   ra   rD   $  s   c                   @  s8   e Zd ZdZejgZdddddZdddd	d
ZdS )rE   zU
    Transposes the input image based on the given `indices` dimension ordering.
    Sequence[int] | Nonero   )indicesr[   c                 C  s   |d krd nt || _d S rl   )tupler   )r`   r   r^   r^   ra   rt   =  s    zTranspose.__init__r
   rY   c                 C  s2   t |t d}|| jp.tt|jddd S )r]   ru   Nrq   )r*   r   permuter   r   ranger   r_   r^   r^   ra   rb   @  s    zTranspose.__call__Nr   r^   r^   r^   ra   rE   6  s   c                   @  s>   e Zd ZdZejejgZddddddZd	d	d
ddZ	dS )rF   z&
    Squeeze a unitary dimension.
    r   T
int | Nonero   )r   r[   c                 C  s8   |dk	r(t |ts(tdt|j d|| _|| _dS )aI  
        Args:
            dim: dimension to be squeezed. Default = 0
                "None" works when the input is numpy array.
            update_meta: whether to update the meta info if the input is a metatensor. Default is ``True``.

        Raises:
            TypeError: When ``dim`` is not an ``Optional[int]``.

        Nz!dim must be None or a int but is r   )rr   rn   	TypeErrorr   rd   r   r   )r`   r   r   r^   r^   ra   rt   O  s    zSqueezeDim.__init__r
   rY   c                 C  s  t |t d}| jdkr0| jr(td | S | jdk rJ| jt|j n| j}|j| dkr|t	d|j|  d|j d||}| jr~t
|tr~|dkr~t|jjd	kr~|jj\}}|jt
|jtjr|jjnd }}||kr
|tjd||d
|d k }||kr8|ddtjd||d
|d kf }|jd |jd krxtjt|ddsxtd|j d ||_|S )z[
        Args:
            img: numpy arrays with required dimension `dim` removed
        ru   Nz*update_meta=True is ignored when dim=None.r   r   z0Can only squeeze singleton dimension, got shape z of r   r   r   T)r   z,After SqueezeDim, img.affine is ill-posed: 
)r*   r   r   r   r   r   r   r   r   rs   rr   r   r   r   r   r   aranger   linalgdetr)   )r`   rZ   r   hwr   r   r^   r^   ra   rb   _  s(    

 
0"

$,zSqueezeDim.__call__N)r   Try   r^   r^   r^   ra   rF   H  s   c                
   @  sX   e Zd ZdZejejgZddddddddd	d
ddZddddddddddddZ	dS )rG   aJ  
    Utility transform to show the statistics of data for debug or analysis.
    It can be inserted into any place of a transform chain and check results of previous transforms.
    It support both `numpy.ndarray` and `torch.tensor` as input data,
    so it can be used in pre-processing and post-processing.

    It gets logger from `logging.getLogger(name)`, we can setup a logger outside first with the same `name`.
    If the log level of `logging.RootLogger` is higher than `INFO`, will add a separate `StreamHandler`
    log handler with `INFO` level and record to `stdout`.

    DataTFNr   rz   Callable | Nonero   )prefixr   
data_shapevalue_range
data_valueadditional_infor5   r[   c                 C  s   t |tstdt| d|| _|| _|| _|| _|| _|dk	rbt	|sbt
dt|j d|| _|| _t| j}|tj tj tjkrtdd |jD }	|	sttj}
|
tj d|
_||
 dS )a  
        Args:
            prefix: will be printed in format: "{prefix} statistics".
            data_type: whether to show the type of input data.
            data_shape: whether to show the shape of input data.
            value_range: whether to show the value range of input data.
            data_value: whether to show the raw value of input data.
                a typical example is to print some properties of Nifti image: affine, pixdim, etc.
            additional_info: user can define callable function to extract additional info from input data.
            name: identifier of `logging.logger` to use, defaulting to "DataStats".

        Raises:
            TypeError: When ``additional_info`` is not an ``Optional[Callable]``.

        zprefix must be a string, got r   Nz0additional_info must be None or callable but is c                 s  s   | ]}t |d o|jV  qdS )is_data_stats_handlerN)hasattrr   ).0r   r^   r^   ra   	<genexpr>  s    z%DataStats.__init__.<locals>.<genexpr>T)rr   r   rs   r   r   r   r   r   r   callabler   rd   r   _logger_namelogging	getLoggersetLevelINFOrootgetEffectiveLevelanyhandlersStreamHandlersysstdoutr   
addHandler)r`   r   r   r   r   r   r   r5   _loggerZhas_console_handlerconsoler^   r^   ra   rt     s,    
zDataStats.__init__r
   z
str | Noner   )rZ   r   r   r   r   r   r   r[   c                 C  s  |p| j  dg}|dkr"| jrPn|rP|dt| dt|drF|jnd  |dkr`| jrn|r|dt|drz|jnd  |dkr| jrn|rt	|t
jr|dt
| d	t
| d
 nJt	|tjr|dt| d	t| d
 n|dt| d
 |dkr,| jrBn|rB|d|  |dkrR| jn|}|dk	rt|d||  d}	|	| }
t| j|
 |S )zk
        Apply the transform to `img`, optionally take arguments similar to the class constructor.
        z statistics:NzType:  r   zShape: r   zValue range: (z, )z2Value range: (not a PyTorch or Numpy array, type: zValue: zAdditional info: 
)r   r   appendr   r   r   r   r   r   rr   r   r   minmaxr   r   r   r   joinr   r   r   info)r`   rZ   r   r   r   r   r   r   lines	separatoroutputr^   r^   ra   rb     s(    * &&
zDataStats.__call__)r   TTTFNrG   )NNNNNNry   r^   r^   r^   ra   rG   {  s"          5      c                      sJ   e Zd ZdZejejgZdddd fddZdd	d
d	dddZ	  Z
S )rH   a  
    This is a pass through transform to be used for testing purposes. It allows
    adding fake behaviors that are useful for testing purposes to simulate
    how large datasets behave without needing to test on large data sets.

    For example, simulating slow NFS data transfers, or slow network transfers
    in testing by adding explicit timing delays. Testing of small test data
    can lead to incomplete understanding of real world issues, and may lead
    to sub-optimal design choices.
            r   ro   )
delay_timer[   c                   s   t    || _dS )z
        Args:
            delay_time: The minimum amount of time, in fractions of seconds,
                to accomplish this delay task.
        N)r   rt   r  )r`   r  r   r^   ra   rt     s    
zSimulateDelay.__init__Nr
   zfloat | None)rZ   r  r[   c                 C  s   t |dkr| jn| |S )z
        Args:
            img: data remain unchanged throughout this transform.
            delay_time: The minimum amount of time, in fractions of seconds,
                to accomplish this delay task.
        N)timesleepr  )r`   rZ   r  r^   r^   ra   rb     s    zSimulateDelay.__call__)r   )N)rd   re   rf   rg   r/   rh   ri   rj   rt   rb   r   r^   r^   r   ra   rH     s   	c                   @  sX   e Zd ZdZejejgZdedfdddddd	d
Z	ddddddZ
ddddZdS )rI   a  
    Apply a user-defined lambda as a transform.

    For example:

    .. code-block:: python
        :emphasize-lines: 2

        image = np.ones((10, 2, 2))
        lambd = Lambda(func=lambda x: x[:4, :, :])
        print(lambd(image).shape)
        (4, 2, 2)

    Args:
        func: Lambda/function to be applied.
        inv_func: Lambda/function of inverse operation, default to `lambda x: x`.
        track_meta:  If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)
            as opposed to MONAI's enhanced objects. By default, this is `True`.

    Raises:
        TypeError: When ``func`` is not an ``Optional[Callable]``.

    NTr   r   rz   ro   )funcinv_funcrv   r[   c                 C  s<   |d k	r&t |s&tdt|j d|| _|| _|| _d S )N%func must be None or callable but is r   )r   r   r   rd   r  r  rv   r`   r  r  rv   r^   r^   ra   rt     s
    zLambda.__init__r
   rZ   r  c                 C  sz   |dk	r|n| j }t|s0tdt|j d||}t|tjtj	frbt|t
sb| jrbt
|}t|t
rv| | |S )z
        Apply `self.func` to `img`.

        Args:
            func: Lambda/function to be applied. Defaults to `self.func`.

        Raises:
            TypeError: When ``func`` is not an ``Optional[Callable]``.

        Nr  r   )r  r   r   r   rd   rr   r   r   r   r   r   rv   push_transform)r`   rZ   r  fnrx   r^   r^   ra   rb   '  s    "

zLambda.__call__r}   r   c                 C  s   t |tr| | | |S rl   )rr   r   pop_transformr  rm   r^   r^   ra   inverse=  s    

zLambda.inverse)N)rd   re   rf   rg   r/   rh   ri   rj   r   rt   rb   r  r^   r^   r^   ra   rI     s     	c                      sb   e Zd ZdZejZddedfddddd	d
ddZdddd fddZdd fddZ	  Z
S )rJ   a  
    Randomizable version :py:class:`monai.transforms.Lambda`, the input `func` may contain random logic,
    or randomly execute the function based on `prob`.

    Args:
        func: Lambda/function to be applied.
        prob: probability of executing the random function, default to 1.0, with 100% probability to execute.
        inv_func: Lambda/function of inverse operation, default to `lambda x: x`.
        track_meta:  If `False`, then standard data objects will be returned (e.g., torch.Tensor` and `np.ndarray`)
            as opposed to MONAI's enhanced objects. By default, this is `True`.

    For more details, please check :py:class:`monai.transforms.Lambda`.
    N      ?Tr   r   r   rz   ro   )r  probr  rv   r[   c                 C  s$   t j| |||d tj| |d d S )Nr  )r`   r  )rI   rt   r   )r`   r  r  r  rv   r^   r^   ra   rt   T  s    zRandLambda.__init__r
   r  c                   sn   |  | t| jr t ||n|}t|ts>| jr>t|}t|trj| jrX| |ni }| j	||d |S )N)
extra_info)
	randomizer   _do_transformr   rb   rr   r   rv   r  r	  )r`   rZ   r  rx   Zlambda_infor   r^   ra   rb   ^  s    

zRandLambda.__call__r}   r   c                   s2   |  |tj}|r$t |}n
| | |S rl   )get_most_recent_transformpopr&   DO_TRANSFORMr   r  r  )r`   rk   do_transformr   r^   ra   r  i  s
    
zRandLambda.inverse)N)rd   re   rf   rg   rI   rj   r   rt   rb   r  r   r^   r^   r   ra   rJ   C  s   
c                   @  sF   e Zd ZdZejejgZdddddddZdd
ddd
dddZ	d	S )rK   a  
    Convert labels to mask for other tasks. A typical usage is to convert segmentation labels
    to mask data to pre-process images and then feed the images into classification network.
    It can support single channel labels or One-Hot labels with specified `select_labels`.
    For example, users can select `label value = [2, 3]` to construct mask data, or select the
    second and the third channels of labels to construct mask data.
    The output mask data can be a multiple channels binary data or a single channel binary
    data that merges all the channels.

    Args:
        select_labels: labels to generate mask from. for 1 channel label, the `select_labels`
            is the expected label values, like: [1, 2, 3]. for One-Hot format label, the
            `select_labels` is the expected channel indices.
        merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes,
            will return a single channel mask with binary data.

    FzSequence[int] | intrz   ro   )select_labelsmerge_channelsr[   c                 C  s   t || _|| _d S rl   )r+   r  r  )r`   r  r  r^   r^   ra   rt     s    
zLabelToMask.__init__Nr
   zSequence[int] | int | None)rZ   r  r  r[   c                 C  s  t |t d}|dkr| j}nt|}|jd dkr@|| }n|t|tjrRtjnt	j}t|tjsnt
t	dr|t||dd|j}n2|t||t	jd|jdt	jd|jd|j}|s| jrt|tjst
t	dr|dd S |t	jdd tS |S )	a  
        Args:
            select_labels: labels to generate mask from. for 1 channel label, the `select_labels`
                is the expected label values, like: [1, 2, 3]. for One-Hot format label, the
                `select_labels` is the expected channel indices.
            merge_channels: whether to use `np.any()` to merge the result on channel dim. if yes,
                will return a single channel mask with binary data.
        ru   Nr   r   )r      r   TFr   )r*   r   r  r+   r   rr   r   r   wherer   r0   r"   reshaper   r   r  r   touint8rz   )r`   rZ   r  r  rk   r  r^   r^   ra   rb     s*      zLabelToMask.__call__)F)NFry   r^   r^   r^   ra   rK   r  s       c                   @  sF   e Zd ZdZejejgZddddddd	Zdd
ddddddZ	dS )rL   a  
    Compute foreground and background of the input label data, return the indices.
    If no output_shape specified, output data will be 1 dim indices after flattening.
    This transform can help pre-compute foreground and background regions for other transforms.
    A typical usage is to randomly select foreground and background to crop.
    The main logic is based on :py:class:`monai.transforms.utils.map_binary_to_indices`.

    Args:
        image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to
            determine the valid image content area and select background only in this area.
        output_shape: expected shape of output indices. if not None, unravel indices to specified shape.

    r   Nr   r   ro   )image_thresholdoutput_shaper[   c                 C  s   || _ || _d S rl   )r  r  )r`   r  r  r^   r^   ra   rt     s    zFgBgToIndices.__init__r
   NdarrayOrTensor | Nonez'tuple[NdarrayOrTensor, NdarrayOrTensor]labelimager  r[   c                 C  sD   |dkr| j }t||| j\}}|dk	r<t||}t||}||fS )a  
        Args:
            label: input data to compute foreground and background indices.
            image: if image is not None, use ``label = 0 & image > image_threshold``
                to define background. so the output items will not map to all the voxels in the label.
            output_shape: expected shape of output indices. if None, use `self.output_shape` instead.

        N)r  r   r  r$   )r`   r!  r"  r  
fg_indices
bg_indicesr^   r^   ra   rb     s    

zFgBgToIndices.__call__)r   N)NN)
rd   re   rf   rg   r/   ri   rh   rj   rt   rb   r^   r^   r^   ra   rL     s      c                   @  sF   e Zd ZejejgZddddddddd	Zdd
ddddddZdS )rM   Nr   r   r   r   ro   )num_classesr  r  max_samples_per_classr[   c                 C  s   || _ || _|| _|| _dS )a  
        Compute indices of every class of the input label data, return a list of indices.
        If no output_shape specified, output data will be 1 dim indices after flattening.
        This transform can help pre-compute indices of the class regions for other transforms.
        A typical usage is to randomly select indices of classes to crop.
        The main logic is based on :py:class:`monai.transforms.utils.map_classes_to_indices`.

        Args:
            num_classes: number of classes for argmax label, not necessary for One-Hot label.
            image_threshold: if enabled `image` at runtime, use ``image > image_threshold`` to
                determine the valid image content area and select only the indices of classes in this area.
            output_shape: expected shape of output indices. if not None, unravel indices to specified shape.
            max_samples_per_class: maximum length of indices to sample in each class to reduce memory consumption.
                Default is None, no subsampling.

        N)r%  r  r  r&  )r`   r%  r  r  r&  r^   r^   ra   rt     s    zClassesToIndices.__init__r
   r  zlist[NdarrayOrTensor]r   c                   sB    dkr| j  t|| j|| j| j} dk	r> fdd|D }|S )ai  
        Args:
            label: input data to compute the indices of every class.
            image: if image is not None, use ``image > image_threshold`` to define valid region, and only select
                the indices within the valid region.
            output_shape: expected shape of output indices. if None, use `self.output_shape` instead.

        Nc                   s   g | ]}t | qS r^   )r$   )r   cls_indicesr  r^   ra   
<listcomp>  s     z-ClassesToIndices.__call__.<locals>.<listcomp>)r  r    r%  r  r&  )r`   r!  r"  r  r   r^   r(  ra   rb     s        zClassesToIndices.__call__)Nr   NN)NN)	rd   re   rf   r/   ri   rh   rj   rt   rb   r^   r^   r^   ra   rM     s          c                   @  s,   e Zd ZdZejejgZdddddZdS )rN   a%  
    Convert labels to multi channels based on brats18 classes:
    label 1 is the necrotic and non-enhancing tumor core
    label 2 is the peritumoral edema
    label 4 is the GD-enhancing tumor
    The possible classes are TC (Tumor core), WT (Whole tumor)
    and ET (Enhancing tumor).
    r
   rY   c                 C  sx   |j dkr"|jd dkr"|d}|dk|dkB |dk|dkB |dkB |dkg}t|tjrjtj|ddS tj|ddS )N   r   r   r   )r   axis)r   r   r   rr   r   r   stackr   )r`   rZ   r   r^   r^   ra   rb     s    
.z1ConvertToMultiChannelBasedOnBratsClasses.__call__Nrc   r^   r^   r^   ra   rN     s   	c                   @  sV   e Zd ZdZejgZddddddd	Zd
ddddZdd
ddddd
dddZ	dS )rO   ar  
    Add extreme points of label to the image as a new channel. This transform generates extreme
    point from label and applies a gaussian filter. The pixel values in points image are rescaled
    to range [rescale_min, rescale_max] and added as a new channel to input image. The algorithm is
    described in Roth et al., Going to Extremes: Weakly Supervised Medical Image Segmentation
    https://arxiv.org/abs/2009.11988.

    This transform only supports single channel labels (1, spatial_dim1, [spatial_dim2, ...]). The
    background ``index`` is ignored when calculating extreme points.

    Args:
        background: Class index of background label, defaults to 0.
        pert: Random perturbation amount to add to the points, defaults to 0.0.

    Raises:
        ValueError: When no label image provided.
        ValueError: When label image is not single channel.
    r   r   rn   r   ro   )
backgroundpertr[   c                 C  s   || _ || _g | _d S rl   )_background_pert_points)r`   r.  r/  r^   r^   ra   rt   @  s    z AddExtremePointsChannel.__init__r
   )r!  r[   c                 C  s   t || j| j| jd| _d S )N)
rand_stater.  r/  )r   Rr0  r1  r2  )r`   r!  r^   r^   ra   r  E  s    z!AddExtremePointsChannel.randomizeN      @      r  r  z?Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor)rZ   r!  sigmarescale_minrescale_maxr[   c                 C  sn   |dkrt d|jd dkr&t d| |dddf  t| j||||d}t||^}}t||fddS )a`  
        Args:
            img: the image that we want to add new channel to.
            label: label image to get extreme points from. Shape must be
                (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels.
            sigma: if a list of values, must match the count of spatial dimensions of input data,
                and apply every value in the list to 1 spatial dimension. if only 1 value provided,
                use it for all spatial dimensions.
            rescale_min: minimum value of output data.
            rescale_max: maximum value of output data.
        Nz&This transform requires a label array!r   r   z$Only supports single channel labels!)pointsr!  r7  r8  r9  r+  )rs   r   r  r   r2  r1   r!   )r`   rZ   r!  r7  r8  r9  points_imager   r^   r^   ra   rb   H  s        z AddExtremePointsChannel.__call__)r   r   )Nr5  r6  r  )
rd   re   rf   rg   r/   rh   rj   rt   r  rb   r^   r^   r^   ra   rO   *  s       c                      s>   e Zd ZdZejgZddd fddZddd	d
Z  Z	S )rP   aY  
    This is a wrapper transform for PyTorch TorchVision transform based on the specified transform name and args.
    As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
    data to be PyTorch Tensor, users can easily call `ToTensor` transform to convert a Numpy array to Tensor.

    r   ro   r5   r[   c                   s4   t    || _tddt|d\}}|||| _dS )z
        Args:
            name: The transform name in TorchVision package.
            args: parameters for the TorchVision transform.
            kwargs: parameters for the TorchVision transform.

        ztorchvision.transformsz0.8.0r4   N)r   rt   r5   r.   r-   trans)r`   r5   argskwargs	transformr   r   r^   ra   rt   t  s    
zTorchVision.__init__r
   r   c                 C  s.   t |tj^}}| |}t||d^}}|S )z\
        Args:
            img: PyTorch Tensor data for the TorchVision transform.

        )srcdst)r'   r   r   r=  r1   )r`   rZ   img_tr   rx   r^   r^   ra   rb     s    
zTorchVision.__call__r   r^   r^   r   ra   rP   j  s   c                   @  sD   e Zd ZdZejejgZej	fdddddddZ
dd	d
dZdS )rQ   aW  
    Utility to map label values to another set of values.
    For example, map [3, 2, 1] to [0, 1, 2], [1, 2, 3] -> [0.5, 1.5, 2.5], ["label3", "label2", "label1"] -> [0, 1, 2],
    [3.5, 2.5, 1.5] -> ["label0", "label1", "label2"], etc.
    The label data must be numpy array or array-like data and the output data will be numpy array.

    r   r	   ro   )orig_labelstarget_labelsr   r[   c                 C  s   t |t |krtd|| _|| _tdd t| j| jD | _t|}t|dddkrrd| _	t
|tjd| _nd	| _	t
|tjd| _d
S )a^  
        Args:
            orig_labels: original labels that map to others.
            target_labels: expected label values, 1: 1 map to the `orig_labels`.
            dtype: convert the output data to dtype, default to float32.
                if dtype is from PyTorch, the transform will use the pytorch backend, else with numpy backend.

        z8orig_labels and target_labels must have the same length.c                 s  s"   | ]\}}||kr||fV  qd S rl   r^   )r   otr^   r^   ra   r     s      z)MapLabelValue.__init__.<locals>.<genexpr>re    r   F)r   TN)r   rs   rD  rE  r   zippairr   getattr	use_numpyr2   r   r   r   r   r   )r`   rD  rE  r   Z
type_dtyper^   r^   ra   rt     s    	zMapLabelValue.__init__r
   r   c                 C  s   | j rt|tj^}}|j}| }z|| j}W n& tk
rZ   tj	|j| jd}Y nX | j
D ]\}}||||k< qb||}	n@t|tj^}
}|
  | j}	| j
D ]\}}||	|
|k< qt|	|| jd^}}|S )N)r   r   )rA  rB  r   )rL  r'   r   r   r   flattenastyper   rs   zerosrJ  r  r   r   r   cloner  r1   )r`   rZ   img_npr   Z
_out_shapeimg_flatout_flatrF  rG  out_trC  rx   r^   r^   ra   rb     s"    zMapLabelValue.__call__N)rd   re   rf   rg   r/   ri   rh   rj   r   r   rt   rb   r^   r^   r^   ra   rQ     s   c                   @  sD   e Zd ZdZejgZdddddddd	ZddddddddZd
S )rR   a9  
    Compute statistics for the intensity values of input image and store into the metadata dictionary.
    For example: if `ops=[lambda x: np.mean(x), "max"]` and `key_prefix="orig"`, may generate below stats:
    `{"orig_custom_0": 1.5, "orig_max": 3.0}`.

    Args:
        ops: expected operations to compute statistics for the intensity.
            if a string, will map to the predefined operations, supported: ["mean", "median", "max", "min", "std"]
            mapping to `np.nanmean`, `np.nanmedian`, `np.nanmax`, `np.nanmin`, `np.nanstd`.
            if a callable function, will execute the function on input image.
        key_prefix: the prefix to combine with `ops` name to generate the key to store the results in the
            metadata dictionary. if some `ops` are callable functions, will use "{key_prefix}_custom_{index}"
            as the key, where index counts from 0.
        channel_wise: whether to compute statistics for every channel of input image separately.
            if True, return a list of values for every operation, default to False.

    FzSequence[str | Callable]r   rz   ro   )ops
key_prefixchannel_wiser[   c                 C  s   t || _|| _|| _d S rl   )r+   rU  rV  rW  )r`   rU  rV  rW  r^   r^   ra   rt     s    
zIntensityStats.__init__Nr
   dict | Noneznp.ndarray | Noneztuple[NdarrayOrTensor, dict])rZ   	meta_datamaskr[   c           
        s"  t |tj^}}|dkri }|dk	rn|j|jkrJtd|j d|j d|jtkrftd|j d|| }tjtj	tj
tjtjd}ddd	 fd
d}d} jD ]r}	t|	trt|	| }	|||	 || jd |	 < qt|	r||	|| jd t| < |d7 }qtdq||fS )a  
        Compute statistics for the intensity of input image.

        Args:
            img: input image to compute intensity stats.
            meta_data: metadata dictionary to store the statistics data, if None, will create an empty dictionary.
            mask: if not None, mask the image to extract only the interested area to compute statistics.
                mask must have the same shape as input `img`.

        Nz2mask must have the same shape as input `img`, got z and r   z"mask must be bool array, got type )meanmedianr   r   stdr   z
np.ndarrayoprk   c                   s    j r fdd|D S  |S )Nc                   s   g | ]} |qS r^   r^   )r   cr_  r^   ra   r)    s     z=IntensityStats.__call__.<locals>._compute.<locals>.<listcomp>)rW  r^  r`   ra  ra   _compute   s    z)IntensityStats.__call__.<locals>._computer   r   Z_custom_r   zFops must be key string for predefined operations or callable function.)r'   r   r   r   rs   r   rz   r   nanmean	nanmediannanmaxnanminnanstdrU  rr   r   r,   keysrV  r   )
r`   rZ   rY  rZ  rQ  r   Zsupported_opsrc  Zcustom_indexrF  r^   rb  ra   rb     s4    





zIntensityStats.__call__)F)NN	rd   re   rf   rg   r/   ri   rj   rt   rb   r^   r^   r^   ra   rR     s      c                   @  s6   e Zd ZdZejgZdddddZddd	d
ZdS )rS   a  
    Move PyTorch Tensor to the specified device.
    It can help cache data into GPU and execute following logic on GPU directly.

    Note:
        If moving data to GPU device in the multi-processing workers of DataLoader, may got below CUDA error:
        "RuntimeError: Cannot re-initialize CUDA in forked subprocess. To use CUDA with multiprocessing,
        you must use the 'spawn' start method."
        So usually suggest to set `num_workers=0` in the `DataLoader` or `ThreadDataLoader`.

    ztorch.device | strro   )r   r[   c                 K  s   || _ || _dS )a	  
        Args:
            device: target device to move the Tensor, for example: "cuda:1".
            kwargs: other args for the PyTorch `Tensor.to()` API, for more details:
                https://pytorch.org/docs/stable/generated/torch.Tensor.to.html.

        N)r   r?  )r`   r   r?  r^   r^   ra   rt   "  s    zToDevice.__init__r}   r   c                 C  s&   t |tjstd|j| jf| jS )NzTimg must be PyTorch Tensor, consider converting img by `EnsureType` transform first.)rr   r   r   rs   r  r   r?  r_   r^   r^   ra   rb   -  s    zToDevice.__call__Nr   r^   r^   r^   ra   rS     s   c                      s0   e Zd ZdZddd fddZdd Z  ZS )	rT   a%  
    Wrap a non-randomized cuCIM transform, defined based on the transform name and args.
    For randomized transforms use :py:class:`monai.transforms.RandCuCIM`.

    Args:
        name: the transform name in CuCIM package
        args: parameters for the CuCIM transform
        kwargs: parameters for the CuCIM transform

    Note:
        CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.
        Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.
    r   ro   r<  c                   s2   t    || _td|d\| _}|| _|| _d S )Nz&cucim.core.operations.expose.transformr4   )r   rt   r5   r.   r@  r>  r?  )r`   r5   r>  r?  r   r   r^   ra   rt   C  s
    
zCuCIM.__init__c                 C  s   | j |f| j| jS )z
        Args:
            data: a CuPy array (`cupy.ndarray`) for the cuCIM transform

        Returns:
            `cupy.ndarray`

        )r@  r>  r?  rm   r^   r^   ra   rb   J  s    	zCuCIM.__call__)rd   re   rf   rg   rt   rb   r   r^   r^   r   ra   rT   4  s   c                   @  s    e Zd ZdZdddddZdS )rU   a  
    Wrap a randomized cuCIM transform, defined based on the transform name and args
    For deterministic non-randomized transforms use :py:class:`monai.transforms.CuCIM`.

    Args:
        name: the transform name in CuCIM package.
        args: parameters for the CuCIM transform.
        kwargs: parameters for the CuCIM transform.

    Note:
        - CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.
          Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.
        - If the random factor of the underlying cuCIM transform is not derived from `self.R`,
          the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.
    r   ro   r<  c                 O  s   t j| |f|| d S rl   )rT   rt   )r`   r5   r>  r?  r^   r^   ra   rt   g  s    zRandCuCIM.__init__N)rd   re   rf   rg   rt   r^   r^   r^   ra   rU   V  s   c                   @  s8   e Zd ZdZejgZdddddZdddd	d
ZdS )r;   a  
    Appends additional channels encoding coordinates of the input. Useful when e.g. training using patch-based sampling,
    to allow feeding of the patch's location into the network.

    This can be seen as a input-only version of CoordConv:

    Liu, R. et al. An Intriguing Failing of Convolutional Neural Networks and the CoordConv Solution, NeurIPS 2018.

    Args:
        spatial_dims: the spatial dimensions that are to have their coordinates encoded in a channel and
            appended to the input image. E.g., `(0, 1, 2)` represents `H, W, D` dims and append three channels
            to the input image, encoding the coordinates of the input's three spatial dimensions.

    zSequence[int]ro   )spatial_dimsr[   c                 C  s
   || _ d S rl   )rk  )r`   rk  r^   r^   ra   rt   }  s    zAddCoordinateChannels.__init__r
   rY   c                 C  s   t | j|jd ks"t| jdk r8td|jd  d|jdd }ttjt	dd |D d	d
i}t
||^}}|t| j }t||fddS )za
        Args:
            img: data to be transformed, assuming `img` is channel first.
        r   r   z)`spatial_dims` values must be within [0, ]r   Nc                 s  s   | ]}t d d|V  qdS )g      g      ?N)r   linspace)r   sr^   r^   ra   r     s     z1AddCoordinateChannels.__call__.<locals>.<genexpr>indexingijr+  )r   rk  r   r   rs   r   r   arraymeshgridr   r1   r   r!   )r`   rZ   spatial_sizeZcoord_channelsr   r^   r^   ra   rb     s    "$zAddCoordinateChannels.__call__Nrj  r^   r^   r^   ra   r;   k  s   c                	   @  s   e Zd ZdZejejgZedddddddd	gZ	d+ddddddZ
d,dddddddZdddddZd-ddddddZdddddd Zd!d"d"d#d$d%d&Zd'd'd(d)d*Zd
S ).rW   at  
    Applies a convolution filter to the input image.

    Args:
        filter:
            A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``.
            Available options for string are: ``mean``, ``laplace``, ``elliptical``, ``sobel``, ``sharpen``, ``median``, ``gauss``
            See below for short explanations on every filter.
        filter_size:
            A single integer value specifying the size of the quadratic or cubic filter.
            Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which
            should be considered when choosing filter size.
        kwargs:
            Additional arguments passed to filter function, required by ``sobel`` and ``gauss``.
            See below for details.

    Raises:
        ValueError: When ``filter_size`` is not an uneven integer
        ValueError: When ``filter`` is an array and ``ndim`` is not in [1,2,3]
        ValueError: When ``filter`` is an array and any dimension has an even shape
        NotImplementedError: When ``filter`` is a string and not in ``self.supported_filters``
        KeyError: When necessary ``kwargs`` are not passed to a filter that requires additional arguments.


    **Mean Filtering:** ``filter='mean'``

    Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image.
    See also py:func:`monai.networks.layers.simplelayers.MeanFilter`
    Example 2D filter (5 x 5)::

        [[1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1],
         [1, 1, 1, 1, 1]]

    If smoothing labels with this filter, ensure they are in one-hot format.

    **Outline Detection:** ``filter='laplace'``

    Laplacian filtering for outline detection in images. Can be used to transform labels to contours.
    See also py:func:`monai.networks.layers.simplelayers.LaplaceFilter`

    Example 2D filter (5x5)::

        [[-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., 24., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., -1., -1., -1.]]


    **Dilation:** ``filter='elliptical'``

    An elliptical filter can be used to dilate labels or label-contours.
    Example 2D filter (5x5)::

        [[0., 0., 1., 0., 0.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [0., 0., 1., 0., 0.]]


    **Edge Detection:** ``filter='sobel'``

    This filter allows for additional arguments passed as ``kwargs`` during initialization.
    See also py:func:`monai.transforms.post.SobelGradients`

    *kwargs*

    * ``spatial_axes``: the axes that define the direction of the gradient to be calculated.
      It calculates the gradient along each of the provide axis.
      By default it calculate the gradient for all spatial axes.
    * ``normalize_kernels``: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
    * ``normalize_gradients``: if normalize the output gradient to 0 and 1. Defaults to False.
    * ``padding_mode``: the padding mode of the image when convolving with Sobel kernels. Defaults to ``"reflect"``.
      Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
      See ``torch.nn.Conv1d()`` for more information.
    * ``dtype``: kernel data type (torch.dtype). Defaults to ``torch.float32``.


    **Sharpening:** ``filter='sharpen'``

    Sharpen an image with a 2D or 3D filter.
    Example 2D filter (5x5)::

        [[ 0.,  0., -1.,  0.,  0.],
         [-1., -1., -1., -1., -1.],
         [-1., -1., 17., -1., -1.],
         [-1., -1., -1., -1., -1.],
         [ 0.,  0., -1.,  0.,  0.]]


    **Gaussian Smooth:** ``filter='gauss'``

    Blur/smooth an image with 2D or 3D gaussian filter.
    This filter requires additional arguments passed as ``kwargs`` during initialization.
    See also py:func:`monai.networks.layers.simplelayers.GaussianFilter`

    *kwargs*

    * ``sigma``: std. could be a single value, or spatial_dims number of values.
    * ``truncated``: spreads how many stds.
    * ``approx``: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".


    **Median Filter:** ``filter='median'``

    Blur an image with 2D or 3D median filter to remove noise.
    Useful in image preprocessing to improve results of later processing.
    See also py:func:`monai.networks.layers.simplelayers.MedianFilter`


    **Savitzky Golay Filter:** ``filter = 'savitzky_golay'``

    Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.
    This filter requires additional arguments passed as ``kwargs`` during initialization.
    See also py:func:`monai.networks.layers.simplelayers.SavitzkyGolayFilter`

    *kwargs*

    * ``order``: Order of the polynomial to fit to each window, must be less than ``window_length``.
    * ``axis``: (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).
    * ``mode``: (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
      ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.

    r[  laplace
ellipticalsobelsharpenr\  gausssavitzky_golayNz!str | NdarrayOrTensor | nn.Moduler   ro   )filterfilter_sizer[   c                 K  s0   |  || | j|f| || _|| _|| _d S rl   )_check_filter_format_check_kwargs_are_presentrz  r{  additional_args_for_filter)r`   rz  r{  r?  r^   r^   ra   rt     s
    zImageFilter.__init__r
   rX  zlist | None)rZ   r   r   r[   c           	      C  s   t |tr|j}|j}t|tj\}}}|jd }t | jt	rT| 
| j| j|| _n t | jtjtjfrtt| j| _| |}|dk	s|dk	rt|||d}nt|||^}}|S )ay  
        Args:
            img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]
            meta_dict: An optional dictionary with metadata
            applied_operations: An optional list of operations that have been applied to the data

        Returns:
            A MetaTensor with the same shape as `img` and identical metadata
        r   N)r   r   )rr   r   r   r   r'   r   r   r   rz  r   _get_filter_from_stringr{  r   r   r   _apply_filter)	r`   rZ   r   r   img_	prev_typer   r   r   r^   r^   ra   rb     s    


zImageFilter.__call__r   )xr[   c                 C  s(   |D ]}|d dkrt d| qd S )Nr   r   z6Only uneven filters are supported, but filter size is )rs   )r`   r  valuer^   r^   ra   _check_all_values_uneven<  s    z$ImageFilter._check_all_values_unevenc                 C  s   t |trL|std|d dkr*td|| jkrt| d| j dnTt |tjtjfr~|j	dkrptd| 
|j n"t |tjtfstt| d	d S )
NzB`filter_size` must be specified when specifying filters by string.r   r   z0`filter_size` should be a single uneven integer.z. Supported filters are r   )r   r      z*Only 1D, 2D, and 3D filters are supported.z is not supported.Supported types are `class 'str'`, `class 'torch.Tensor'`, `class 'np.ndarray'`, `class 'torch.nn.modules.module.Module'`, `class 'monai.transforms.Transform'`)rr   r   rs   supported_filtersNotImplementedErrorr   r   r   r   r   r  r   nnModuler   r   r   )r`   rz  r{  r^   r^   ra   r|  A  s    


z ImageFilter._check_filter_formatr   )rz  r?  r[   c                 K  sJ   t |tsdS |dkr*d| kr*td|dkrFd| krFtddS )a  
        Perform sanity checks on the kwargs if the filter contains the required keys.
        If the filter is ``gauss``, kwargs should contain ``sigma``.
        If the filter is ``savitzky_golay``, kwargs should contain ``order``.

        Args:
            filter: A string specifying the filter, a custom filter as ``torch.Tenor`` or ``np.ndarray`` or a ``nn.Module``.
            kwargs: additional arguments defining the filter.

        Raises:
            KeyError if the filter doesn't contain the requirement key.
        Nrx  r7  zA`filter='gauss', requires the additional keyword argument `sigma`ry  orderzJ`filter='savitzky_golay', requires the additional keyword argument `order`)rr   r   ri  KeyError)r`   rz  r?  r^   r^   ra   r}  T  s    
z%ImageFilter._check_kwargs_are_presentr   rn   znn.Module | Callable)rz  sizer   r[   c                   s(  |dkrt ||S |dkr$t||S |dkr6t||S |dkrzddlm} |jj   fdd| j	 D }||f|S |d	krt
||S |d
krtjj   fdd| j	 D }t|f|S |dkrtt||dS |dkrtjj   fdd| j	 D }t|f|S td| dd S )Nr[  rt  ru  rv  r   )SobelGradientsc                   s   i | ]\}}| kr||qS r^   r^   r   kvallowed_keysr^   ra   
<dictcomp>t  s       z7ImageFilter._get_filter_from_string.<locals>.<dictcomp>rw  rx  c                   s   i | ]\}}| kr||qS r^   r^   r  r  r^   ra   r  z  s       r\  )kernel_sizerk  ry  c                   s   i | ]\}}| kr||qS r^   r^   r  r  r^   ra   r    s       zFilter z not implemented)r   r   r   Zmonai.transforms.post.arrayr  rt   __annotations__ri  r~  itemsr   r   r   r   r   r  )r`   rz  r  r   r  r?  r^   r  ra   r  i  s0    




z#ImageFilter._get_filter_from_stringr}   rY   c                 C  s4   t | jtr| |}n| |d}|d }|S )Nr   )rr   rz  r   	unsqueezer_   r^   r^   ra   r    s
    zImageFilter._apply_filter)N)NN)N)rd   re   rf   rg   r/   rh   ri   rj   sortedr  rt   rb   r  r|  r}  r  r  r^   r^   r^   ra   rW     s       c                      sH   e Zd ZdZejZdddddd fd	d
ZdddddddZ  ZS )rX   a  
    Randomly apply a convolutional filter to the input data.

    Args:
        filter:
            A string specifying the filter or a custom filter as `torch.Tenor` or `np.ndarray`.
            Available options are: `mean`, `laplace`, `elliptical`, `gaussian``
            See below for short explanations on every filter.
        filter_size:
            A single integer value specifying the size of the quadratic or cubic filter.
            Computational complexity scales to the power of 2 (2D filter) or 3 (3D filter), which
            should be considered when choosing filter size.
        prob:
            Probability the transform is applied to the data
    N皙?zstr | NdarrayOrTensorr   r   ro   )rz  r{  r  r[   c                   s    t  | t||f|| _d S rl   )r   rt   rW   rz  )r`   rz  r{  r  r?  r   r^   ra   rt     s    zRandImageFilter.__init__r
   r~   r   c                 C  s   |  d | jr| |}|S )a  
        Args:
            img: torch tensor data to apply filter to with shape: [channels, height, width[, depth]]
            meta_dict: An optional dictionary with metadata
            kwargs: optional arguments required by specific filters. E.g. `sigma`if filter is `gauss`.
                see py:func:`monai.transforms.utility.array.ImageFilter` for more details

        Returns:
            A MetaTensor with the same shape as `img` and identical metadata
        N)r  r  rz  )r`   rZ   r   r^   r^   ra   rb     s    

zRandImageFilter.__call__)Nr  )N)	rd   re   rf   rg   rW   rj   rt   rb   r   r^   r^   r   ra   rX     s      )wrg   
__future__r   r   r   r  r   collections.abcr   r   copyr   	functoolsr   typingr   r   r   r   r   torch.nnr  monai.configr	   monai.config.type_definitionsr
   monai.data.meta_objr   monai.data.meta_tensorr   monai.data.utilsr   r   Z"monai.networks.layers.simplelayersr   r   r   r   r   r   r   r   monai.transforms.inverser   monai.transforms.traitsr   monai.transforms.transformr   r   r   r   monai.transforms.utilsr   r   r   r    0monai.transforms.utils_pytorch_numpy_unificationr!   r"   r#   r$   monai.utilsr%   r&   r'   r(   r)   r*   r+   r,   r-   r.   monai.utils.enumsr/   monai.utils.miscr0   monai.utils.type_conversionr1   r2   r   has_pilr   r   cphas_cp__all__r8   r9   r:   r<   r>   r?   r@   rA   rB   r=   rC   rV   rD   rE   rF   rG   rH   rI   rJ   rK   rL   rM   rN   rO   rP   rQ   rR   rS   rT   rU   r;   rW   rX   r^   r^   r^   ra   <module>   s   (
0%>..+K3f"@/A)7@$8M!"$  