o
    i"                     @  s  d Z ddlmZ ddlZddlmZmZmZmZm	Z	 ddl
mZ ddlmZ ddlZddlZddlmZ ddlmZmZmZ dd	lmZ dd
lmZ ddlmZ ddlmZmZm Z m!Z!m"Z"m#Z#m$Z$m%Z%m&Z&m'Z'm(Z(m)Z) ddl*m+Z+ ddl,m-Z- ddl.m/Z/m0Z0 ddl1m2Z2m3Z3m4Z4m5Z5 g dZ6e27 Z8G dd de+Z9G dd de+Z:G dd de+Z;G dd de+Z<G dd de+Z=G dd de+Z>G dd de+Z?G d d! d!e+Z@G d"d# d#e@ZAG d$d% d%e@ZBG d&d' d'e+ZCG d(d) d)e+ZDG d*d+ d+e+ZEG d,d- d-e+ZFG d.d/ d/e+ZGe9 ZHZIe: ZJZKe> ZLZMeD ZNZOe; ZPZQe< ZRZSe= ZTZUe? ZVZWeA ZXZYeC ZZZ[eE Z\Z]eB Z^Z_e@ Z`ZaeF ZbZceG ZdZedS )0z
A collection of dictionary-based wrappers around the "vanilla" transforms for model output tensors
defined in :py:class:`monai.transforms.utility.array`.

Class names are ended with 'd' to denote dictionary-based transforms.
    )annotationsN)CallableHashableIterableMappingSequence)deepcopy)Any)config)KeysCollectionNdarrayOrTensorPathLike)CSVSaver)
MetaTensor)InvertibleTransform)Activations
AsDiscreteDistanceTransformEDT	FillHolesKeepLargestConnectedComponentLabelFilterLabelToContourMeanEnsembleProbNMSRemoveSmallObjectsSobelGradientsVoteEnsemble)MapTransform)ToTensor)allow_missing_keys_modeconvert_applied_interp_mode)PostFixconvert_to_tensorensure_tupleensure_tuple_rep)-ActivationsDActivationsDictActivationsdAsDiscreteDAsDiscreteDictAsDiscreted	Ensembled	EnsembleDEnsembleDict
FillHolesDFillHolesDict
FillHolesdInvertD
InvertDictInvertdKeepLargestConnectedComponentD!KeepLargestConnectedComponentDictKeepLargestConnectedComponentdRemoveSmallObjectsDRemoveSmallObjectsDictRemoveSmallObjectsdLabelFilterDLabelFilterDictLabelFilterdLabelToContourDLabelToContourDictLabelToContourdMeanEnsembleDMeanEnsembleDictMeanEnsembledProbNMSDProbNMSDictProbNMSdSaveClassificationDSaveClassificationDictSaveClassificationdSobelGradientsDSobelGradientsDictSobelGradientsdVoteEnsembleDVoteEnsembleDictVoteEnsembledDistanceTransformEDTdDistanceTransformEDTDDistanceTransformEDTDictc                      s<   e Zd ZdZejZ				dd fddZdddZ  ZS )r'   z
    Dictionary-based wrapper of :py:class:`monai.transforms.AddActivations`.
    Add activation layers to the input data specified by `keys`.
    FNkeysr   sigmoidSequence[bool] | boolsoftmaxother$Sequence[Callable] | Callable | Noneallow_missing_keysboolreturnNonec                   sX   t  || t|t| j| _t|t| j| _t|t| j| _t | _	|| j	_
dS )a  
        Args:
            keys: keys of the corresponding items to model output and label.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            sigmoid: whether to execute sigmoid function on model output before transform.
                it also can be a sequence of bool, each element corresponds to a key in ``keys``.
            softmax: whether to execute softmax function on model output before transform.
                it also can be a sequence of bool, each element corresponds to a key in ``keys``.
            other: callable function to execute other activation layers,
                for example: `other = torch.tanh`. it also can be a sequence of Callable, each
                element corresponds to a key in ``keys``.
            allow_missing_keys: don't raise exception if key is missing.
            kwargs: additional parameters to `torch.softmax` (used when ``softmax=True``).
                Defaults to ``dim=0``, unrecognized parameters will be ignored.

        N)super__init__r$   lenrR   rS   rU   rV   r   	converterkwargs)selfrR   rS   rU   rV   rX   r`   	__class__ b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/transforms/post/dictionary.pyr]   o   s   zActivationsd.__init__data"Mapping[Hashable, NdarrayOrTensor]dict[Hashable, NdarrayOrTensor]c                 C  sH   t |}| || j| j| jD ]\}}}}| || |||||< q|S N)dictkey_iteratorrS   rU   rV   r_   )ra   rf   dkeyrS   rU   rV   rd   rd   re   __call__   s   "zActivationsd.__call__)FFNF)rR   r   rS   rT   rU   rT   rV   rW   rX   rY   rZ   r[   rf   rg   rZ   rh   )	__name__
__module____qualname____doc__r   backendr]   rn   __classcell__rd   rd   rb   re   r'   g   s     r'   c                      s>   e Zd ZdZejZ					dd fddZdddZ  ZS )r*   zN
    Dictionary-based wrapper of :py:class:`monai.transforms.AsDiscrete`.
    FNrR   r   argmaxrT   	to_onehot!Sequence[int | None] | int | None	threshold%Sequence[float | None] | float | Nonerounding!Sequence[str | None] | str | NonerX   rY   rZ   r[   c           	        s   t  || t|t| j| _g | _t|t| jD ]}t|tr&t	d| j
| qg | _t|t| jD ]}t|trCt	d| j
| q8t|t| j| _t | _|| j_dS )aa  
        Args:
            keys: keys of the corresponding items to model output and label.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            argmax: whether to execute argmax function on input data before transform.
                it also can be a sequence of bool, each element corresponds to a key in ``keys``.
            to_onehot: if not None, convert input data into the one-hot format with specified number of classes.
                defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.
            threshold: if not None, threshold the float values to int number 0 or 1 with specified threshold value.
                defaults to ``None``. it also can be a sequence, each element corresponds to a key in ``keys``.
            rounding: if not None, round the data according to the specified option,
                available options: ["torchrounding"]. it also can be a sequence of str or None,
                each element corresponds to a key in ``keys``.
            allow_missing_keys: don't raise exception if key is missing.
            kwargs: additional parameters to ``AsDiscrete``.
                ``dim``, ``keepdim``, ``dtype`` are supported, unrecognized parameters will be ignored.
                These default to ``0``, ``True``, ``torch.float`` respectively.

        zQ`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.zR`threshold_values=True/False` is deprecated, please use `threshold=value` instead.N)r\   r]   r$   r^   rR   rv   rw   
isinstancerY   
ValueErrorappendry   r{   r   r_   r`   )	ra   rR   rv   rw   ry   r{   rX   r`   flagrb   rd   re   r]      s   

zAsDiscreted.__init__rf   rg   rh   c                 C  sP   t |}| || j| j| j| jD ]\}}}}}| || ||||||< q|S ri   )rj   rk   rv   rw   ry   r{   r_   )ra   rf   rl   rm   rv   rw   ry   r{   rd   rd   re   rn      s   zAsDiscreted.__call__)FNNNF)rR   r   rv   rT   rw   rx   ry   rz   r{   r|   rX   rY   rZ   r[   ro   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   r*      s    /r*   c                      s@   e Zd ZdZejZ						dd fddZdddZ  ZS )r6   za
    Dictionary-based wrapper of :py:class:`monai.transforms.KeepLargestConnectedComponent`.
    NT   FrR   r   applied_labelsSequence[int] | int | None	is_onehotbool | NoneindependentrY   connectivity
int | Nonenum_componentsintrX   rZ   r[   c                   s&   t  || t|||||d| _dS )a  
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            applied_labels: Labels for applying the connected component analysis on.
                If given, voxels whose value is in this list will be analyzed.
                If `None`, all non-zero values will be analyzed.
            is_onehot: if `True`, treat the input data as OneHot format data, otherwise, not OneHot format data.
                default to None, which treats multi-channel data as OneHot and single channel data as not OneHot.
            independent: whether to treat ``applied_labels`` as a union of foreground labels.
                If ``True``, the connected component analysis will be performed on each foreground label independently
                and return the intersection of the largest components.
                If ``False``, the analysis will be performed on the union of foreground labels.
                default is `True`.
            connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
                Accepted values are ranging from  1 to input.ndim. If ``None``, a full
                connectivity of ``input.ndim`` is used. for more details:
                https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
            num_components: The number of largest components to preserve.
            allow_missing_keys: don't raise exception if key is missing.

        )r   r   r   r   r   N)r\   r]   r   r_   )ra   rR   r   r   r   r   r   rX   rb   rd   re   r]      s    z'KeepLargestConnectedComponentd.__init__rf   rg   rh   c                 C  .   t |}| |D ]}| || ||< q	|S ri   rj   rk   r_   ra   rf   rl   rm   rd   rd   re   rn        z'KeepLargestConnectedComponentd.__call__)NNTNr   F)rR   r   r   r   r   r   r   rY   r   r   r   r   rX   rY   rZ   r[   ro   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   r6      s    )r6   c                      s@   e Zd ZdZejZ						dd fddZdddZ  ZS )r9   a  
    Dictionary-based wrapper of :py:class:`monai.transforms.RemoveSmallObjectsd`.

    Args:
        min_size: objects smaller than this size (in number of voxels; or surface area/volume value
            in whatever units your image is if by_measure is True) are removed.
        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
            Accepted values are ranging from  1 to input.ndim. If ``None``, a full
            connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image
            documentation.
        independent_channels: Whether or not to consider channels as independent. If true, then
            conjoining islands from different labels will be removed if they are below the threshold.
            If false, the overall size islands made from all non-background voxels will be used.
        by_measure: Whether the specified min_size is in number of voxels. if this is True then min_size
            represents a surface area or volume value of whatever units your image is in (mm^3, cm^2, etc.)
            default is False. e.g. if min_size is 3, by_measure is True and the units of your data is mm,
            objects smaller than 3mm^3 are removed.
        pixdim: the pixdim of the input image. if a single number, this is used for all axes.
            If a sequence of numbers, the length of the sequence must be equal to the image dimensions.
    @   r   TFNrR   r   min_sizer   r   independent_channelsrY   
by_measurepixdim+Sequence[float] | float | np.ndarray | NonerX   rZ   r[   c                   s$   t  || t|||||| _d S ri   )r\   r]   r   r_   )ra   rR   r   r   r   r   r   rX   rb   rd   re   r]   $  s   
zRemoveSmallObjectsd.__init__rf   rg   rh   c                 C  r   ri   r   r   rd   rd   re   rn   1  r   zRemoveSmallObjectsd.__call__)r   r   TFNF)rR   r   r   r   r   r   r   rY   r   rY   r   r   rX   rY   rZ   r[   ro   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   r9     s    r9   c                      s6   e Zd ZdZejZ	dd fddZdddZ  ZS )r<   zO
    Dictionary-based wrapper of :py:class:`monai.transforms.LabelFilter`.
    FrR   r   r   Sequence[int] | intrX   rY   rZ   r[   c                   s   t  || t|| _dS )a%  
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            applied_labels: Label(s) to filter on.
            allow_missing_keys: don't raise exception if key is missing.

        N)r\   r]   r   r_   )ra   rR   r   rX   rb   rd   re   r]   ?  s   zLabelFilterd.__init__rf   rg   rh   c                 C  r   ri   r   r   rd   rd   re   rn   M  r   zLabelFilterd.__call__)F)rR   r   r   r   rX   rY   rZ   r[   ro   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   r<   8  s    r<   c                      s:   e Zd ZdZejZ			dd fddZdddZ  ZS )r0   zM
    Dictionary-based wrapper of :py:class:`monai.transforms.FillHoles`.
    NFrR   r   r   Iterable[int] | int | Noner   r   rX   rY   rZ   r[   c                   s    t  || t||d| _dS )a  
        Initialize the connectivity and limit the labels for which holes are filled.

        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            applied_labels (Optional[Union[Iterable[int], int]], optional): Labels for which to fill holes. Defaults to None,
                that is filling holes for all labels.
            connectivity (int, optional): Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
                Accepted values are ranging from  1 to input.ndim. Defaults to a full
                connectivity of ``input.ndim``.
            allow_missing_keys: don't raise exception if key is missing.
        )r   r   N)r\   r]   r   r_   )ra   rR   r   r   rX   rb   rd   re   r]   [  s   zFillHolesd.__init__rf   rg   rh   c                 C  r   ri   r   r   rd   rd   re   rn   r  r   zFillHolesd.__call__)NNF)
rR   r   r   r   r   r   rX   rY   rZ   r[   ro   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   r0   T  s    r0   c                      s4   e Zd ZdZejZdd fddZdddZ  ZS )r?   zR
    Dictionary-based wrapper of :py:class:`monai.transforms.LabelToContour`.
    LaplaceFrR   r   kernel_typestrrX   rY   rZ   r[   c                   s   t  || t|d| _dS )aJ  
        Args:
            keys: keys of the corresponding items to be transformed.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            kernel_type: the method applied to do edge detection, default is "Laplace".
            allow_missing_keys: don't raise exception if key is missing.

        )r   N)r\   r]   r   r_   )ra   rR   r   rX   rb   rd   re   r]     s   	zLabelToContourd.__init__rf   rg   rh   c                 C  r   ri   r   r   rd   rd   re   rn     r   zLabelToContourd.__call__)r   F)rR   r   r   r   rX   rY   rZ   r[   ro   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   r?   y  s
    r?   c                      sJ   e Zd ZdZeeejeej@ Z		dd fddZ	dddZ
  ZS )r+   z>
    Base class of dictionary-based ensemble transforms.

    NFrR   r   ensembleHCallable[[Sequence[NdarrayOrTensor] | NdarrayOrTensor], NdarrayOrTensor]
output_key
str | NonerX   rY   rZ   r[   c                   sr   t  || t|stdt|j d|| _t| jdkr(|du r(t	d|dur1|| _
dS | jd | _
dS )a  
        Args:
            keys: keys of the corresponding items to be stack and execute ensemble.
                if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`.
            output_key: the key to store ensemble result in the dictionary.
            ensemble: callable method to execute ensemble on specified data.
                if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default.
            allow_missing_keys: don't raise exception if key is missing.

        Raises:
            TypeError: When ``ensemble`` is not ``callable``.
            ValueError: When ``len(keys) > 1`` and ``output_key=None``. Incompatible values.

        z!ensemble must be callable but is .r   Nz<Incompatible values: len(self.keys) > 1 and output_key=None.r   )r\   r]   callable	TypeErrortyperp   r   r^   rR   r~   r   )ra   rR   r   r   rX   rb   rd   re   r]     s   "zEnsembled.__init__rf   rg   rh   c                   sl   t | t| jdkr| jd  v r | jd  }n fdd|  D }t|dkr4| | | j<  S )Nr   r   c                   s   g | ]} | qS rd   rd   ).0rm   rl   rd   re   
<listcomp>  s    z&Ensembled.__call__.<locals>.<listcomp>)rj   r^   rR   rk   r   r   )ra   rf   itemsrd   r   re   rn     s   zEnsembled.__call__)NF)
rR   r   r   r   r   r   rX   rY   rZ   r[   ro   )rp   rq   rr   rs   listsetr   rt   r   r]   rn   ru   rd   rd   rb   re   r+     s    r+   c                      s.   e Zd ZdZejZ		dd fddZ  ZS )rB   zP
    Dictionary-based wrapper of :py:class:`monai.transforms.MeanEnsemble`.
    NrR   r   r   r   weights(Sequence[float] | NdarrayOrTensor | NonerZ   r[   c                      t |d}t ||| dS )a  
        Args:
            keys: keys of the corresponding items to be stack and execute ensemble.
                if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`.
            output_key: the key to store ensemble result in the dictionary.
                if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default.
            weights: can be a list or tuple of numbers for input data with shape: [E, C, H, W[, D]].
                or a Numpy ndarray or a PyTorch Tensor data.
                the `weights` will be added to input data from highest dimension, for example:
                1. if the `weights` only has 1 dimension, it will be added to the `E` dimension of input data.
                2. if the `weights` has 2 dimensions, it will be added to `E` and `C` dimensions.
                it's a typical practice to add weights for different classes:
                to ensemble 3 segmentation model outputs, every output has 4 channels(classes),
                so the input data shape can be: [3, 4, H, W, D].
                and add different `weights` for different classes, so the `weights` shape can be: [3, 4].
                for example: `weights = [[1, 2, 3, 4], [4, 3, 2, 1], [1, 1, 1, 1]]`.

        )r   N)r   r\   r]   )ra   rR   r   r   r   rb   rd   re   r]     s   
zMeanEnsembled.__init__NN)rR   r   r   r   r   r   rZ   r[   )rp   rq   rr   rs   r   rt   r]   ru   rd   rd   rb   re   rB     s    rB   c                      s*   e Zd ZdZejZdd fddZ  ZS )rN   zP
    Dictionary-based wrapper of :py:class:`monai.transforms.VoteEnsemble`.
    NrR   r   r   r   num_classesr   rZ   r[   c                   r   )aK  
        Args:
            keys: keys of the corresponding items to be stack and execute ensemble.
                if only 1 key provided, suppose it's a PyTorch Tensor with data stacked on dimension `E`.
            output_key: the key to store ensemble result in the dictionary.
                if only 1 key provided in `keys`, `output_key` can be None and use `keys` as default.
            num_classes: if the input is single channel data instead of One-Hot, we can't get class number
                from channel, need to explicitly specify the number of classes to vote.

        )r   N)r   r\   r]   )ra   rR   r   r   r   rb   rd   re   r]     s   
zVoteEnsembled.__init__r   )rR   r   r   r   r   r   rZ   r[   )rp   rq   rr   rs   r   rt   r]   ru   rd   rd   rb   re   rN     s    rN   c                      s>   e Zd ZdZejZ					dd fddZdddZ  ZS )rE   a-  
    Performs probability based non-maximum suppression (NMS) on the probabilities map via
    iteratively selecting the coordinate with highest probability and then move it as well
    as its surrounding values. The remove range is determined by the parameter `box_size`.
    If multiple coordinates have the same highest probability, only one of them will be
    selected.

    Args:
        spatial_dims: number of spatial dimensions of the input probabilities map.
            Defaults to 2.
        sigma: the standard deviation for gaussian filter.
            It could be a single value, or `spatial_dims` number of values. Defaults to 0.0.
        prob_threshold: the probability threshold, the function will stop searching if
            the highest probability is no larger than the threshold. The value should be
            no less than 0.0. Defaults to 0.5.
        box_size: the box size (in pixel) to be removed around the pixel with the maximum probability.
            It can be an integer that defines the size of a square or cube,
            or a list containing different values for each dimensions. Defaults to 48.

    Return:
        a list of selected lists, where inner lists contain probability and coordinates.
        For example, for 3D input, the inner lists are in the form of [probability, x, y, z].

    Raises:
        ValueError: When ``prob_threshold`` is less than 0.0.
        ValueError: When ``box_size`` is a list or tuple, and its length is not equal to `spatial_dims`.
        ValueError: When ``box_size`` has a less than 1 value.

                     ?0   FrR   r   spatial_dimsr   sigma?Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensorprob_thresholdfloatbox_sizeint | Sequence[int]rX   rY   rZ   r[   c                   s$   t  || t||||d| _d S )N)r   r   r   r   )r\   r]   r   prob_nms)ra   rR   r   r   r   r   rX   rb   rd   re   r]      s   	zProbNMSd.__init__rf   rg   c                 C  r   ri   )rj   rk   r   r   rd   rd   re   rn   .  r   zProbNMSd.__call__)r   r   r   r   F)rR   r   r   r   r   r   r   r   r   r   rX   rY   rZ   r[   )rf   rg   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   rE     s    rE   c                	      s@   e Zd ZdZdddedddddf	d! fddZd"dd Z  ZS )#r3   a  
    Utility transform to invert the previously applied transforms.

    Taking the ``transform`` previously applied on ``orig_keys``, this ``Invertd`` will apply the inverse of it
    to the data stored at ``keys``.

    ``Invertd``'s output will also include a copy of the metadata
    dictionary (originally from  ``orig_meta_keys`` or the metadata of ``orig_keys``),
    with the relevant fields inverted and stored at ``meta_keys``.

    A typical usage is to apply the inverse of the preprocessing (``transform=preprocessings``) on
    input ``orig_keys=image`` to the model predictions ``keys=pred``.

    A detailed usage example is available in the tutorial:
    https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/torch/unet_inference_dict.py

    Note:

        - The output of the inverted data and metadata will be stored at ``keys`` and ``meta_keys`` respectively.
        - To correctly invert the transforms, the information of the previously applied transforms should be
          available at ``{orig_keys}_transforms``, and the original metadata at ``orig_meta_keys``.
          (``meta_key_postfix`` is an optional string to conveniently construct "meta_keys" and/or "orig_meta_keys".)
          see also: :py:class:`monai.transforms.TraceableTransform`.
        - The transform will not change the content in ``orig_keys`` and ``orig_meta_key``.
          These keys are only used to represent the data status of ``key`` before inverting.

    NTFrR   r   	transformr   	orig_keysKeysCollection | None	meta_keysorig_meta_keysmeta_key_postfixr   nearest_interpbool | Sequence[bool]	to_tensordevice8str | torch.device | Sequence[str | torch.device] | None	post_func$Callable | Sequence[Callable] | NonerX   rY   rZ   r[   c                   s   t  || t|tstd|| _|durt|t| jn| j| _	|du r/tdt| jnt
|| _t| jt| jkrBtdt|t| j| _t|t| j| _t|t| j| _t|t| j| _t|	t| j| _t|
t| j| _t | _dS )aD
  
        Args:
            keys: the key of expected data in the dict, the inverse of ``transforms`` will be applied on it in-place.
                It also can be a list of keys, will apply the inverse transform respectively.
            transform: the transform applied to ``orig_key``, its inverse will be applied on ``key``.
            orig_keys: the key of the original input data in the dict. These keys default to `self.keys` if not set.
                the transform trace information of ``transforms`` should be stored at ``{orig_keys}_transforms``.
                It can also be a list of keys, each matches the ``keys``.
            meta_keys: The key to output the inverted metadata dictionary.
                The metadata is a dictionary optionally containing: filename, original_shape.
                It can be a sequence of strings, maps to ``keys``.
                If None, will try to create a metadata dict with the default key: `{key}_{meta_key_postfix}`.
            orig_meta_keys: the key of the metadata of original input data.
                The metadata is a dictionary optionally containing: filename, original_shape.
                It can be a sequence of strings, maps to the `keys`.
                If None, will try to create a metadata dict with the default key: `{orig_key}_{meta_key_postfix}`.
                This metadata dict will also be included in the inverted dict, stored in `meta_keys`.
            meta_key_postfix: if `orig_meta_keys` is None, use `{orig_key}_{meta_key_postfix}` to fetch the
                metadata from dict, if `meta_keys` is None, use `{key}_{meta_key_postfix}`. Default: ``"meta_dict"``.
            nearest_interp: whether to use `nearest` interpolation mode when inverting the spatial transforms,
                default to `True`. If `False`, use the same interpolation mode as the original transform.
                It also can be a list of bool, each matches to the `keys` data.
            to_tensor: whether to convert the inverted data into PyTorch Tensor first, default to `True`.
                It also can be a list of bool, each matches to the `keys` data.
            device: if converted to Tensor, move the inverted results to target device before `post_func`,
                default to None, it also can be a list of string or `torch.device`, each matches to the `keys` data.
            post_func: post processing for the inverted data, should be a callable function.
                It also can be a list of callable, each matches to the `keys` data.
            allow_missing_keys: don't raise exception if key is missing.

        zAtransform is not invertible, can't invert transform for the data.Nz.meta_keys should have the same length as keys.)r\   r]   r}   r   r~   r   r$   r^   rR   r   r#   r   r   r   r   r   r   r   r   	_totensor)ra   rR   r   r   r   r   r   r   r   r   r   rX   rb   rd   re   r]   R  s   -
 "zInvertd.__init__rf   Mapping[Hashable, Any]dict[Hashable, Any]c                 C  s  t |}| || j| j| j| j| j| j| j| j		D ],\	}}}}}}}	}
}t
|| tr>||vr=td| d| d qnt|}||vrQtd| d q|pY| d| }||v rpt
|| trp|| j}|| j}n|t| }||i }|rt|dd d}|| }t
|tjr| }t
|tst|dd	}t||_t||_||i}tjr||t|< ||t|< t| j | j|}W d    n1 sw   Y  || }|	rt
|ts| |}t
|t j!r|
d urt|
j"d
krt#d|
 dt
|tjr|j$|
d}t%|r||n|||< t||v r0|j|t|< ||v rF|p>| d| }||||< q|S )Nztransform info of `z!` is not available in MetaTensor r   z5` is not available or no InvertibleTransform applied._nearest)
trans_infomodealign_cornersT)
track_metacpuzEInverted data with type of 'numpy.ndarray' support device='cpu', got )r   )&rj   rk   r   r   r   r   r   r   r   r   r}   r   warningswarnr   	trace_keyapplied_operationsmetagetr    torchTensordetachr"   r   r
   USE_META_DICTr!   r   r   inverser   npndarrayr   r~   tor   )ra   rf   rl   rm   orig_keymeta_keyZorig_meta_keyr   r   r   r   r   transform_keytransform_info	meta_infoinputs
input_dictinvertedZinverted_datard   rd   re   rn     s   





*
zInvertd.__call__)rR   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rX   rY   rZ   r[   )rf   r   rZ   r   )rp   rq   rr   rs   DEFAULT_POST_FIXr]   rn   ru   rd   rd   rb   re   r3   5  s     =r3   c                	      sF   e Zd ZdZdedddddddf	d  fddZdd Zdd Z  ZS )!rH   zW
    Save the classification results and metadata into CSV file or other storage.

    Nz./zpredictions.csv,TFrR   r   r   r   r   r   saverCSVSaver | None
output_dirr   filename	delimiter	overwriterY   flushrX   rZ   r[   c                   sj   t  ||
 t| jdkrtd|pt||||	|d| _|	| _t|t| j| _	t|t| j| _
dS )am	  
        Args:
            keys: keys of the corresponding items to model output, this transform only supports 1 key.
                See also: :py:class:`monai.transforms.compose.MapTransform`
            meta_keys: explicitly indicate the key of the corresponding metadata dictionary.
                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
                the metadata is a dictionary object which contains: filename, original_shape, etc.
                it can be a sequence of string, map to the `keys`.
                if None, will try to construct meta_keys by `key_{meta_key_postfix}`.
                will extract the filename of input image to save classification results.
            meta_key_postfix: `key_{postfix}` was used to store the metadata in `LoadImaged`.
                so need the key to extract the metadata of input image, like filename, etc. default is `meta_dict`.
                for example, for data with key `image`, the metadata by default is in `image_meta_dict`.
                the metadata is a dictionary object which contains: filename, original_shape, etc.
                this arg only works when `meta_keys=None`. if no corresponding metadata, set to `None`.
            saver: the saver instance to save classification results, if None, create a CSVSaver internally.
                the saver must provide `save(data, meta_data)` and `finalize()` APIs.
            output_dir: if `saver=None`, specify the directory to save the CSV file.
            filename: if `saver=None`, specify the name of the saved CSV file.
            delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
                to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
            overwrite: if `saver=None`, indicate whether to overwriting existing CSV file content, if True,
                will clear the file before saving. otherwise, will append new content to the CSV file.
            flush: if `saver=None`, indicate whether to write the cache data to CSV file immediately
                in this transform and clear the cache. default to True.
                If False, may need user to call `saver.finalize()` manually or use `ClassificationSaver` handler.
            allow_missing_keys: don't raise exception if key is missing.

        r   z<only 1 key is allowed when saving the classification result.)r   r   r   r   r   N)r\   r]   r^   rR   r~   r   r   r   r$   r   r   )ra   rR   r   r   r   r   r   r   r   r   rX   rb   rd   re   r]     s   *
zSaveClassificationd.__init__c                 C  s   t |}| || j| jD ]0\}}}|d u r!|d ur!| d| }|d ur)|| nd }| jj|| |d | jr=| j  q|S )Nr   )rf   	meta_data)rj   rk   r   r   r   saver   finalize)ra   rf   rl   rm   r   r   r   rd   rd   re   rn     s   
zSaveClassificationd.__call__c                 C  s   | j S )z
        If want to write content into file, may need to call `finalize` of saver when epoch completed.
        Or users can also get the cache content from `saver` instead of writing into file.

        )r   )ra   rd   rd   re   	get_saver&  s   zSaveClassificationd.get_saver)rR   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rY   r   rY   rX   rY   rZ   r[   )	rp   rq   rr   rs   r   r]   rn   r   ru   rd   rd   rb   re   rH     s    4rH   c                      sF   e Zd ZdZejZdddddejddfd  fddZd!ddZ	  Z
S )"rK   a  Calculate Sobel horizontal and vertical gradients of a grayscale image.

    Args:
        keys: keys of the corresponding items to model output.
        kernel_size: the size of the Sobel kernel. Defaults to 3.
        spatial_axes: the axes that define the direction of the gradient to be calculated. It calculate the gradient
            along each of the provide axis. By default it calculate the gradient for all spatial axes.
        normalize_kernels: if normalize the Sobel kernel to provide proper gradients. Defaults to True.
        normalize_gradients: if normalize the output gradient to 0 and 1. Defaults to False.
        padding_mode: the padding mode of the image when convolving with Sobel kernels. Defaults to `"reflect"`.
            Acceptable values are ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``.
            See ``torch.nn.Conv1d()`` for more information.
        dtype: kernel data type (torch.dtype). Defaults to `torch.float32`.
        new_key_prefix: this prefix be prepended to the key to create a new key for the output and keep the value of
            key intact. By default not prefix is set and the corresponding array to the key will be replaced.
        allow_missing_keys: don't raise exception if key is missing.

       NTFreflectrR   r   kernel_sizer   spatial_axesr   normalize_kernelsrY   normalize_gradientspadding_moder   dtypetorch.dtypenew_key_prefixr   rX   rZ   r[   c
           
        sB   t  ||	 t||||||d| _|| _| jj| _| jj| _d S )N)r   r   r   r   r   r   )r\   r]   r   r   r   kernel_diffkernel_smooth)
ra   rR   r   r   r   r   r   r   r   rX   rb   rd   re   r]   E  s   
zSobelGradientsd.__init__rf   rg   rh   c                 C  sF   t |}| |D ]}| jd u r|n| j| }| || ||< q	|S ri   )rj   rk   r   r   )ra   rf   rl   rm   new_keyrd   rd   re   rn   ^  s
   zSobelGradientsd.__call__)rR   r   r   r   r   r   r   rY   r   rY   r   r   r   r   r   r   rX   rY   rZ   r[   ro   )rp   rq   rr   rs   r   rt   r   float32r]   rn   ru   rd   rd   rb   re   rK   /  s    rK   c                      s6   e Zd ZdZejZ	dd fddZdddZ  ZS )rO   a  
    Applies the Euclidean distance transform on the input.
    Either GPU based with CuPy / cuCIM or CPU based with scipy.
    To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.

    Note that the results of the libraries can differ, so stick to one if possible.
    For details, check out the `SciPy`_ and `cuCIM`_ documentation and / or :func:`monai.transforms.utils.distance_transform_edt`.


    Note on the input shape:
        Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
        Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
        Input gets passed channel-wise to the distance-transform, thus results from this function will differ
        from directly calling ``distance_transform_edt()`` in CuPy or SciPy.

    Args:
        keys: keys of the corresponding items to be transformed.
        allow_missing_keys: don't raise exception if key is missing.
        sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
            if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.

    .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
    .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt


    FNrR   r   rX   rY   samplingNone | float | list[float]rZ   r[   c                   s&   t  || || _t| jd| _d S )N)r  )r\   r]   r  r   distance_transform)ra   rR   rX   r  rb   rd   re   r]     s   zDistanceTransformEDTd.__init__rf   rg   c                 C  s0   t |}| |D ]}| j|| d||< q	|S )N)img)rj   rk   r  r   rd   rd   re   rn     s   zDistanceTransformEDTd.__call__)FN)rR   r   rX   rY   r  r  rZ   r[   )rf   rg   rZ   rg   )	rp   rq   rr   rs   r   rt   r]   rn   ru   rd   rd   rb   re   rO   g  s    rO   )frs   
__future__r   r   collections.abcr   r   r   r   r   copyr   typingr	   numpyr   r   monair
   monai.config.type_definitionsr   r   r   Zmonai.data.csv_saverr   monai.data.meta_tensorr   monai.transforms.inverser   monai.transforms.post.arrayr   r   r   r   r   r   r   r   r   r   r   r   monai.transforms.transformr   monai.transforms.utility.arrayr   monai.transforms.utilsr   r    monai.utilsr!   r"   r#   r$   __all__r   r   r'   r*   r6   r9   r<   r0   r?   r+   rB   rN   rE   r3   rH   rK   rO   r%   r&   r(   r)   r.   r/   r1   r2   r4   r5   r7   r8   r:   r;   r=   r>   r@   rA   rC   rD   rF   rG   rL   rM   r,   r-   rI   rJ   rP   rQ   rd   rd   rd   re   <module>   sf   
80/?7,%3#6 ,O8-