U
    Ph
f                 2   @  s  d dl mZ d dlZd dlZd dlZd dlmZmZmZm	Z	m
Z
 d dlmZ d dlmZmZ d dlmZmZ d dlmZ d dlZd dlZd dlZd dlmZmZ d d	lmZmZ d d
lm Z  d dl!m"Z" d dl#m$Z$ d dl%m&Z&m'Z'm(Z( d dl)m*Z*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0m1Z1m2Z2m3Z3m4Z4 d dl5m6Z6m7Z7m8Z8m9Z9m:Z:m;Z;m<Z<m=Z=m>Z>m?Z?m@Z@mAZAmBZBmCZCmDZDmEZEmFZFmGZGmHZHmIZImJZJ d dlKmLZL d dlMmNZNmOZOmPZPmQZQmRZR eIddeH\ZSZTeId\ZUZVeId\ZWZXeId\ZYZZeIddd\Z[Z\eId\Z]Z^ddddddd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:d;d<d=d>d?d@dAdBdCdDdEdFdGdHdIdJdKg2Z_dLddej`fdMdNdOdOdPdMdQdRdKZad!dNdTdUdVd6ZbdWdXdYd/ZcdNdNdNdNdNdTdZd[d0Zdd\dTd]d^d1Zed_d2 ZfdWd`dTdadbd<ZgdcdLej`fdMdddddPdMdedfd7ZhdcdLej`fdWdddddgdWdedhd9ZiejjfdWdgdWdidjd8ZkdkdkdldmdndodZldcdpdqdWdrdNdTdsdtd:Zmd"dMdvdTdwdxdydZnd#dMdzdNd{d|d}d3Zod$dMdrdzdNdrd~ddd4Zpd%ddMd`ddddd;Zqd&dddkdTddddZrd'dd`dNdkdMdMddTdd	dd*Zsd(dd`dkddddTdTdd	dd)ZtddpeudeLjvfdkddTdPddMddd!ZwddpeufdkddTdPdddZxddpej`dfdkddTddddZydpeudeLjvfdkddTdgdddd ZzdeLjvfd`ddddMddd"Z{ej|ej}ej~fd`dddddMdddZdeLjvfd`dddMddd$Zej~fd`ddMdddZdeLjvfd`dddMddd#Zejfd`ddMdddZdeLjvfd`dddMddd%Zej~ejfd`ddMdddZe@ddpdudddefdd dpfdMddddTdddd+Zd)ddrd`dddd-Zd*dd`d`dTdTdddƜdd.Zd+dMdTdddʜddCZd,dWddrdWd͜dd'Zd-dMdd`dNddМdd,Zd.ddMddNdNdd՜dd&Zd/d`ddTdd؜dd5ZeddۜddZee8ee6 Zd0dddߜddZddG ZdkddddZd1dWdd`d`d`dWddd=ZG dd( d(Zd2dddd`ddd>Zdd? Zdd@ ZdMddddAZdddddBZd3dTdddDZd4ddEZd5dTdddFZdwdddZd d Zdd Zd6dNdddZdd Zd	d
 Zdd Zdd ZedddeLjfddddHZd7ddddddZd8ddddddIZd9dduddMddTdTdzdzddTdd	ddJZed kre  dS (:      )annotationsN)CallableHashableIterableMappingSequence)contextmanager)	lru_cachewraps)
getmembersisclass)Any)	DtypeLikeIndexSelection)NdarrayOrTensorNdarrayTensor)GaussianFilter)meshgrid_ij)Compose)MapTransform	Transformapply_transform)	any_np_ptascontiguousarraycumsumisfinitenonzeroravelsearchsortedsoftplusuniqueunravel_indexwhere)GridSampleModeGridSamplePadModeInterpolateModeNdimageModeNumpyPadModePostFixPytorchPadMode
SplineMode	TraceKeysTraceStatusKeysdeprecated_arg_defaultensure_tupleensure_tuple_repensure_tuple_sizefall_back_tupleget_equivalent_dtypeissequenceiterablelook_up_optionmin_versionoptional_importpytorch_after)TransformBackends)convert_data_typeconvert_to_cupyconvert_to_dst_typeconvert_to_numpyconvert_to_tensorzskimage.measurez0.14.2zskimage.morphologyzscipy.ndimagecupyndarraynamezskimage.exposureallow_missing_keys_modecheck_boundariescompute_divisible_spatial_sizeconvert_applied_interp_modecopypaste_arrayscheck_non_lazy_pending_opscreate_control_gridcreate_gridcreate_rotatecreate_scalecreate_shearcreate_translateextreme_points_to_image
fill_holesFourier#generate_label_classes_crop_centers#generate_pos_neg_label_crop_centersgenerate_spatial_bounding_boxget_extreme_points$get_largest_connected_component_maskremove_small_objects
img_bounds	in_boundsis_emptyis_positivemap_binary_to_indicesmap_classes_to_indicesmap_spatial_axesrand_choicerescale_arrayrescale_array_int_maxrescale_instance_arrayresize_centerweighted_patch_sampleszero_marginsequalize_hist!get_number_image_type_conversionsget_transform_backendsprint_transform_backendsconvert_pad_modeconvert_to_contiguousget_unique_labelsscale_affineattach_hooksync_meta_inforeset_ops_idresolves_modeshas_status_keysdistance_transform_edt	soft_clip      ?r   floatz$NdarrayOrTensor | float | int | NonezDtypeLike | torch.dtype)arrsharpness_factorminvmaxvdtypereturnc                 C  sb   |dk	rt | |d^} }| }|dk	r>|t| | |  |  }|dk	r^|t|| |  |  }|S )a  
    Apply soft clip to the input array or tensor.
    The intensity values will be soft clipped according to
    f(x) = x + (1/sharpness_factor)*softplus(- c(x - minv)) - (1/sharpness_factor)*softplus(c(x - maxv))
    From https://medium.com/life-at-hopper/clip-it-clip-it-good-1f1bf711b291

    To perform one-sided clipping, set either minv or maxv to None.
    Args:
        arr: input array to clip.
        sharpness_factor: the sharpness of the soft clip function, default to 1.
        minv: minimum value of target clipped array.
        maxv: maximum value of target clipped array.
        dtype: if not None, convert input array to dtype before computation.

    Nrz   )r9   r   )rv   rw   rx   ry   rz   _v r   K/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/utils.pyrs      s          ?bool)probr{   c                 C  s   t t | kS )zv
    Returns True if a randomly chosen number is less than or equal to `prob`, by default this is a 50/50 chance.
    )r   random)r   r   r   r   r^      s    z
np.ndarrayimgc                 C  sN   t j| dd}t j| dd}t t |d ddg t |d ddg fS )zt
    Returns the minimum and maximum indices of non-zero lines in axis 0 of `img`, followed by that for axis 1.
    r   )axis   )npanyconcatenater"   )r   ax0ax1r   r   r   rW      s    )xymarginmaxxmaxyr{   c                 C  s<   t ||   ko|| k n  o8||  ko4|| k n  S )zc
    Returns True if (x,y) is within the rectangle (margin, margin, maxx-margin, maxy-margin).
    )r   )r   r   r   r   r   r   r   r   rX      s    znp.ndarray | torch.Tensor)r   r{   c                 C  s   |   |  k S )zd
    Returns True if `img` is empty, that is its maximum value is not greater than its minimum.
    )maxminr   r   r   r   rY      s    c                 C  s   | dkS )z{
    Returns a boolean version of `img` where the positive values are converted into True, the other values are False.
    r   r   r   r   r   r   rZ      s    int)r   r   r{   c                 C  s   t | ddddd|f sBt | dddd| df rFdS t | ddd|ddf  ot | dd| dddf  S )zo
    Returns True if the values within `margin` indices of the edges of `img` in dimensions 1 and 2 are 0.
    NF)r   r   )r   r   r   r   r   rd      s    B        zfloat | None)rv   rx   ry   rz   r{   c                 C  sx   |dk	rt | |d^} }|  }|  }||krD|dk	r@| | S | S | | ||  }|dksd|dkrh|S |||  | S )a  
    Rescale the values of numpy array `arr` to be from `minv` to `maxv`.
    If either `minv` or `maxv` is None, it returns `(a - min_a) / (max_a - min_a)`.

    Args:
        arr: input array to rescale.
        minv: minimum value of target rescaled array.
        maxv: maximum value of target rescaled array.
        dtype: if not None, convert input array to dtype before computation.

    Nr|   )r9   r   r   )rv   rx   ry   rz   r}   Zminamaxanormr   r   r   r_      s    r   c                 C  sB   t | j|p| j}t| jd D ]}t| | |||||< q"|S )zT
    Rescale each array slice along the first dimension of `arr` independently.
    r   )r   zerosshaperz   ranger_   )rv   rx   ry   rz   outir   r   r   ra      s    )rv   rz   r{   c                 C  s0   t |p| j}t jt| |j|j|p*| jdS )zc
    Rescale the array `arr` to be between the minimum and maximum values of the type `dtype`.
    r|   )r   iinforz   asarrayr_   r   r   )rv   rz   infor   r   r   r`   	  s    zSequence[int]zSequence[int | None]z+tuple[tuple[slice, ...], tuple[slice, ...]])	srccenter
destcenterdimsr{   c              	   C  s   t | }t |}tdg| }tdg| }tt|| ||||D ]|\}	}
}}}}|rBt|d dt||}t|d d dt|
| || }t|| || ||	< t|| || ||	< qBt|t|fS )a  
    Calculate the slices to copy a sliced area of array in `src_shape` into array in `dest_shape`.

    The area has dimensions `dims` (use 0 or None to copy everything in that dimension),
    the source area is centered at `srccenter` index in `src` and copied into area centered at `destcenter` in `dest`.
    The dimensions of the copied area will be clipped to fit within the
    source and destination arrays so a smaller area may be copied than expected. Return value is the tuples of slice
    objects indexing the copied area in `src`, and those indexing the copy area in `dest`.

    Example

    .. code-block:: python

        src_shape = (6,6)
        src = np.random.randint(0,10,src_shape)
        dest = np.zeros_like(src)
        srcslices, destslices = copypaste_arrays(src_shape, dest.shape, (3, 2),(2, 1),(3, 4))
        dest[destslices] = src[srcslices]
        print(src)
        print(dest)

        >>> [[9 5 6 6 9 6]
             [4 3 5 6 1 2]
             [0 7 3 2 4 1]
             [3 0 0 1 5 1]
             [9 4 7 1 8 2]
             [6 6 5 8 6 7]]
            [[0 0 0 0 0 0]
             [7 3 2 4 0 0]
             [0 0 1 5 0 0]
             [4 7 1 8 0 0]
             [0 0 0 0 0 0]
             [0 0 0 0 0 0]]

    N   r   r   )lenslicezipr   r   clipr   tuple)Z	src_shapeZ
dest_shaper   r   r   Zs_ndimZd_ndim	srcslices
destslicesr   ssdsscdcdimd1d2r   r   r   rF     s    &&$T)
fill_valueinplace
int | None)r   resize_dimsr   r   c          	      G  st   t || j}t| jd  }t|d  }t| j||||\}}|slt||| j}| | ||< |S | | S )a  
    Resize `img` by cropping or expanding the image from the center. The `resize_dims` values are the output dimensions
    (or None to use original dimension of `img`). If a dimension is smaller than that of `img` then the result will be
    cropped and if larger padded with zeros, in both cases this is done relative to the center of `img`. The result is
    a new image with the specified dimensions and values from `img` copied into its center.
    r   )r1   r   r   r   tolistrF   fullrz   )	r   r   r   r   Zhalf_img_shapeZhalf_dest_shaper   r   destr   r   r   rb   I  s    Fz
None | strNone)input_arrayrA   raise_errorr{   c                 C  s>   t | tjjr:| jr:d|pd d}|r0t|t| dS )aN  
    Check whether the input array has pending operations, raise an error or warn when it has.

    Args:
        input_array: input array to be checked.
        name: an optional name to be included in the error message.
        raise_error: whether to raise an error, default to False, a warning message will be issued instead.
    zMThe input image is a MetaTensor and has pending operations,
but the function  z1 assumes non-lazy input, result may be incorrect.N)
isinstancemonaidata
MetaTensorpending_operations
ValueErrorwarningswarn)r   rA   r   msgr   r   r   rG   ^  s    zNdarrayOrTensor | Nonez'tuple[NdarrayOrTensor, NdarrayOrTensor])labelimageimage_thresholdr{   c                 C  s   t | dd | jd dkr&| dd } tt| d}t|}|dk	rt |dd tt||kd}t|| td^}}t|| @ }n
t| }t|t	dd^}}t|t	dd^}}||fS )	a  
    Compute the foreground and background of input label data, return the indices after fattening.
    For example:
    ``label = np.array([[[0, 1, 1], [1, 0, 1], [1, 1, 0]]])``
    ``foreground indices = np.array([1, 2, 3, 5, 6, 7])`` and ``background indices = np.array([0, 4, 8])``

    Args:
        label: use the label data to get the foreground/background information.
        image: if image is not None, use ``label = 0 & image > image_threshold``
            to define background. so the output items will not map to all the voxels in the label.
        image_threshold: if enabled `image`, use ``image > image_threshold`` to
            determine the valid image content area and select background only in this area.
    r[   r@   r   r   Nr|   cpudevice)
rG   r   r   r   r   r;   r   r9   torchr   )r   r   r   
label_flat
fg_indicesimg_flatr}   
bg_indicesr   r   r   r[   s  s    
zlist[NdarrayOrTensor])r   num_classesr   r   max_samples_per_classr{   c                 C  sD  t | dd d}|dk	r6t |dd t||kd}t| }|}|dkr^|dkrZtd|}g }t|D ]}	|dkrtt| |	 tdd }
nt| |	k}
|dk	r||
@ }
t| t	j
jrtjnd}tt|
|tdd	d }|r4t||kr4t|dkr4ttdt|d |t}|||  qj|| qj|S )
a`  
    Filter out indices of every class of the input label data, return the indices after fattening.
    It can handle both One-Hot format label and Argmax format label, must provide `num_classes` for
    Argmax label.

    For example:
    ``label = np.array([[[0, 1, 2], [2, 0, 1], [1, 2, 0]]])`` and `num_classes=3`, will return a list
    which contains the indices of the 3 classes:
    ``[np.array([0, 4, 8]), np.array([1, 5, 6]), np.array([2, 3, 7])]``

    Args:
        label: use the label data to get the indices of every class.
        num_classes: number of classes for argmax label, not necessary for One-Hot label.
        image: if image is not None, only return the indices of every class that are within the valid
            region of the image (``image > image_threshold``).
        image_threshold: if enabled `image`, use ``image > image_threshold`` to
            determine the valid image content area and select class indices only in this area.
        max_samples_per_class: maximum length of indices in each class to reduce memory consumption.
            Default is None, no subsampling.

    r\   r@   Nr   r   zSchannels==1 indicates not using One-Hot format label, must provide ``num_classes``.r|   r   output_typer   )rG   r   r   r   r   r   r9   r   r   r   r   r   r   Tensorr   r   r   roundlinspaceastyper   append)r   r   r   r   r   r   channelsZnum_classes_indicescr   r   Zcls_indicesZ	sample_idr   r   r   r\     s>      ""r   zint | Sequence[int]znp.random.RandomState | Nonelist)spatial_sizew	n_samplesr_stater{   c                   sF  t |dd |dkrtd|dkr.tj }tj|jtd}tjt| |td}t	dd t
||D }|| }|jt|}|dk  r|| 8 }t|}|d	 rt|d	 r|d	 dk r|jdt||d
}n*t|||^}	}
t||	|d	  dd}t||tjd^}}
t||d  t |^ }
 fdd|D S )a  
    Computes `n_samples` of random patch sampling locations, given the sampling weight map `w` and patch `spatial_size`.

    Args:
        spatial_size: length of each spatial dimension of the patch.
        w: weight map, the weights must be non-negative. each element denotes a sampling weight of the spatial location.
            0 indicates no sampling.
            The weight map shape is assumed ``(spatial_dim_0, spatial_dim_1, ..., spatial_dim_n)``.
        n_samples: number of patch samples
        r_state: a random state container

    Returns:
        a list of `n_samples` N-D integers representing the spatial sampling location of patches.

    rc   r@   Nz w must be an ND array, got None.r|   c                 s  sJ   | ]B\}}||kr,t |d  || |d   nt |d  |d  d V  qdS )r   r   N)r   ).0r   mr   r   r   	<genexpr>  s     z)weighted_patch_samples.<locals>.<genexpr>r   r   )sizeT)rightr   c                   s   g | ]}t |  qS r   )r!   r   r   diffv_sizer   r   
<listcomp>  s     z*weighted_patch_samples.<locals>.<listcomp>)rG   r   r   r   RandomStater   r   r   r1   r   r   r   r   r   r   r   randintr   r;   r   r   minimum)r   r   r   r   img_sizewin_sizesr~   idxrr}   r   r   r   rc     s,    
 z	list[int]zSequence[int] | intz
tuple[Any])centersr   label_spatial_shapeallow_smallerr{   c                 C  s   t ||d}tt||dk rR|s:td| d| dtdd t||D }t|d}t|td	 |td 	tj
}t|D ]$\}}||| kr||  d	7  < qg }t| ||D ],\}	}
}tt|	|
|d	 }|t| qt|S )
a~  
    Utility to correct the crop center if the crop size and centers are not compatible with the image size.

    Args:
        centers: pre-computed crop centers of every dim, will correct based on the valid region.
        spatial_size: spatial size of the ROIs to be sampled.
        label_spatial_shape: spatial shape of the original label data to compare with ROI.
        allow_smaller: if `False`, an exception will be raised if the image is smaller than
            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
            match the cropped size (i.e., no cropping in that dimension).

    defaultr   zUThe size of the proposed random crop ROI is larger than the image size, got ROI size z and label image size z respectively.c                 s  s   | ]\}}t ||V  qd S N)r   )r   lr   r   r   r   r   !  s     z'correct_crop_centers.<locals>.<genexpr>r   r   )r1   r   r   subtractr   r   r   floor_dividearrayr   uint16	enumerater   r   r   r   r.   )r   r   r   r   valid_startZ	valid_endr   Zvalid_sZvalid_centersr   Zv_sZv_eZcenter_ir   r   r   correct_crop_centers  s"    (r  ztuple[tuple])	r   num_samples	pos_ratior   r   r   
rand_stater   r{   c              	   C  s  |dkrt jjj}g }t|tr*t |n|}t|trBt |n|}t|dkrft|dkrftdt|dks~t|dkrt|dkrdnd}t	dt| dt| d| d t
|D ]P}	| |k r|n|}
|t|
}|
| }t|| }|t|| || qt|S )	a  
    Generate valid sample locations based on the label with option for specifying foreground ratio
    Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]

    Args:
        spatial_size: spatial size of the ROIs to be sampled.
        num_samples: total sample centers to be generated.
        pos_ratio: ratio of total locations generated that have center being foreground.
        label_spatial_shape: spatial shape of the original label data to unravel selected centers.
        fg_indices: pre-computed foreground indices in 1 dimension.
        bg_indices: pre-computed background indices in 1 dimension.
        rand_state: numpy randomState object to align with other modules.
        allow_smaller: if `False`, an exception will be raised if the image is smaller than
            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
            match the cropped size (i.e., no cropping in that dimension).

    Raises:
        ValueError: When the proposed roi is larger than the image.
        ValueError: When the foreground and background indices lengths are 0.

    Nr   zNo sampling location available.r   zNum foregrounds z, Num backgrounds zD, unable to generate class balanced samples, setting `pos_ratio` to .)r   r   __self__r   r   r   r   r   r   r   r   randr   r!   r   r   r  r.   )r   r  r  r   r   r   r  r   r   r}   indices_to_use
random_intr   centerr   r   r   rR   4  s&    
zSequence[NdarrayOrTensor]zlist[float | int] | None)	r   r  r   r   ratiosr  r   r   r{   c                 C  s\  |dkrt jjj}|dk r*td| dtt|dkrDdgt| n|}t|t|krztdt| dt| dtdd |D rtd	| dt|D ]>\}	}
t|
d
kr||	 d
krd
||	< |rt	
d|	 d qg }|jt||t |t | d}|D ]B}	||	 }|t|}t|| | }|t|| || qt|S )a^  
    Generate valid sample locations based on the specified ratios of label classes.
    Valid: samples sitting entirely within image, expected input shape: [C, H, W, D] or [C, H, W]

    Args:
        spatial_size: spatial size of the ROIs to be sampled.
        num_samples: total sample centers to be generated.
        label_spatial_shape: spatial shape of the original label data to unravel selected centers.
        indices: sequence of pre-computed foreground indices of every class in 1 dimension.
        ratios: ratios of every class in the label to generate crop centers, including background class.
            if None, every class will have the same ratio to generate crop centers.
        rand_state: numpy randomState object to align with other modules.
        allow_smaller: if `False`, an exception will be raised if the image is smaller than
            the requested ROI in any dimension. If `True`, any smaller dimensions will be set to
            match the cropped size (i.e., no cropping in that dimension).
        warn: if `True` prints a warning if a class is not present in the label.

    Nr   z:num_samples must be an int number and greater than 0, got r  zDrandom crop ratios must match the number of indices of classes, got z and c                 s  s   | ]}|d k V  qdS )r   Nr   r   r   r   r   r     s     z6generate_label_classes_crop_centers.<locals>.<genexpr>z/ratios should not contain negative number, got r   zno available indices of class z7 to crop, setting the crop ratio of this class to zero.)r   p)r   r   r  r   r   r.   r   r   r  r   r   choicer   sumr   r!   r   r   r  )r   r  r   r   r  r  r   r   Zratios_r   r   r   classesr
  r  r  r   r   r   rQ   n  s6    
"
$zSequence[float] | Noneztorch.device | None)r   spacinghomogeneousrz   r   r{   c                 C  sX   t |t}|pt}|tjkr*t| |||S |tjkrDt| ||||S td| ddS )a  
    compute a `spatial_size` mesh.

        - when ``homogeneous=True``, the output shape is (N+1, dim_size_1, dim_size_2, ..., dim_size_N)
        - when ``homogeneous=False``, the output shape is (N, dim_size_1, dim_size_2, ..., dim_size_N)

    Args:
        spatial_size: spatial size of the grid.
        spacing: same len as ``spatial_size``, defaults to 1.0 (dense grid).
        homogeneous: whether to make homogeneous coordinates.
        dtype: output grid data type, defaults to `float`.
        device: device to compute and store the output (when the backend is "torch").
        backend: APIs to use, ``numpy`` or ``torch``.

    backend  is not supportedN)r4   r8   ru   NUMPY_create_grid_numpyTORCH_create_grid_torchr   )r   r  r  rz   r   backend_backend_dtyper   r   r   rI     s    


)r   r  r  rz   c                 C  sp   |pt dd | D }dd t| |D }tjtj|ddit|tjd}|sT|S t|t|dd	 gS )
z;
    compute a `spatial_size` mesh with the numpy API.
    c                 s  s   | ]
}d V  qdS rt   Nr   r   r}   r   r   r   r     s     z%_create_grid_numpy.<locals>.<genexpr>c                 S  s<   g | ]4\}}t |d   d | |d  d | t|qS )rt          @)r   r   r   r   dr   r   r   r   r     s     z&_create_grid_numpy.<locals>.<listcomp>indexingijr|   Nr   )	r   r   r   r   meshgridr2   r?   r   	ones_like)r   r  r  rz   rangescoordsr   r   r   r    s    	"r  )r   r  r  r   c                   s`   |pt dd | D } fddt| |D }t| }|sFt|S t|t|d fS )z;
    compute a `spatial_size` mesh with the torch API.
    c                 s  s   | ]
}d V  qdS r  r   r  r   r   r   r     s     z%_create_grid_torch.<locals>.<genexpr>c              
     sJ   g | ]B\}}t j|d   d | |d  d | t| tt jdqS )rt   r  r   rz   )r   r   r   r2   r   r   r(  r   r   r     s   
z&_create_grid_torch.<locals>.<listcomp>r   )r   r   r   r   stackr%  )r   r  r  rz   r   r&  r'  r   r(  r   r    s    


r  zSequence[float])spatial_shaper  r  rz   r   c                 C  s   t |ttjk}|rtjntj}g }t| |D ]x\}	}
|rHtj|	|dnt|	}	|	d dkr|	||	d d|
  d d d  q.|	||	d d|
  d d  q.t
||||||dS )	zB
    control grid with two additional point in each direction
    r   r   r   rt   r  r   g      @)r   r  r  rz   r   r  )r4   r8   r  r   ceilr   r   	as_tensorr   r   rI   )r*  r  r  rz   r   r  Ztorch_backendZ	ceil_funcZ
grid_shaper!  r   r   r   r   rH     s     ($     zSequence[float] | floatstr)spatial_dimsradiansr   r  r{   c                   st   t |t}|tjkr,t| |tjtjtjdS |tjkr`t| | fdd fdd fdddS t	d| ddS )	a  
    create a 2D or 3D rotation matrix

    Args:
        spatial_dims: {``2``, ``3``} spatial rank
        radians: rotation radians
            when spatial_dims == 3, the `radians` sequence corresponds to
            rotation in the 1st, 2nd, and 3rd dim respectively.
        device: device to compute and store the output (when the backend is "torch").
        backend: APIs to use, ``numpy`` or ``torch``.

    Raises:
        ValueError: When ``radians`` is empty.
        ValueError: When ``spatial_dims`` is not one of [2, 3].

    )r.  r/  sin_funccos_funceye_funcc                   s   t t j| t j dS N)rz   r   )r   sinr,  float32thr   r   r   <lambda>1      zcreate_rotate.<locals>.<lambda>c                   s   t t j| t j dS r3  )r   cosr,  r5  r6  r   r   r   r8  2  r9  c                   s   t j|  dS Nr   r   eyerankr   r   r   r8  3  r9  r  r  N)
r4   r8   r  _create_rotater   r4  r:  r=  r  r   )r.  r/  r   r  r  r   r   r   rJ     s$    

    



r   )r.  r/  r0  r1  r2  r{   c           
      C  s  t |}| dkrpt|dkrh||d ||d  }}|d}||  |d< |d< || |d< |d< |S td	| dkrd }t|dkr||d ||d  }}|d
}||  |d< |d< || |d< |d< t|dkrB||d ||d  }}|d krtd|d
}	|| |	d< |	d< | | |	d< |	d< ||	 }t|dkr||d ||d  }}|d kr|td|d
}	||  |	d< |	d< || |	d< |	d< ||	 }|d krtd	|S td|  dd S )Nr   r   r      )r   r   r   r   r   r   )r   r   zradians must be non empty.   r   r   r   r   )r   r   zAffine should be a matrix.r   r   r   r   zUnsupported spatial_dims: z, available options are [2, 3].)r.   r   r   )
r.  r/  r0  r1  r2  sin_cos_r   affine_affiner   r   r   r@  8  sH    



r@  )r.  coefsr   r{   c                   sX   t |t}|tjkr$t| |tjdS |tjkrDt| | fdddS td| ddS )a  
    create a shearing matrix

    Args:
        spatial_dims: spatial rank
        coefs: shearing factors, a tuple of 2 floats for 2D, a tuple of 6 floats for 3D),
            take a 3D affine as example::

                [
                    [1.0, coefs[0], coefs[1], 0.0],
                    [coefs[2], 1.0, coefs[3], 0.0],
                    [coefs[4], coefs[5], 1.0, 0.0],
                    [0.0, 0.0, 0.0, 1.0],
                ]

        device: device to compute and store the output (when the backend is "torch").
        backend: APIs to use, ``numpy`` or ``torch``.

    Raises:
        NotImplementedError: When ``spatial_dims`` is not one of [2, 3].

    )r.  rM  r2  c                   s   t j|  dS r;  r<  r>  r   r   r   r8    r9  zcreate_shear.<locals>.<lambda>r  r  N)r4   r8   r  _create_shearr   r=  r  r   )r.  rM  r   r  r  r   r   r   rL   g  s    


  
)r.  rM  r{   c                 C  s   | dkr<t |ddd}|d}|d |d  |d< |d< |S | dkrt |d	dd}|d
}|d |d  |d< |d< |d |d  |d< |d< |d
 |d  |d< |d< |S tdd S )Nr   r   r   pad_valrA  r   r   rB  rC     rD  rG  rE     rH  rF  z4Currently only spatial_dims in [2, 3] are supported.)r0   NotImplementedError)r.  rM  r2  r   r   r   r   rN    s    rN  ztorch.device | str | None)r.  scaling_factorr   r{   c                   sX   t |t}|tjkr$t| |tjdS |tjkrDt| | fdddS td| ddS )a)  
    create a scaling matrix

    Args:
        spatial_dims: spatial rank
        scaling_factor: scaling factors for every spatial dim, defaults to 1.
        device: device to compute and store the output (when the backend is "torch").
        backend: APIs to use, ``numpy`` or ``torch``.
    r.  rT  
array_funcc                   s   t t j|  dS r;  )r   diagr,  r   r   r   r   r8    r9  zcreate_scale.<locals>.<lambda>r  r  N)r4   r8   r  _create_scaler   rW  r  r   )r.  rT  r   r  r  r   r   r   rK     s    



)r.  rT  r{   c                 C  s"   t || dd}||d |  d S )Nrt   rO  rt   )r0   rU  r   r   r   rY    s    rY  )r.  shiftr   r{   c                   sn   t |t}t| } |tjkr0t| |tjtjdS |tjkrZt| | fdd fdddS t	d| ddS )a*  
    create a translation matrix

    Args:
        spatial_dims: spatial rank
        shift: translate pixel/voxel for every spatial dim, defaults to 0.
        device: device to compute and store the output (when the backend is "torch").
        backend: APIs to use, ``numpy`` or ``torch``.
    )r.  r[  r2  rV  c                   s   t jt |  dS r;  )r   r=  r,  rX  r   r   r   r8    r9  z"create_translate.<locals>.<lambda>c                   s   t j|  dS r;  )r   r,  rX  r   r   r   r8    r9  r  r  N)
r4   r8   r   r  _create_translater   r=  r   r  r   )r.  r[  r   r  r  r   r   r   rM     s    




)r.  r[  r{   c                 C  sB   t |}|| d }t|d |  D ]\}}|||| f< q$||S )Nr   )r.   r  )r.  r[  r2  rV  rK  r   ar   r   r   r\    s
    r\  r   z1.2z1.5)old_defaultnew_defaultsincereplacedzIndexSelection | Noneztuple[list[int], list[int]])r   	select_fnchannel_indicesr   r   r{   c                 C  s  t | dd | jdd }|dk	r2| tt| n| }||d}t|j}t||}|D ]}|dk r\td| dq\dg| }	dg| }
tt	
tt||d D ]\}}|}t|dkrt||}| sdg| dg| f  S t|| kd }|d ||  }|d ||  d }|rBt|d}t||| }t|tjr`|   n||	|< t|tjr|   n||
|< q|	|
fS )	a  
    Generate the spatial bounding box of foreground in the image with start-end positions (inclusive).
    Users can define arbitrary function to select expected foreground from the whole image or specified channels.
    And it can also add margin to every dim of the bounding box.
    The output format of the coordinates is:

        [1st_spatial_dim_start, 2nd_spatial_dim_start, ..., Nth_spatial_dim_start],
        [1st_spatial_dim_end, 2nd_spatial_dim_end, ..., Nth_spatial_dim_end]

    This function returns [0, 0, ...], [0, 0, ...] if there's no positive intensity.

    Args:
        img: a "channel-first" image of shape (C, spatial_dim1[, spatial_dim2, ...]) to generate bounding box from.
        select_fn: function to select expected foreground, default is to select values > 0.
        channel_indices: if defined, select foreground only on the specified channels
            of image. if None, select foreground on the whole image.
        margin: add margin value to spatial dims of the bounding box, if only 1 value provided, use it for all dims.
        allow_smaller: when computing box size with `margin`, whether to allow the image edges to be smaller than the
                final box edges. If `True`, the bounding boxes edges are aligned with the input image edges, if `False`,
                the bounding boxes edges are aligned with the final box edges. Default to `True`.

    rS   r@   r   Nr   z0margin value should not be negative number, got r  r   )rG   r   r   r.   r   r   r/   r   r  	itertoolscombinationsreversedr   r   r"   r   r   r   r   r   detachr   item)r   rb  rc  r   r   r   r   ndimr   	box_startbox_enddiaxdtZarg_maxmin_dmax_dr   r   r   rS     s4    



$

&(r   )r   connectivitynum_componentsr{   c                 C  s   t d\}}to.|o.t| tjo.| jtdk}|rNt|  }|jj	}t
}n&tsZtdt| tj^}}	tj	}t}|||dd\}
}||kr|t}n@|
||
 }|||ddd }|d| }||
|}t|| |jdd	 S )
aS  
    Gets the largest connected component mask of an image.

    Args:
        img: Image to get largest connected component from. Shape is (spatial_dim1 [, spatial_dim2, ...])
        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
            Accepted values are ranging from  1 to input.ndim. If ``None``, a full
            connectivity of ``input.ndim`` is used. for more details:
            https://scikit-image.org/docs/dev/api/skimage.measure.html#skimage.measure.label.
        num_components: The number of largest components to preserve.
    zcucim.skimager   zSkimage.measure required.T)rq  Z
return_numNr   )dstrz   r   )r6   has_cpr   r   r   r   r:   shortmeasurer   cphas_measureRuntimeErrorr9   r   r?   r   r   r   argsortbincountisinr;   rz   )r   rq  rr  skimage	has_cucimuse_cpimg_r   libr}   featuresnum_featuresr   nonzerosZfeatures_to_keepr   r   r   rU   %  s&    $@   z+Sequence[float] | float | np.ndarray | None)r   min_sizerq  independent_channels
by_measurepixdimr{   c                 C  s"  t t| dkr| S ts td|rt | jdd }t| tjjrL| j	}n&|dk	r`t
||}ntd d| }tt|}|dkrtd d}t|| }n|dk	rtd t| tj^}	}
|s|	dk}	n|	|	 dkrtntj}	t|	||}t|| ^}}
|s| | }|S )	a  
    Use `skimage.morphology.remove_small_objects` to remove small objects from images.
    See: https://scikit-image.org/docs/dev/api/skimage.morphology.html#remove-small-objects.

    Data should be one-hotted.

    Args:
        img: image to process. Expected shape: C, H,W,[D]. Expected to only have singleton channel dimension,
            i.e., not be one-hotted. Converted to type int.
        min_size: objects smaller than this size are removed.
        connectivity: Maximum number of orthogonal hops to consider a pixel/voxel as a neighbor.
            Accepted values are ranging from  1 to input.ndim. If ``None``, a full
            connectivity of ``input.ndim`` is used. For more details refer to linked scikit-image
            documentation.
        independent_channels: Whether to consider each channel independently.
        by_measure: Whether the specified min_size is in number of voxels. if this is True then min_size
            represents a surface area or volume value of whatever units your image is in (mm^3, cm^2, etc.)
            default is False.
        pixdim: the pixdim of the input image. if a single number, this is used for all axes.
            If a sequence of numbers, the length of the sequence must be equal to the image dimensions.
    r   zSkimage required.NzU`img` is not of type MetaTensor and `pixdim` is None, assuming affine to be identity.rZ  r   zPInvalid `pixdim` value detected, set it to 1. Please verify the pixdim settings.z?`pixdim` is specified but not in use when computing the volume.)r   r    has_morphologyry  r   r   r   r   r   r  r/   r   r   r   prodr   r+  r9   r?   r   r   r   int32
morphologyrV   r;   )r   r  rq  r  r  r  srZ_pixdimZvoxel_volumeimg_npr}   out_npr   r   r   r   rV   V  s8    



zint | Iterable[int] | Nonezset[int])r   	is_onehotdiscardr{   c                 C  sn   | j d }|r"dd t| D }n(|dkr:td| dtt|  }|dk	rjt|D ]}|| qZ|S )a  Get list of non-background labels in an image.

    Args:
        img: Image to be processed. Shape should be [C, W, H, [D]] with C=1 if not onehot else `num_classes`.
        is_onehot: Boolean as to whether input image is one-hotted. If one-hotted, only return channels with
        discard: Can be used to remove labels (e.g., background). Can be any value, sequence of values, or
            `None` (nothing is discarded).

    Returns:
        Set of labels
    r   c                 S  s    h | ]\}}|  d kr|qS )r   )r  )r   r   r   r   r   r   	<setcomp>  s      z$get_unique_labels.<locals>.<setcomp>r   z7If input not one-hotted, should only be 1 channel, got r  N)r   r  r   setr    r   r.   r  )r   r  r  
n_channelsapplied_labelsr   r   r   r   rk     s    
zIterable[int] | None)img_arrr  rq  r{   c              
   C  s   d}| j | }|dk}| jd }t||p,|}|dk	r@t|nt| |}d}|| |D ]t}	tj| j dd t	d}
tj
|
|d|rt| |	 n
| d |	kdd|
d |rt|
| |	< q\|	| dt|
f< q\| S )a  
    Fill the holes in the provided image.

    The label 0 will be treated as background and the enclosed holes will be set to the neighboring class label.
    What is considered to be an enclosed hole is defined by the connectivity.
    Holes on the edge are always considered to be open (not enclosed).

    Note:

        The performance of this method heavily depends on the number of labels.
        It is a bit faster if the list of `applied_labels` is provided.
        Limiting the number of `applied_labels` results in a big decrease in processing time.

        If the image is one-hot-encoded, then the `applied_labels` need to match the channel index.

    Args:
        img_arr: numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
        applied_labels: Labels for which to fill holes. Defaults to None,
            that is filling holes for all labels.
        connectivity: Maximum number of orthogonal hops to
            consider a pixel/voxel as a neighbor. Accepted values are ranging from  1 to input.ndim.
            Defaults to a full connectivity of ``input.ndim``.

    Returns:
        numpy array of shape [C, spatial_dim1[, spatial_dim2, ...]].
    r   r   Nr|   r   )	structure
iterationsmaskoriginborder_valueoutput)r   ri  ndimagegenerate_binary_structurer  rk   r  r   r   r   binary_dilationlogical_not)r  r  rq  Zchannel_axisnum_channelsZ
is_one_hotr.  r  Zbackground_labelr   tmpr   r   r   rO     s.    


	zlist[tuple[int, ...]])r   r  
backgroundpertr{   c                   s   t  dd dkrtjjjt |ktd dkrDtd fdd}g }t jD ]<}|	t
||  | |	t
||  | qd|S )a  
    Generate extreme points from an image. These are used to generate initial segmentation
    for annotation models. An optional perturbation can be passed to simulate user clicks.

    Args:
        img:
            Image to generate extreme points from. Expected Shape is ``(spatial_dim1, [, spatial_dim2, ...])``.
        rand_state: `np.random.RandomState` object used to select random indices.
        background: Value to be consider as background, defaults to 0.
        pert: Random perturbation amount to add to the points, defaults to 0.0.

    Returns:
        A list of extreme points, its length is equal to 2 * spatial dimension of input image.
        The output format of the coordinates is:

        [1st_spatial_dim_min, 1st_spatial_dim_max, 2nd_spatial_dim_min, ..., Nth_spatial_dim_max]

    Raises:
        ValueError: When the input image does not have any foreground pixel.
    rT   r@   Nr   z1get_extreme_points: no foreground object in mask!c                   s   t | | kd }t|tjr(| n|}dk	r>|n|}g }t jD ]X}t| | d dk	rv	 nd  } t
| d} t|  j| d } ||  qP|S )z
        Select one of the indices within slice containing val.

        Args:
            val : value for comparison
            dim : dimension in which to look for value
        r   Nr  r   r   )r"   r   r   r   r   r  r   ri  r   r	  r   r   r   r   )valr   r   ptjr   r   r  r  r   r   
_get_point  s    ,
z&get_extreme_points.<locals>._get_point)rG   r   r   r  r"   r   r   r   ri  r   r   r   r   )r   r  r  r  r  pointsr   r   r  r   rT     s    
      z?Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensorztorch.Tensor)r  r   sigmarescale_minrescale_maxr{   c           	        s   t jt |d t jd | D ]}d |< qt|trJ fdd|D }nt j| jd} dd t|j	d |d}| 
d    }  } | ||    ||  | S )	a"  
    Please refer to :py:class:`monai.transforms.AddExtremePointsChannel` for the usage.

    Applies a gaussian filter to the extreme points image. Then the pixel values in points image are rescaled
    to range [rescale_min, rescale_max].

    Args:
        points: Extreme points of the object/organ.
        label: label image to get extreme points from. Shape must be
            (1, spatial_dim1, [, spatial_dim2, ...]). Doesn't support one-hot labels.
        sigma: if a list of values, must match the count of spatial dimensions of input data,
            and apply every value in the list to 1 spatial dimension. if only 1 value provided,
            use it for all spatial dimensions.
        rescale_min: minimum value of output data.
        rescale_max: maximum value of output data.
    r   r|   rt   c                   s   g | ]}t j| jd qS )r   )r   r,  r   )r   r   Zpoints_imager   r   r   M  s     z+extreme_points_to_image.<locals>.<listcomp>r   r   )r  )r   
zeros_liker,  ru   r   r   r   	unsqueezer   ri  squeezerg  r   r   )	r  r   r  r  r  r  gaussian_filterZmin_intensityZmax_intensityr   r  r   rN   /  s    

zSequence[int] | int | None)img_ndimspatial_axeschannel_firstr{   c                 C  s   |dkr&t |rtd| n
t| d S g }t|D ]J}|rZ||dk rN||  n|d  q2||dk rv|d | d  n| q2|S )ae  
    Utility to map the spatial axes to real axes in channel first/last shape.
    For example:
    If `channel_first` is True, and `img` has 3 spatial dims, map spatial axes to real axes as below:
    None -> [1, 2, 3]
    [0, 1] -> [1, 2]
    [0, -1] -> [1, -1]
    If `channel_first` is False, and `img` has 3 spatial dims, map spatial axes to real axes as below:
    None -> [0, 1, 2]
    [0, 1] -> [0, 1]
    [0, -1] -> [0, -2]

    Args:
        img_ndim: dimension number of the target image.
        spatial_axes: spatial axes to be converted, default is None.
            The default `None` will convert to all the spatial axes of the image.
            If axis is negative it counts from the last to the first axis.
            If axis is a tuple of ints.
        channel_first: the image data is channel first or channel last, default to channel first.

    Nr   r   )r   r   r.   r   )r  r  r  Zspatial_axes_r]  r   r   r   r]   ]  s     $z=MapTransform | Compose | tuple[MapTransform] | tuple[Compose])	transformc              	   c  s   t | rt| } g }t| tr&| g}nt| trDdd |  jD }t|dkrXtddd |D }z|D ]
}d|_	qldV  W 5 t||D ]\}}||_	qX dS )a  Temporarily set all MapTransforms to not throw an error if keys are missing. After, revert to original states.

    Args:
        transform: either MapTransform or a Compose

    Example:

    .. code-block:: python

        data = {"image": np.arange(16, dtype=float).reshape(1, 4, 4)}
        t = SpatialPadd(["image", "label"], 10, allow_missing_keys=False)
        _ = t(data)  # would raise exception
        with allow_missing_keys_mode(t):
            _ = t(data)  # OK!
    c                 S  s   g | ]}t |tr|qS r   )r   r   r   tr   r   r   r     s     
 z+allow_missing_keys_mode.<locals>.<listcomp>r   z_allow_missing_keys_mode expects either MapTransform(s) or Compose(s) containing MapTransform(s)c                 S  s   g | ]
}|j qS r   )allow_missing_keysr  r   r   r   r     s     TN)
r3   r   r   r   flatten
transformsr   	TypeErrorr   r  )r  r  Zorig_statesr  Zo_sr   r   r   rB     s$    


nearestzbool | Nonemodealign_cornersc                   s  t ttfr"fddD S t ts0S tdkrd }t |tsZ|tkrdd< n8t |d ts~|d tkrfddttD d< dkrdkrt	j
n d }t|r؇ fddD n d< dkr
dkr
fd	d
D S S )a  
    Recursively change the interpolation mode in the applied operation stacks, default to "nearest".

    See also: :py:class:`monai.transform.inverse.InvertibleTransform`

    Args:
        trans_info: applied operation stack, tracking the previously applied invertible transform.
        mode: target interpolation mode to convert, default to "nearest" as it's usually used to save the mode output.
        align_corners: target align corner value in PyTorch interpolation API, need to align with the `mode`.

    c                   s   g | ]}t | d qS r  rE   )r   r   )r  r  r   r   r     s     z/convert_applied_interp_mode.<locals>.<listcomp>r  r   c                   s   g | ]} qS r   r   r  )r  r   r   r     s     r  Nc                   s   g | ]} qS r   r   r  )_align_cornersr   r   r     s     c                   s    i | ]}|t |  d qS r  r  )r   k)r  r  
trans_infor   r   
<dictcomp>  s     z/convert_applied_interp_mode.<locals>.<dictcomp>)r   r   r   r   dictr   _interp_modesr   r   r+   NONEr3   )r  r  r  Zcurrent_modecurrent_valuer   )r  r  r  r  r   rE     s*    

c                 C  sx   t | ttfrdd | D S t | tjjr:t| j| _| S t | tsH| S t	| } t
j| krft
j| t
j< dd |  D S )zbfind MetaTensors in list or dict `data` and (in-place) set ``TraceKeys.ID`` to ``Tracekeys.NONE``.c                 S  s   g | ]}t |qS r   ro   )r   r!  r   r   r   r     s     z reset_ops_id.<locals>.<listcomp>c                 S  s   i | ]\}}|t |qS r   r  r   r  r~   r   r   r   r    s      z reset_ops_id.<locals>.<dictcomp>)r   r   r   r   r   r   ro   applied_operationsr   r  r+   IDr  items)r   r   r   r   ro     s    

)r*  r  c                 C  s^   t |dt|  }g }t|| D ]4\}}|dkrFtt|| | n|}|| q t|S )aZ  
    Compute the target spatial size which should be divisible by `k`.

    Args:
        spatial_shape: original spatial shape.
        k: the target k for each spatial dimension.
            if `k` is negative or 0, the original size is preserved.
            if `k` is an int, the same `k` be applied to all the input spatial dimensions.

    r   r   )r1   r   r   r   r   r+  r   r   )r*  r  new_sizeZk_dr   new_dimr   r   r   rD     s    "      znp.ndarray | None)r   r  num_binsr   r   r{   c           
      C  s   | j }|dk	r | tj|td n| }tr>t| |\}}n0t| |\}}|dd |dd  d }| }	t	|	||d}	t
|  ||	} | |S )a  
    Utility to equalize input image based on the histogram.
    If `skimage` installed, will leverage `skimage.exposure.histogram`, otherwise, use
    `np.histogram` instead.

    Args:
        img: input image to equalize.
        mask: if provided, must be ndarray of bools or 0s and 1s, and same shape as `image`.
            only points at which `mask==True` are used for the equalization.
        num_bins: number of the bins to use in histogram, default to `256`. for more details:
            https://numpy.org/doc/stable/reference/generated/numpy.histogram.html.
        min: the min value to normalize input image, default to `0`.
        max: the max value to normalize input image, default to `255`.

    Nr|   r   r   r   )rv   rx   ry   )r   r   r   r   has_skimageexposure	histogramr  r   r_   interpreshape)
r   r  r  r   r   
orig_shapeZhist_imghistbinsZcumr   r   r   re     s    c                   @  s@   e Zd ZdZeddddddZedddddd	d
dZdS )rP   z/
    Helper class storing Fourier mappings
    r   r   )r   r.  r{   c                 C  s   t t| d}t| tjrlttjdrFtjjtjj| |d|d}qt	jjt	jj| 
  |d|d}nt	jjt	jj| |d|d}|S )a?  
        Applies fourier transform and shifts the zero-frequency component to the
        center of the spectrum. Only the spatial dimensions get transformed.

        Args:
            x: Image to transform.
            spatial_dims: Number of spatial dimensions.

        Returns
            k: K-space data.
        r   fftshiftr   axes)r   r   r   r   r   hasattrfftr  fftnr   r   numpy)r   r.  r   r  r   r   r   shift_fourier"  s    &zFourier.shift_fourierNr   )r  r.  n_dimsr{   c                 C  s   t t| d}t| tjrrttjdrJtjjtjj| |d|ddj	}qt
jjt
jj|   |d|dj	}nt
jjt
jj| |d|dj	}|S )a  
        Applies inverse shift and fourier transform. Only the spatial
        dimensions are transformed.

        Args:
            k: K-space data.
            spatial_dims: Number of spatial dimensions.

        Returns:
            x: Tensor in image space.
        r   	ifftshiftr  backward)r   r   r  )r   r   r   r   r   r  r  ifftnr  realr   r   r  )r  r.  r  r   r   r   r   r   inv_shift_fourier;  s    "(zFourier.inv_shift_fourier)N)__name__
__module____qualname____doc__staticmethodr  r  r   r   r   r   rP     s
   r   r   zHashable | None)r  	test_datakeyr{   c                   s   ddl m  dd }t|||ts(dnd}|  j}t|  sVt fdd|D r^td|D ]x}|||}t|}t|t	j
r|jnd	}	t||| j| j}|||}
t|
t	j
r|
jnd	}t|
|r||	krb|d
7 }qb|S )a  
    Get the number of times that the data need to be converted (e.g., numpy to torch).
    Conversions between different devices are also counted (e.g., CPU to GPU).

    Args:
        transform: composed transforms to be tested
        test_data: data to be used to count the number of conversions
        key: if using dictionary transforms, this key will be used to check the number of conversions.
    r   OneOfc                 S  s   |d kr| S | | S r   r   )objr  r   r   r   	_get_dataa  s    z4get_number_image_type_conversions.<locals>._get_datar   c                 3  s   | ]}t | V  qd S r   )r   r   r  r   r   r   j  s     z4get_number_image_type_conversions.<locals>.<genexpr>zRNot compatible with `OneOf`, as the applied transform is deterministically chosen.Nr   )monai.transforms.composer  r   r-  r  r  r   ry  typer   r   r   r   	map_itemsunpack_items)r  r  r  r  Znum_conversionstr
_transformZ	prev_dataZ	prev_typeprev_deviceZ	curr_dataZcurr_devicer   r  r   rf   U  s     

 


c                  C  sj   i } g }t tjD ]R\}}||kr$q|| t|rt|tr|dkrtj|j	ktj
|j	kg| |< q| S )a6  Get the backends of all MONAI transforms.

    Returns:
        Dictionary, where each key is a transform, and its
        corresponding values are a boolean list, stating
        whether that transform supports (1) `torch.Tensor`,
        and (2) `np.ndarray` as input without needing to
        convert.
    )BatchInverseTransformr   CuCIMCuCIMD
DecollatedInvertDInvertibleTransformLambdaLambdaDr   r  	RandCuCIM
RandCuCIMDRandomOrderPadListDataCollate
RandLambdaRandLambdaDRandTorchVisionDRandomizableTransformTorchVisionDr   )r   r   r  r   r   
issubclassr   r8   r  r  r  )backendsZunique_transformsnr  r   r   r   rg   z  s     

c                    s$  G dd d} dd  | j f fdd	}t }t|}d\}}}}|dd	d
 | D ]v\}}	t|	rv| j}
|d7 }n>|	d r| j}
|d7 }n&|	d r| j}
|d7 }n| j}
|d7 }|||	d |	d |
d qVtd|  d| | j  d| | j  d| | j  d| | j dS )z2Prints a list of backends of all MONAI transforms.c                   @  s   e Zd ZdZdZdZdZdS )z(print_transform_backends.<locals>.Colorsr   Z91Z92Z93N)r  r  r  noneredgreenyellowr   r   r   r   Colors  s   r  c                 S  s   t d| d|  d d S )Nz[r   z[00m)print)r  colorr   r   r   print_color  s    z-print_transform_backends.<locals>.print_colorc                   s$    | dd|dd|d| d S )Nz<50 z<8r   )rA   r   r  r  r  r   r   print_table_column  s    z4print_transform_backends.<locals>.print_table_column)r   r   r   r   r   zTorch?zNumpy?r   r   )r  zTotal number of transforms:z1Number transforms allowing both torch and numpy: zNumber of TorchTransform: zNumber of NumpyTransform: zNumber of uncategorized: N)	r  rg   r   r  allr  r  r  r  )r  r  r  Zn_totalZ	n_t_or_npZn_tZn_npZn_uncategorizedr  r~   r  r   r  r   rh     s2    



z
str | Noners  r  c                 C  sx   t | tjr0|dkrd}n|dkr&d}t|tS t | tjr`|dkrJd}n|dkrVd}t|tS tdt	|  ddS )z
    Utility to convert padding mode between numpy array and PyTorch Tensor.

    Args:
        dst: target data to convert padding mode for, should be numpy array or PyTorch Tensor.
        mode: current padding mode.

    wrapcircularedge	replicatezunsupported data type: r  N)
r   r   r   r4   r)   r   r?   r'   r   r  r!  r   r   r   ri     s    	

z7NdarrayOrTensor | str | bytes | Mapping | Sequence[Any]z)NdarrayOrTensor | Mapping | Sequence[Any])r   r{   c                   sn   t | tjtjttfr"t| f S t | trB fdd| 	 D S t | t
rft|  fdd| D S | S dS )a  
    Check and ensure the numpy array or PyTorch Tensor in data to be contiguous in memory.

    Args:
        data: input data to convert, will recursively convert the numpy array or PyTorch Tensor in dict and sequence.
        kwargs: if `x` is PyTorch Tensor, additional args for `torch.contiguous`, more details:
            https://pytorch.org/docs/stable/generated/torch.Tensor.contiguous.html#torch.Tensor.contiguous.

    c                   s   i | ]\}}|t |f qS r   rj   r  kwargsr   r   r    s      z)convert_to_contiguous.<locals>.<dictcomp>c                 3  s   | ]}t |f V  qd S r   r&  r   r'  r   r   r     s     z(convert_to_contiguous.<locals>.<genexpr>N)r   r   r?   r   r   r-  bytesr   r   r  r   r  )r   r(  r   r'  r   rj     s    

)centeredc                 C  s   t t|t| }| |kr(t|d S tjdd t| |D td}t|| }|r~t	|d| d d |d|df< |S )aB  
    Compute the scaling matrix according to the new spatial size

    Args:
        spatial_size: original spatial size.
        new_spatial_size: new spatial size.
        centered: whether the scaling is with respect to the image center (True, default) or corner (False).

    Returns:
        the scaling matrix.

    r   c                 S  s&   g | ]\}}t |t t|d  qS r  )ru   r   )r   or  r   r   r   r     s     z scale_affine.<locals>.<listcomp>r|   Nr  r   )
r   r   r   r=  r   r   ru   rK   r   rW  )r   Znew_spatial_sizer*  r   r   scaler   r   r   rl     s    &prec                   sF   ddh}t ||dkr"||   n
| |  t|  fdd}|S )z
    Adds `hook` before or after a `func` call. If mode is "pre", the wrapper will call hook then func.
    If the mode is "post", the wrapper will call func then hook.
    r-  postc                   s   | |} | |S r   r   )instr   _funcZ_hookr   r   wrapper%  s    
zattach_hook.<locals>.wrapper)r4   r
   )funchookr  	supportedr2  r   r0  r   rm     s    
)r  c           	      C  s@  t |ts|S t|}t| }||kr8tjj ||< t ||  tjjsltj||  || < || ||  _|| 	||  j tj
j| }||krtjj ||< ||  j||  }}|s| ||  _||< |S |s| ||  _||< |S |rt|t|kr
|n|}nt|t|kr&|n|}| ||  _||< |S )z
    Given the key, sync up between metatensor `data_dict[key]` and meta_dict `data_dict[key_transforms/meta_dict]`.
    t=True: the one with more applied_operations in metatensor vs meta_dict is the output, False: less is the output.
    )r   r   r  r(   metar   r   r   get_default_metaupdater  TraceableTransform	trace_keyget_default_applied_operationsr  r   )	r  	data_dictr  r!  Zmeta_dict_keyZ	xform_keyZ	from_meta	from_dictrefr   r   r   rn   -  s2    

)r{   c                 C  s4   t | tr(t| dkr(tdd | D s0tddS )z0
    Check boundaries for Signal transforms
    r   c                 s  s   | ]}t |tV  qd S r   )r   ru   r   r   r   r   r   W  s     z#check_boundaries.<locals>.<genexpr>z<Incompatible values: boundaries needs to be a list of float.N)r   r   r   r   r   )
boundariesr   r   r   rC   R  s    
c                 C  sn   | \}}}|j d }t|d}t|| |}t|d }|t|| | }|dkrV|nd}t||t||fS )z?
    given a tuple (pos,w,max_w), return a tuple of slices
    r   r   N)r   r   r   r   )tupposr   Zmax_wZorig_minZorig_maxZ	block_minZ	block_maxr   r   r   paste_slices\  s    


rB  c                 C  sV   t ||j| }t tt| \}}||d  | dd|d f< | jd dkrR|  } | S )zZ
    given a location (loc) and an original array (orig), paste a block array into it
    r   Nr   )r   r   maprB  r  )origblocklocZloc_zipZorig_slicesZblock_slicesr   r   r   pastej  s    rG  )dutyc           	      C  s   t | t | }}t |}t |}t|j}|dk|dk B }t|dtj }| ||d tj k @ }d||< | | @ }d||< |S )z
    compute squarepulse using pytorch
    equivalent to numpy implementation from
    https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.square.html
    r   r   r   r   )r=   r   r   r   	remainderpi)	sigrH  r  r   r   mask1Ztmodmask2Zmask3r   r   r   squarepulsex  s    rN  c                 C  s   t t| td d}|d k	r"t|S tjtjtjtjtjtj	tj
tj	tjtj	tjtjtjtji}t t| |d d}|d k	r||S t t| t|tt S Nr   )r4   r-  r*   r   r%   NEARESTZERONEAREST_EXACTLINEARONEBILINEAR	TRILINEARBICUBICTHREEAREAr   interp_moderet_mappingr   r   r   _to_numpy_resample_interp_mode  s,           	r^  c                 C  s   t t| td d}|d k	r|S tjtddr2tjntjtjtj	tj
tji}t t| |d d}|d k	rh|S t t| t|tt S )Nr   r      )r4   r-  r%   r*   rQ  r7   rR  rP  rT  rS  rX  rW  r   rZ  r   r   r   _to_torch_resample_interp_mode  s       r`  c                 C  sr   t t| td d}|d k	r|S tjtjtjtjtjtj	i}t t| |d d}|d k	rX|S t t| t
|t
t S rO  )r4   r-  r&   r$   ZEROSCONSTANTBORDERrP  
REFLECTIONREFLECTr   r   r\  r]  r   r   r   _to_numpy_resample_padding_mode  s       rg  c                 C  s   t t| td d}|d k	r|S tjtjtjtjtjtjtj	tj
tjtj
tjtj
tjtj
i}t t| |d d}|d k	rx|S t t| t|tt S rO  )r4   r-  r$   r&   rb  ra  GRID_CONSTANTrP  rc  re  rd  WRAP	GRID_WRAPGRID_MIRRORr   rf  r   r   r   _to_torch_resample_padding_mode  s,           	rl  constantr   )r[  c                 K  s  dd|p
i     }}}|dkrBtt| tdddk	r<tjntj}|tjkrht| }t|}||||fS t	| }t
|}t|dr|dd}|dkrtj}n|dkrtj}ntj}|dd	s||||fS |d
krdn|}|dkrd}nt|drd}nt|}||||fS )a  
    Automatically adjust the resampling interpolation mode and padding mode,
    so that they are compatible with the corresponding API of the `backend`.
    Depending on the availability of the backends, when there's no exact
    equivalent, a similar mode is returned.

    Args:
        interp_mode: interpolation mode.
        padding_mode: padding mode.
        backend: optional backend of `TransformBackends`. If None, the backend will be decided from `interp_mode`.
        kwargs: additional keyword arguments. currently support ``torch_interpolate_spatial_nd``, to provide
            additional information to determine ``linear``, ``bilinear`` and ``trilinear``;
            ``use_compiled`` to use MONAI's precompiled backend (pytorch c++ extensions), default to ``False``.
    Nr   linearZtorch_interpolate_spatial_ndr   r   rA  Zuse_compiledF
reflectionbicubic)copyr4   r-  r*   r8   r  r  r^  rg  r`  rl  endswithpopr%   rS  rV  rU  r#   )r[  padding_moder  r(  Z_interp_modeZ_padding_mode_kwargsndr   r   r   rp     s8    
No message providedzlist | dict)entry
status_keydefault_messagec                 C  s   t | tr0t }| D ]}|t||| q|S t|}tj| kr|| tj kr| tj | }|dkrl|gS t |trz|S |gS g S dS )a  
    Check the operations of a MetaTensor to determine whether there are any statuses
    Args:
        entry: a dictionary that may contain TraceKey.STATUS entries, or a list of such dictionaries
        status_key: the status key to search for. This must be an entry in `TraceStatusKeys`_
        default_message: The message to provide if no messages are provided for the given status key entry

    Returns:
        A list of status messages matching the providing status key

    N)r   r   extendcheck_applied_operationsr,   r+   STATUSES)rx  ry  rz  resultsZ	sub_entryZstatus_key_reasonr   r   r   r|    s    

r|  )r   ry  rz  c                 C  s   t  }t| t tfrB| D ]&}t|||\}}|dk	r|| qnht| tjjrp| jD ]}|t	||| qVn:t| t
r|  D ]&}t|||\}}|dk	r|| qt|dkrd|fS dS )a\  
    Checks whether a given tensor is has a particular status key message on any of its
    applied operations. If it doesn't, it returns the tuple `(False, None)`. If it does
    it returns a tuple of True and a list of status messages for that status key.

    Status keys are defined in :class:`TraceStatusKeys<monai.utils.enums.TraceStatusKeys>`.

    This function also accepts:

    * dictionaries of tensors
    * lists or tuples of tensors
    * list or tuples of dictionaries of tensors

    In any of the above scenarios, it iterates through the collections and executes itself recursively until it is
    operating on tensors.

    Args:
        data: a `torch.Tensor` or `MetaTensor` or collections of torch.Tensor or MetaTensor, as described above
        status_key: the status key to look for, from `TraceStatusKeys`
        default_message: a default message to use if the status key entry doesn't have a message set

    Returns:
        A tuple. The first entry is `False` or `True`. The second entry is the status messages that can be used for the
        user to help debug their pipelines.

    Nr   F)TN)r   r   r   rq   r{  r   r   r   r  r|  r  valuesr   )r   ry  rz  Zstatus_key_occurrencesr!  r}   reasonsopr   r   r   rq   !  s"    

)block_paramsfloat64_distanceszNone | float | list[float]ztuple[int, int, int] | Nonez@None | NdarrayOrTensor | tuple[NdarrayOrTensor, NdarrayOrTensor])	r   samplingreturn_distancesreturn_indices	distancesr   r  r  r{   c                C  sP  t ddd\}}	to.|	o.t| tjo.| jjdk}
|s@|s@td| jdkrT| jdks\td|| }}d	\}}|
rd	\}}|r|rtj	ntj
}|d
krtj| tj|d}n2t|tjs|j| jkrtd|j|kstdt|}|rXtj}|d
krtj|  f| j |d}n8t|tjs<|j| jkr<td|j|ksPtdt|}t| }t|jd D ]F}||| ||||d
k	r|| nd
|d
k	r|| nd
||d qnntstdt| }|r|d
krtj|tj	d}n,t|tjstd|jtj	kstd|rv|d
krJtj|jf|j tjd}n,t|tjs`td|jtjksvtdt|jd D ]D}tj|| ||||d
k	r|| nd
|d
k	r|| nd
d qg }|r|d
kr|| |r|d
kr|| |sd
S t| tjr | jnd
}tt|dkr<|d n|t| |dd S )a
  
    Euclidean distance transform, either GPU based with CuPy / cuCIM or CPU based with scipy.
    To use the GPU implementation, make sure cuCIM is available and that the data is a `torch.tensor` on a GPU device.

    Note that the results of the libraries can differ, so stick to one if possible.
    For details, check out the `SciPy`_ and `cuCIM`_ documentation.

    .. _SciPy: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.distance_transform_edt.html
    .. _cuCIM: https://docs.rapids.ai/api/cucim/nightly/api/#cucim.core.operations.morphology.distance_transform_edt

    Args:
        img: Input image on which the distance transform shall be run.
            Has to be a channel first array, must have shape: (num_channels, H, W [,D]).
            Can be of any type but will be converted into binary: 1 wherever image equates to True, 0 elsewhere.
            Input gets passed channel-wise to the distance-transform, thus results from this function will differ
            from directly calling ``distance_transform_edt()`` in CuPy or SciPy.
        sampling: Spacing of elements along each dimension. If a sequence, must be of length equal to the input rank -1;
            if a single number, this is used for all axes. If not specified, a grid spacing of unity is implied.
        return_distances: Whether to calculate the distance transform.
        return_indices: Whether to calculate the feature transform.
        distances: An output array to store the calculated distance transform, instead of returning it.
            `return_distances` must be True.
        indices: An output array to store the calculated feature transform, instead of returning it. `return_indicies` must be True.
        block_params: This parameter is specific to cuCIM and does not exist in SciPy. For details, look into `cuCIM`_.
        float64_distances: This parameter is specific to cuCIM and does not exist in SciPy.
            If True, use double precision in the distance computation (to match SciPy behavior).
            Otherwise, single precision will be used for efficiency.

    Returns:
        distances: The calculated distance transform. Returned only when `return_distances` is True and `distances` is not supplied.
            It will have the same shape and type as image. For cuCIM: Will have dtype torch.float64 if float64_distances is True,
            otherwise it will have dtype torch.float32. For SciPy: Will have dtype np.float64.
        indices: The calculated feature transform. It has an image-shaped array for each dimension of the image.
            The type will be equal to the type of the image.
            Returned only when `return_indices` is True and `indices` is not supplied. dtype np.float64.

    z cucim.core.operations.morphologyrr   r@   cudaz0Neither return_distances nor return_indices TruerA  rD  z9Wrong input dimensionality. Use (num_channels, H, W [,D]))NNN)memory_formatrz   z:distances must be a torch.Tensor on the same device as imgz<distances must be a torch.Tensor of dtype float32 or float64r|   z8indices must be a torch.Tensor on the same device as imgz-indices must be a torch.Tensor of dtype int32r   )r  r  r  r  r   r  r  z/scipy.ndimage required if cupy is not availablez!distances must be a numpy.ndarrayz2distances must be a numpy.ndarray of dtype float64zindices must be a numpy.ndarrayz.indices must be a numpy.ndarray of dtype int32)r  r  r  r  r   r   r   )r6   rt  r   r   r   r   r  ry  ri  float64r5  r  contiguous_formatr  rz   r:   r  r   r   r   r   has_ndimager<   r   r?   r  rr   r   r9   r   )r   r  r  r  r  r   r  r  rr   r~  r  Zdistances_originalZindices_originalZ
distances_indices_rz   r  Zchannel_idxr_valsr   r   r   r   rr   P  s    0 
 





	

__main__)r   )NF)Nr   )NNr   N)r   N)F)NF)NNFT)Nr   )r  r   TFN)N)NN)Nr   r   )r   r  rt   )NT)r  N)Nr  r   r  )N)T)r-  )T)r   )rw  )rw  )NTFNN)
__future__r   rd  r   r   collections.abcr   r   r   r   r   
contextlibr   	functoolsr	   r
   inspectr   r   typingr   r  r   r   r   monai.configr   r   monai.config.type_definitionsr   r   Zmonai.networks.layersr   Zmonai.networks.utilsr   r  r   monai.transforms.transformr   r   r   0monai.transforms.utils_pytorch_numpy_unificationr   r   r   r   r   r   r   r   r    r!   r"   monai.utilsr#   r$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   monai.utils.enumsr8   Zmonai.utils.type_conversionr9   r:   r;   r<   r=   rv  rx  r  r  r  r  rw  rt  
cp_ndarrayr}   r  r  __all__r5  rs   r^   rW   rX   rY   rZ   rd   r_   ra   r   r`   rF   rb   rG   r[   r\   rc   r  rR   rQ   ru   r  rI   r  r  rH   rJ   r4  r:  r=  r@  rL   rN  rK   rW  rY  rM   r   r\  rS   rU   rV   rk   rO   rT   rN   r]   rB   r   r  rE   ro   rD   re   rP   rf   rg   rh   ri   rj   rl   rm   rn   rC   rB  rG  rN  r^  r`  rg  rl  r  rp   r|  rq   rr   r  r   r   r   r   <module>   s  4\8#	   8    &   D 43  ?    A")2& 
A  3    J  <    =  /  #.$      $8%2)%
  21    . 	