U
    Ph                    @  s  d Z ddlmZ ddlZddlmZ ddlmZ ddlm	Z	 ddl
mZmZmZmZmZmZ ddlZddlZddlmZmZ dd	lmZ dd
lmZmZ ddlmZ ddlmZm Z m!Z!m"Z"m#Z#m$Z$ ddl%m&Z&m'Z'm(Z( ddl)m*Z* ddl+m,Z,m-Z- ddl.m/Z/ ddl0m1Z1m2Z2m3Z3m4Z4m5Z5m6Z6m7Z7m8Z8 ddl9m:Z: ddl;m<Z<m=Z=m>Z>m?Z? ddl@mAZAmBZBmCZCmDZDmEZEmFZFmGZGmHZHmIZI ddlJmKZKmLZLmMZMmNZN ddlOmPZPmQZQmRZRmSZSmTZTmUZUmVZVmWZWmXZXmYZYmZZZm[Z[m\Z\m]Z] ddl^m_Z_m`Z`maZambZb ddlcmdZe ddlfmgZg ddlhmiZimjZjmkZk e]d\ZlZme]d\ZnZoe]d\ZpZoe]d\ZqZodd d!d"d#d$d%d&d'd(d)d*d+d,d-d.d/d0d1d2d3d4d5d6d7d8d9d:gZreeeeeesesf esf  esf  ZtG d;d de/e<ZuG d<d  d euZvG d=d! d!e/e<ZwG d>d" d"e/e<ZxG d?d# d#e/e<ZyG d@d( d(e/e<ZzG dAd) d)e/e<Z{G dBd* d*e/e<Z|G dCd+ d+e/e<Z}G dDd, d,e>e/e<Z~G dEd- d-e>e/e<ZG dFd. d.e>e/e<ZG dGd0 d0e>e/e<ZG dHd1 d1e>e/e<ZG dId2 d2e<ZG dJd3 d3e=e<ZG dKd4 d4e=e?ZG dLd5 d5e?ZG dMd6 d6e/e<ZG dNd7 d7e>e/e<ZG dOd8 d8e>ZG dPd9 d9e>ZG dQd$ d$e?ZG dRd/ d/e>ZG dSd% d%e?e:ZG dTd& d&e?e:ZG dUd' d'ee>e:ZG dVd: d:e>ZdS )Wz>
A collection of "vanilla" transforms for spatial operations.
    )annotationsN)Callable)deepcopy)zip_longest)AnyOptionalSequenceTupleUnioncast)USE_COMPILED	DtypeLike)NdarrayOrTensor)get_track_metaset_track_meta)
MetaTensor)
AFFINE_TOLaffine_to_spacingcompute_shape_offset
iter_patchto_affine_ndzoom_affine)AffineTransformGaussianFilter	grid_pull)meshgrid_ij)CenterSpatialCropResizeWithPadOrCrop)InvertibleTransform)affine_funcfliporientationresizerotaterotate90spatial_resamplezoom)MultiSampleTrait)LazyTransformRandomizableRandomizableTransform	Transform)	create_control_gridcreate_gridcreate_rotatecreate_scalecreate_shearcreate_translatemap_spatial_axesresolves_modesscale_affine)argsortargwhere
linalg_invmoveaxis)GridSampleModeGridSamplePadModeInterpolateModeNumpyPadModeconvert_to_cupyconvert_to_dst_typeconvert_to_numpyconvert_to_tensorensure_tupleensure_tuple_repensure_tuple_sizefall_back_tupleissequenceiterableoptional_import)GridPatchSort	PatchKeys	TraceKeysTransformBackends)ImageMetaKey)look_up_option)convert_data_typeget_equivalent_dtypeget_torch_dtype_from_stringnibabelcupyzcupyx.scipy.ndimagezscipy.ndimageSpatialResampleResampleToMatchSpacingOrientationFlipGridDistortion	GridSplit	GridPatchRandGridPatchResizeRotateZoomRotate90RandRotate90
RandRotateRandFlipRandGridDistortionRandAxisFlipRandZoom
AffineGridRandAffineGridRandDeformGridResampleAffine
RandAffineRand2DElasticRand3DElasticRandSimulateLowResolutionc                   @  sx   e Zd ZdZejejejgZe	j
ejdejdfdddddddd	Zddddddddddd	ddZdddddZd
S )rR   a  
    Resample input image from the orientation/spacing defined by ``src_affine`` affine matrix into
    the ones specified by ``dst_affine`` affine matrix.

    Internally this transform computes the affine transform matrix from ``src_affine`` to ``dst_affine``,
    by ``xform = linalg.solve(src_affine, dst_affine)``, and call ``monai.transforms.Affine`` with ``xform``.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    F	str | intstrboolr   modepadding_modealign_cornersdtypelazyc                 C  s*   t j| |d || _|| _|| _|| _dS )a%  
        Args:
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
                If ``None``, use the data type of input data. To be compatible with other modules,
                the output data type is always ``float32``.
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False
        rv   N)r(   __init__rr   rs   rt   ru   )selfrr   rs   rt   ru   rv    rz   S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/transforms/spatial/array.pyrx      s
    zSpatialResample.__init__Ntorch.Tensortorch.Tensor | Nonez)Sequence[int] | torch.Tensor | int | Nonestr | int | None
str | Nonebool | None)	img
dst_affinespatial_sizerr   rs   rt   ru   rv   returnc	                 C  s~   t |p| jp|jtj}	|dk	r$|n| j}|dk	r6|n| j}|dk	rH|n| j}|dkr\| jn|}
t|||||||	|
| 	 d	S )a  
        Args:
            img: input image to be resampled. It currently supports channel-first arrays with
                at most three spatial dimensions.
            dst_affine: destination affine matrix. Defaults to ``None``, which means the same as `img.affine`.
                the shape should be `(r+1, r+1)` where `r` is the spatial rank of ``img``.
                when `dst_affine` and `spatial_size` are None, the input will be returned without resampling,
                but the data type will be `float32`.
            spatial_size: output image spatial size.
                if `spatial_size` and `self.spatial_size` are not defined,
                the transform will compute a spatial size automatically containing the previous field of view.
                if `spatial_size` is ``-1`` are the transform will use the corresponding input img size.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                Defaults to ``None``, effectively using the value of `self.align_corners`.
            dtype: data type for resampling computation. Defaults to ``self.dtype`` or
                ``np.float64`` (for best precision). If ``None``, use the data type of input data.
                To be compatible with other modules, the output data type is always `float32`.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        The spatial rank is determined by the smallest among ``img.ndim -1``, ``len(src_affine) - 1``, and ``3``.

        When both ``monai.config.USE_COMPILED`` and ``align_corners`` are set to ``True``,
        MONAI's resampling implementation will be used.
        Set `dst_affine` and `spatial_size` to `None` to turn off the resampling step.
        Nrv   transform_info)
rN   ru   torchTensorrt   rr   rs   rv   r%   get_transform_info)ry   r   r   r   rr   rs   rt   ru   rv   Zdtype_ptlazy_rz   rz   r{   __call__   s     3zSpatialResample.__call__datar   c              	   C  s   |  |}|tj }t|d |d< |d|d< |tj |d< |dtjkrXd|d< | d t	j
| |f|}W 5 Q R X |d|d< |S )Nru   Z
src_affiner   r   rt   F)pop_transformrI   
EXTRA_INFOrO   pop	ORIG_SIZEgetNONEtrace_transformrR   r   )ry   r   	transformkw_argsoutrz   rz   r{   inverse   s    

zSpatialResample.inverse)NNNNNNN)__name__
__module____qualname____doc__rJ   TORCHNUMPYCUPYbackendr9   BILINEARr:   BORDERnpfloat64rx   r   r   rz   rz   rz   r{   rR   w   s"   %        Dc                
      s6   e Zd ZdZdddddddddd fd	d
Z  ZS )rS   a
  
    Resample an image to match given metadata. The affine matrix will be aligned,
    and the size of the output image will match.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    Nr|   r~   r   r   r   )r   img_dstrr   rs   rt   ru   rv   r   c              
     s"  |dkrt dt|tr"| ntd}|dkr:| jn|}	t j||t|trZ|	 n|j
dd |||||	d}|	st|tr||_t|tr|jtjd}
t|j|_|
|jtj< n`t|trt|tr|jtjd}
t|j}dD ]}||d q|j| |
|jtj< |S )a>	  
        Args:
            img: input image to be resampled to match ``img_dst``. It currently supports channel-first arrays with
                at most three spatial dimensions.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
                Defaults to ``None``, effectively using the value of `self.align_corners`.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            dtype: data type for resampling computation. Defaults to ``self.dtype`` or
                ``np.float64`` (for best precision). If ``None``, use the data type of input data.
                To be compatible with other modules, the output data type is always `float32`.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.

        Raises:
            ValueError: When the affine matrix of the source image is not invertible.
        Returns:
            Resampled input tensor or MetaTensor.
        Nz`img_dst` is missing.      )r   r   r   rr   rs   rt   ru   rv   Zresample_to_match_source)affinespatial_shape)RuntimeError
isinstancer   peek_pending_affiner   eyerv   superr   peek_pending_shapeshaper   metar   KeyFILENAME_OR_OBJr   r   update)ry   r   r   rr   rs   rt   ru   rv   r   r   Zoriginal_fname	meta_dictk	__class__rz   r{   r     s:    )



zResampleToMatch.__call__)NNNNN)r   r   r   r   r   __classcell__rz   rz   r   r{   rS      s        c                   @  s   e Zd ZdZejZdejej	de
jdddddf
ddddddddd	d	dd
dddZejjdd
dddZddddddddddd	ddZdddddZdS )rT   z
    Resample input image into the specified `pixdim`.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    FNz$Sequence[float] | float | np.ndarrayrp   rn   ro   r   z+Sequence[float] | float | np.ndarray | NoneNone)pixdimdiagonalrr   rs   rt   ru   scale_extentrecompute_affine
min_pixdim
max_pixdimrv   r   c                 C  s   t j| |d tjt|tjd| _tjt|	tjd| _tjt|
tjd| _|| _	|| _
|| _t| j| jD ]F\}}t|spt|sp||k s|dk rptd| j d| j dqpt|||||d| _dS )	a  
        Args:
            pixdim: output voxel spacing. if providing a single number, will use it for the first dimension.
                items of the pixdim sequence map to the spatial dimensions of input image, if length
                of pixdim sequence is longer than image spatial dimensions, will ignore the longer part,
                if shorter, will pad with the last value. For example, for 3D image if pixdim is [1.0, 2.0] it
                will be padded to [1.0, 2.0, 2.0]
                if the components of the `pixdim` are non-positive values, the transform will use the
                corresponding components of the original pixdim, which is computed from the `affine`
                matrix of input image.
            diagonal: whether to resample the input to have a diagonal affine matrix.
                If True, the input data is resampled to the following affine::

                    np.diag((pixdim_0, pixdim_1, ..., pixdim_n, 1))

                This effectively resets the volume to the world coordinate system (RAS+ in nibabel).
                The original orientation, rotation, shearing are not preserved.

                If False, this transform preserves the axes orientation, orthogonal rotation and
                translation components from the original affine. This option will not flip/swap axes
                of the original data.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
                If None, use the data type of input data. To be compatible with other modules,
                the output data type is always ``float32``.
            scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
                default False. The option is ignored if output spatial size is specified when calling this transform.
                See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
                should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.
            recompute_affine: whether to recompute affine based on the output shape. The affine computed
                analytically does not reflect the potential quantization errors in terms of the output shape.
                Set this flag to True to recompute the output affine based on the actual pixdim. Default to ``False``.
            min_pixdim: minimal input spacing to be resampled. If provided, input image with a larger spacing than this
                value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
                value of `pixdim`. Default to `None`.
            max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
                value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
                value of `pixdim`. Default to `None`.
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False
        rw   ru   r   zmin_pixdim z$ must be positive, smaller than max .rq   N)r(   rx   r   arrayrA   r   r   r   r   r   r   r   zipisnan
ValueErrorrR   sp_resample)ry   r   r   rr   rs   rt   ru   r   r   r   r   rv   mnmxrz   rz   r{   rx   Y  s"    D$    zSpacing.__init__valr   c                 C  s   || _ || j_d S N)_lazyr   rv   ry   r   rz   rz   r{   rv     s    zSpacing.lazyr|   r~   r   r   z'Sequence[int] | np.ndarray | int | None)	
data_arrayrr   rs   rt   ru   r   output_spatial_shaperv   r   c	                 C  s  t |tr| n|jdd }	t|	}
|
dkr@td|	 dt |trR| nd}|dkr|td t	j
|
d t	jd}t|
t|t	jd }| jd|
  }|j|
k rt	||d g|
|j  }t||
|j}tt|| jd|
 | jd|
 t	jd	D ]\}\}}}|| }t	|r(|nt||}t	|rB|nt||}||krxtd
| d| d| d| d	|t |  kr|t krn n|n|||< q|s|rtd t||| jd}|dkr| j n|}t!|	|||\}}|d|
 |d|
df< |dkr"t"|n|}|dkr6| j#n|}| j$|t%&|||||||d}| j'rt |tr|r|t(dt)|	|}t*||d |_+|S )a  
        Args:
            data_array: in shape (num_channels, H[, W, ...]).
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"self.mode"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"self.padding_mode"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                Defaults to ``None``, effectively using the value of `self.align_corners`.
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                If None, use the data type of input data. To be compatible with other modules,
                the output data type is always ``float32``.
            scale_extent: whether the scale is computed based on the spacing or the full extent of voxels,
                The option is ignored if output spatial size is specified when calling this transform.
                See also: :py:func:`monai.data.utils.compute_shape_offset`. When this is True, `align_corners`
                should be `True` because `compute_shape_offset` already provides the corner alignment shift/scaling.
            output_spatial_shape: specify the shape of the output data_array. This is typically useful for
                the inverse of `Spacingd` where sometimes we could not compute the exact shape due to the quantization
                error with the affine.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.

        Raises:
            ValueError: When ``data_array`` has no spatial dimensions.
            ValueError: When ``pixdim`` is nonpositive.

        Returns:
            data tensor or MetaTensor (resampled into `self.pixdim`).

        r   Nr   9data_array must have at least one spatial dimension, got r   zG`data_array` is not of type MetaTensor, assuming affine to be identity.r   )	fillvaluez,min_pixdim is larger than max_pixdim at dim z: min z max z out z=align_corners=False is not compatible with scale_extent=True.)r   )r   r   rr   rs   rt   ru   rv   z7recompute_affine is not supported with lazy evaluation.),r   r   r   r   lenr   r   warningswarnr   r   r   r   rM   ndarrayr   copysizeappendr   ru   	enumerater   r   r   nanr   minmaxr   r   r   r   r   listrv   r   r   	as_tensorr   NotImplementedErrorr4   r>   r   )ry   r   rr   rs   rt   ru   r   r   rv   Zoriginal_spatial_shapesrZinput_affineaffine_out_dZorig_didx_dr   r   target
new_affineoutput_shapeoffsetactual_shaper   arz   rz   r{   r     s^    4

$
"4


zSpacing.__call__r   c                 C  s   | j |S r   )r   r   ry   r   rz   rz   r{   r     s    zSpacing.inverse)NNNNNNN)r   r   r   r   rR   r   r9   r   r:   r   r   r   rx   r(   rv   setterr   r   rz   rz   rz   r{   rT   O  s0   &T        lc                   @  sX   e Zd ZdZejejgZddddddd	d
dZdddddddZ	dddddZ
dS )rU   z
    Change the input image's orientation into the specified based on `axcodes`.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    NF)LR)PA)ISr   rp   z Sequence[tuple[str, str]] | Noner   )axcodesas_closest_canonicallabelsrv   r   c                 C  sN   t j| |d |dkr"|s"td|dk	r8|r8td || _|| _|| _dS )a  
        Args:
            axcodes: N elements sequence for spatial ND input's orientation.
                e.g. axcodes='RAS' represents 3D orientation:
                (Left, Right), (Posterior, Anterior), (Inferior, Superior).
                default orientation labels options are: 'L' and 'R' for the first dimension,
                'P' and 'A' for the second, 'I' and 'S' for the third.
            as_closest_canonical: if True, load the image as closest to canonical axis format.
            labels: optional, None or sequence of (2,) sequences
                (2,) sequences are labels for (beginning, end) of output axis.
                Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False

        Raises:
            ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values.

        See Also: `nibabel.orientations.ornt2axcodes`.

        rw   N@Incompatible values: axcodes=None and as_closest_canonical=True.z1using as_closest_canonical=True, axcodes ignored.)r(   rx   r   r   r   r   r   r   )ry   r   r   r   rv   rz   rz   r{   rx   ,  s    
zOrientation.__init__r|   r   )r   rv   r   c                 C  s  t |tr| n|jdd }t|}|dkr@td| dt |trjt| tj	^}}t
||}n2td tj|d tjd}tj|d tjd}t|}| jr|}	n| jdkrtd|t| jk r
td	| j d
| d| jj d| d|jd  d tjj| jd| | jd}
t|
|k rPtdt| j d| dtj||
}	|dkrn| jn|}t|||	||  dS )a  
        If input type is `MetaTensor`, original affine is extracted with `data_array.affine`.
        If input type is `torch.Tensor`, original affine is assumed to be identity.

        Args:
            data_array: in shape (num_channels, H[, W, ...]).
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.

        Raises:
            ValueError: When ``data_array`` has no spatial dimensions.
            ValueError: When ``axcodes`` spatiality differs from ``data_array``.

        Returns:
            data_array [reoriented in `self.axcodes`]. Output type will be `MetaTensor`
                unless `get_track_meta() == False`, in which case it will be
                `torch.Tensor`.

        r   Nr   r   r   zH`data_array` is not of type `MetaTensor, assuming affine to be identity.r   r   z
axcodes ('z?') length is smaller than number of input spatial dimensions D=z.
z: spatial shape = z, channels = z;,please make sure the input is in the channel-first format.)r   z5axcodes must match data_array spatially, got axcodes=zD data_array=Dr   )r   r   r   r   r   r   rM   r   r   r   r   r   r   r   r   nibio_orientationr   r   r   r   orientationsaxcodes2orntr   ornt_transformrv   r!   r   )ry   r   rv   r   r   	affine_np_r   srcZspatial_orntdstr   rz   rz   r{   r   P  s6     



.zOrientation.__call__r   c              	   C  sV   |  |}|tj d }tj|}t|d| jd}|d ||}W 5 Q R X |S )Noriginal_affineF)r   r   r   )	r   rI   r   r   r   aff2axcodesrU   r   r   )ry   r   r   orig_affineZorig_axcodesinverse_transformrz   rz   r{   r     s    
zOrientation.inverse)NFr   F)N)r   r   r   r   rJ   r   r   r   rx   r   r   rz   rz   rz   r{   rU   "  s       $9c                   @  sP   e Zd ZdZejgZddddddd	Zdd
dd
dddZd
d
dddZ	dS )rV   a7  
    Reverses the order of elements along the given spatial axis. Preserves shape.
    See `torch.flip` documentation for additional details:
    https://pytorch.org/docs/stable/generated/torch.flip.html

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        spatial_axis: spatial axes along which to flip over. Default is None.
            The default `axis=None` will flip over all of the axes of the input array.
            If axis is negative it counts from the last to the first axis.
            If axis is a tuple of ints, flipping is performed on all of the axes
            specified in the tuple.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False

    NFSequence[int] | int | Nonerp   r   )spatial_axisrv   r   c                 C  s   t j| |d || _d S Nrw   )r(   rx   r	  )ry   r	  rv   rz   rz   r{   rx     s    zFlip.__init__r|   r   r   rv   r   c                 C  s6   t |t d}|dkr| jn|}t|| j||  dS )al  
        Args:
            img: channel first array, must have shape: (num_channels, H[, W, ..., ])
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        
track_metaNr   )r@   r   rv   r    r	  r   )ry   r   rv   r   rz   rz   r{   r     s    zFlip.__call__r   c              
   C  s@   |  | t| jd}|d ||W  5 Q R  S Q R X d S )Nr	  F)r   rV   r	  r   )ry   r   flipperrz   rz   r{   r     s    
zFlip.inverse)NF)N)
r   r   r   r   rJ   r   r   rx   r   r   rz   rz   rz   r{   rV     s
   c                   @  s   e Zd ZdZejgZdejddde	j
dfdddddd	d
ddd	ddZdddddd	d
dddddZdddddZdddddZdS )r[   a	  
    Resize the input image to given spatial size (with scaling, not cropping/padding).
    Implemented using :py:class:`torch.nn.functional.interpolate`.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        spatial_size: expected shape of spatial dimensions after resize operation.
            if some components of the `spatial_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        size_mode: should be "all" or "longest", if "all", will use `spatial_size` for all the spatial dims,
            if "longest", rescale the image so that only the longest side is equal to specified `spatial_size`,
            which must be an int number in this case, keeping the aspect ratio of the initial image, refer to:
            https://albumentations.ai/docs/api_reference/augmentations/geometric/resize/
            #albumentations.augmentations.geometric.resize.LongestMaxSize.
        mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
            The interpolation mode. Defaults to ``"area"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        align_corners: This only has an effect when mode is
            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        anti_aliasing: bool
            Whether to apply a Gaussian filter to smooth the image prior
            to downsampling. It is crucial to filter when downsampling
            the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
        anti_aliasing_sigma: {float, tuple of floats}, optional
            Standard deviation for Gaussian filtering used when anti-aliasing.
            By default, this value is chosen as (s - 1) / 2 where s is the
            downsampling factor, where s > 1. For the up-size case, s < 1, no
            anti-aliasing is performed prior to rescaling.
        dtype: data type for resampling computation. Defaults to ``float32``.
            If None, use the data type of input data.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
    allNFzSequence[int] | intro   r   rp   Sequence[float] | float | NoneDtypeLike | torch.dtyper   )	r   	size_moderr   rt   anti_aliasinganti_aliasing_sigmaru   rv   r   c	           	      C  sF   t j| |d t|ddg| _|| _|| _|| _|| _|| _|| _	d S )Nrw   r  longest)
r(   rx   rL   r  r   rr   rt   r  r  ru   )	ry   r   r  rr   rt   r  r  ru   rv   rz   rz   r{   rx     s    zResize.__init__r|   r   )r   rr   rt   r  r  ru   rv   r   c                   s  |dkr| j n|}|dkr | jn|}|jd }| jdkrtt| j}	|	|krlt|j|	d d}
|	|
}n|	|k rt
d|	 d| dt|tr| n|jdd }t| j|}nXt|tr| n|jdd }t| jtst
d| jt|  t fdd	|D }|dkr | jn|}|dkr4| jn|}t|pL| jpL|jtj}|dkrd| jn|}t|td
d	 |D ||||||||  
S )aP  
        Args:
            img: channel first array, must have shape: (num_channels, H[, W, ..., ]).
            mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``,
                ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
                The interpolation mode. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            align_corners: This only has an effect when mode is
                'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            anti_aliasing: bool, optional
                Whether to apply a Gaussian filter to smooth the image prior
                to downsampling. It is crucial to filter when downsampling
                the image to avoid aliasing artifacts. See also ``skimage.transform.resize``
            anti_aliasing_sigma: {float, tuple of floats}, optional
                Standard deviation for Gaussian filtering used when anti-aliasing.
                By default, this value is chosen as (s - 1) / 2 where s is the
                downsampling factor, where s > 1. For the up-size case, s < 1, no
                anti-aliasing is performed prior to rescaling.
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                If None, use the data type of input data.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        Raises:
            ValueError: When ``self.spatial_size`` length is less than ``img`` spatial dimensions.

        Nr   r  zWlen(spatial_size) must be greater or equal to img spatial dimensions, got spatial_size=z img=r   z=spatial_size must be an int number if size_mode is 'longest'.c                 3  s   | ]}t t|  V  qd S r   )intround).0sscalerz   r{   	<genexpr>;  s     z"Resize.__call__.<locals>.<genexpr>c                 s  s   | ]}t |V  qd S r   )r  )r  _srz   rz   r{   r  C  s     )r  r  ndimr  r   rA   r   rC   r   reshaper   r   r   r   rD   r  r   tuplerr   rt   rN   ru   r   r   rv   r"   r   )ry   r   rr   rt   r  r  ru   rv   
input_ndimZoutput_ndiminput_shapeZ_spsp_sizeimg_size_mode_align_corners_dtyper   rz   r  r{   r      sF    &

  zResize.__call__r   c                 C  s   |  |}| ||S r   r   r  ry   r   r   rz   rz   r{   r   N  s    
zResize.inversec           	   	   C  s   |t j }|t j d }|t j d }|t j d }t|||t jkrHd n||d}|d ||}W 5 Q R X t|t j d D ]}|d}q|S )Nrr   rt   ru   )r   rr   rt   ru   Fnew_dimr   )rI   r   r   r[   r   r   rangesqueeze)	ry   r   r   	orig_sizerr   rt   ru   xformr  rz   rz   r{   r  R  s    
zResize.inverse_transform)NNNNNN)r   r   r   r   rJ   r   r   r;   AREAr   float32rx   r   r   r  rz   rz   rz   r{   r[     s&   &       Nc                
   @  s   e Zd ZdZejgZdeje	j
dejdfddddddddd	d
dZdddddddddddZdddddZdddddZdS )r\   a  
    Rotates an input image by given angle using :py:class:`monai.networks.layers.AffineTransform`.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        angle: Rotation angle(s) in radians. should a float for 2D, three floats for 3D.
        keep_size: If it is True, the output shape is kept the same as the input.
            If it is False, the output shape is adapted so that the
            input array is contained completely in the output. Default is True.
        mode: {``"bilinear"``, ``"nearest"``}
            Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            Padding mode for outside grid values. Defaults to ``"border"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        align_corners: Defaults to False.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        dtype: data type for resampling computation. Defaults to ``float32``.
            If None, use the data type of input data. To be compatible with other modules,
            the output data type is always ``float32``.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
    TFSequence[float] | floatrp   ro   r  r   )angle	keep_sizerr   rs   rt   ru   rv   r   c                 C  s6   t j| |d || _|| _|| _|| _|| _|| _d S r
  )r(   rx   r3  r4  rr   rs   rt   ru   )ry   r3  r4  rr   rs   rt   ru   rv   rz   rz   r{   rx     s    
zRotate.__init__Nr|   r   r   r   rr   rs   rt   ru   rv   r   c                 C  s   t |t d}t|p| jp|jtj}|p.| j}|p8| j}	|dkrH| jn|}
t	|t
r^| n|jdd }| jrv|nd}|dkr| jn|}t|| j|||	|
|||  d	S )a  
        Args:
            img: channel first array, must have shape: [chns, H, W] or [chns, H, W, D].
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                align_corners: Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            align_corners: Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                If None, use the data type of input data. To be compatible with other modules,
                the output data type is always ``float32``.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.

        Raises:
            ValueError: When ``img`` spatially is not one of [2D, 3D].

        r  Nr   r   )r@   r   rN   ru   r   r   rr   rs   rt   r   r   r   r   r4  rv   r#   r3  r   )ry   r   rr   rs   rt   ru   rv   r(  r&  _padding_moder'  im_shaper   r   rz   rz   r{   r     s&    !

 zRotate.__call__r   c                 C  s   |  |}| ||S r   r)  r*  rz   rz   r{   r     s    
zRotate.inversec                 C  s*  |t j d }|t j d }|t j d }|t j d }|t j d }tt|}t||\}	}
}}	td|
||t jkrzdn|dd}t|t|d	d
 }t	||^}}	|t j
 }||d
||d d
}t	|||jdd
 }t|tr&t| dd}tt|d |}| jt	||d
   _|S )Nrot_matrr   rs   rt   ru   FT)
normalizedrr   rs   rt   reverse_indexingr   r   r   r  ru   r  r   )rI   r   r7   r?   r3   r   r   rM   r   r>   r   	unsqueezefloatr-  ru   r   r@   r   r   r   r   )ry   r   r   Zfwd_rot_matrr   rs   rt   ru   Zinv_rot_matr  _m_pr/  img_tZtransform_tr$  r   r   matrz   rz   r{   r    s0    
zRotate.inverse_transform)NNNNN)r   r   r   r   rJ   r   r   r9   r   r:   r   r   r1  rx   r   r   r  rz   rz   rz   r{   r\   d  s"        5c                
   @  s   e Zd ZdZejgZeje	j
dejddfdddddd	d	d
dddZdddddddddddZdddddZdddddZdS )r]   a:	  
    Zooms an ND image using :py:class:`torch.nn.functional.interpolate`.
    For details, please see https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html.

    Different from :py:class:`monai.transforms.resize`, this transform takes scaling factors
    as input, and provides an option of preserving the input spatial size.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        zoom: The zoom factor along the spatial axes.
            If a float, zoom is the same for each spatial axis.
            If a sequence, zoom should contain one value for each spatial axis.
        mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
            The interpolation mode. Defaults to ``"area"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"edge"``.
            The mode to pad data after zooming.
            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        align_corners: This only has an effect when mode is
            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        dtype: data type for resampling computation. Defaults to ``float32``.
            If None, use the data type of input data.
        keep_size: Should keep original size (padding/slicing if needed), default is True.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.
    NTFr2  ro   r   r  rp   r   )r&   rr   rs   rt   ru   r4  rv   r   c           	      K  s<   t j| |d || _|| _|| _|| _|| _|| _|| _d S r
  )	r(   rx   r&   rr   rs   rt   ru   r4  kwargs)	ry   r&   rr   rs   rt   ru   r4  rv   rC  rz   rz   r{   rx     s    zZoom.__init__r|   r   r5  c                 C  s   t |t d}t| j|jd }|dkr.| jn|}|p:| j}	|dkrJ| jn|}
t|p^| j	p^|j	t
j}|dkrt| jn|}t||| j||	|
|||  d	S )a  
        Args:
            img: channel first array, must have shape: (num_channels, H[, W, ..., ]).
            mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``,
                ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
                The interpolation mode. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
                ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
                One of the listed string values or a user supplied function. Defaults to ``"edge"``.
                The mode to pad data after zooming.
                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
            align_corners: This only has an effect when mode is
                'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                If None, use the data type of input data.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        r  r   Nr   )r@   r   rB   r&   r  rr   rs   rt   rN   ru   r   r   rv   r4  r   )ry   r   rr   rs   rt   ru   rv   _zoomr&  r6  r'  r(  r   rz   rz   r{   r   #  s$     
zZoom.__call__r   c                 C  s   |  |}| ||S r   r)  r*  rz   rz   r{   r   V  s    
zZoom.inversec              	   C  s   |t j d rj|t j }t|dd}|t j d }t j|t j d t j< t j|t j d t j< |||}|t j d }|t j d }|t j d	 }t|t j d
}	|	d$ |	|||t jkrd n||d}
W 5 Q R X |
S )NZ
do_padcropedge)r   rr   Zpadcroppad_info	crop_inforr   rt   ru   r;  F)rr   rt   ru   )	rI   r   r   r   r   IDr  r[   r   )ry   r   r   r.  Zpad_or_cropZpadcrop_xformrr   rt   ru   r  r   rz   rz   r{   r  Z  s&    
   zZoom.inverse_transform)NNNNN)r   r   r   r   rJ   r   r   r;   r0  r<   EDGEr   r1  rx   r   r   r  rz   rz   rz   r{   r]     s"   $     3c                   @  sb   e Zd ZdZejgZdddddd	d
dZdddddddZdddddZ	dddddZ
dS )r^   aE  
    Rotate an array by 90 degrees in the plane specified by `axes`.
    See `torch.rot90` for additional details:
    https://pytorch.org/docs/stable/generated/torch.rot90.html#torch-rot90.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    r   r   r   Fr  tuple[int, int]rp   r   )r   spatial_axesrv   r   c                 C  sN   t j| |d d|d  d | _t|}t|dkrDtd| d|| _dS )a  
        Args:
            k: number of times to rotate by 90 degrees.
            spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
                Default: (0, 1), this is the first two axis in spatial dimensions.
                If axis is negative it counts from the last to the first axis.
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False
        rw   r      zBspatial_axes must be 2 numbers to define the plane to rotate, got r   N)r(   rx   r   rA   r   r   rL  )ry   r   rL  rv   spatial_axes_rz   rz   r{   rx   |  s    
zRotate90.__init__Nr|   r   r  c                 C  sF   t |t d}t|j| j}|dkr*| jn|}t||| j||  dS )am  
        Args:
            img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        r  Nr   )	r@   r   r2   r  rL  rv   r$   r   r   )ry   r   rv   axesr   rz   rz   r{   r     s    zRotate90.__call__r   c                 C  s   |  |}| ||S r   r)  r*  rz   rz   r{   r     s    
zRotate90.inversec              
   C  s^   |t j d }|t j d }d|d  }t||d}|d ||W  5 Q R  S Q R X d S )NrO  r   r   )r   rL  F)rI   r   r^   r   )ry   r   r   rO  r   Zinv_kr/  rz   rz   r{   r    s    zRotate90.inverse_transform)r   rJ  F)N)r   r   r   r   rJ   r   r   rx   r   r   r  rz   rz   rz   r{   r^   p  s   	c                      sn   e Zd ZdZejZddddd	d
dddZddd
d fddZddd	dddddZdddddZ	  Z
S )r_   z
    With probability `prob`, input arrays are rotated by 90 degrees
    in the plane specified by `spatial_axes`.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    皙?   rJ  Fr>  r  rK  rp   r   )probmax_krL  rv   r   c                 C  s0   t | | tj| |d || _|| _d| _dS )a#  
        Args:
            prob: probability of rotating.
                (Default 0.1, with 10% probability it returns a rotated array)
            max_k: number of rotations will be sampled from `np.random.randint(max_k) + 1`, (Default 3).
            spatial_axes: 2 int numbers, defines the plane to rotate with 2 spatial axes.
                Default: (0, 1), this is the first two axis in spatial dimensions.
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False
        rw   r   N)r*   rx   r(   rS  rL  _rand_k)ry   rR  rS  rL  rv   rz   rz   r{   rx     s
    zRandRotate90.__init__N
Any | Noner   c                   s.   t  d  | jsd S | j| jd | _d S Nr   )r   	randomize_do_transformr   randintrS  rT  r   r   rz   r{   rW    s    zRandRotate90.randomizeTr|   r   r   rW  rv   r   c                 C  sb   |r|    |dkr| jn|}| jr@t| j| j|d}||}nt|t d}| j|d|d |S a  
        Args:
            img: channel first array, must have shape: (num_channels, H[, W, ..., ]),
            randomize: whether to execute `randomize()` function first, default to True.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        Nrw   r  Treplacerv   )	rW  rv   rX  r^   rT  rL  r@   r   push_transform)ry   r   rW  rv   r   r/  r   rz   rz   r{   r     s    

zRandRotate90.__call__c                 C  s0   |  |}|tj s|S |tj }t ||S r   )r   rI   DO_TRANSFORMr   r^   r  )ry   r   
xform_infoZrotate_xformrz   rz   r{   r     s
    


zRandRotate90.inverse)rP  rQ  rJ  F)N)TN)r   r   r   r   r^   r   rx   rW  r   r   r   rz   rz   r   r{   r_     s          c                      s   e Zd ZdZejZdddddejej	de
jdf
dddddd	d	dd
dddddZdddd fddZdddddd
dddddZdddddZ  ZS )r`   a  
    Randomly rotate the input arrays.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        range_x: Range of rotation angle in radians in the plane defined by the first and second axes.
            If single number, angle is uniformly sampled from (-range_x, range_x).
        range_y: Range of rotation angle in radians in the plane defined by the first and third axes.
            If single number, angle is uniformly sampled from (-range_y, range_y). only work for 3D data.
        range_z: Range of rotation angle in radians in the plane defined by the second and third axes.
            If single number, angle is uniformly sampled from (-range_z, range_z). only work for 3D data.
        prob: Probability of rotation.
        keep_size: If it is False, the output shape is adapted so that the
            input array is contained completely in the output.
            If it is True, the output shape is the same as the input. Default is True.
        mode: {``"bilinear"``, ``"nearest"``}
            Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
            Padding mode for outside grid values. Defaults to ``"border"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        align_corners: Defaults to False.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        dtype: data type for resampling computation. Defaults to ``float32``.
            If None, use the data type of input data. To be compatible with other modules,
            the output data type is always ``float32``.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
            rP  TFtuple[float, float] | floatr>  rp   ro   r  r   )range_xrange_yrange_zrR  r4  rr   rs   rt   ru   rv   r   c                 C  s   t | | tj| |
d t|| _t| jdkrRtt| jd  | jd g| _t|| _t| jdkrtt| jd  | jd g| _t|| _	t| j	dkrtt| j	d  | j	d g| _	|| _
|| _|| _|| _|	| _d| _d| _d| _d S )Nrw   r   r   ra  )r*   rx   r(   rA   rc  r   r!  sortedrd  re  r4  rr   rs   rt   ru   xyz)ry   rc  rd  re  rR  r4  rr   rs   rt   ru   rv   rz   rz   r{   rx     s&    
 
 
 zRandRotate.__init__NrU  r   c                   st   t  d  | jsd S | jj| jd | jd d| _| jj| jd | jd d| _| jj| j	d | j	d d| _
d S )Nr   r   lowhigh)r   rW  rX  r   uniformrc  rg  rd  rh  re  ri  r   r   rz   r{   rW  1  s    zRandRotate.randomizer|   r   r   )r   rr   rs   rt   ru   rW  rv   c              	   C  s   |r|    |dkr| jn|}| jrtt|tr8| n|jdd }	t|	dkrX| j	n| j	| j
| jf| j|pr| j|pz| j|dkr| jn||p| jp|j|d}
|
|}nt|t tjd}| j|d|d |S )aq  
        Args:
            img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            align_corners: Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                If None, use the data type of input data. To be compatible with other modules,
                the output data type is always ``float32``.
            randomize: whether to execute `randomize()` function first, default to True.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        Nr   rM  )r3  r4  rr   rs   rt   ru   rv   r  ru   Tr\  )rW  rv   rX  r   r   r   r   r   r\   rg  rh  ri  r4  rr   rs   rt   ru   r@   r   r   r1  r^  )ry   r   rr   rs   rt   ru   rW  rv   r   r  Zrotatorr   rz   rz   r{   r   9  s$    $	
zRandRotate.__call__c                 C  s.   |  |}|tj s|S td||tj S )Nr   )r   rI   r_  r\   r  r   ry   r   r`  rz   rz   r{   r   k  s    

zRandRotate.inverse)N)NNNNTN)r   r   r   r   r\   r   r9   r   r:   r   r   r1  rx   rW  r   r   r   rz   rz   r   r{   r`     s,    $#      2c                   @  sh   e Zd ZdZejZdddddd	d
dZejj	ddddZddddddddZ
dddddZdS )ra   a)  
    Randomly flips the image along axes. Preserves shape.
    See numpy.flip for additional details.
    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        prob: Probability of flipping.
        spatial_axis: Spatial axes along which to flip over. Default is None.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
    rP  NFr>  r  rp   r   )rR  r	  rv   r   c                 C  s,   t | | tj| |d t||d| _d S )Nrw   )r	  rv   )r*   rx   r(   rV   r  )ry   rR  r	  rv   rz   rz   r{   rx     s    zRandFlip.__init__r   c                 C  s   || j _|| _d S r   r  rv   r   r   rz   rz   r{   rv     s    zRandFlip.lazyTr|   r   rZ  c                 C  sZ   |r|  d |dkr| jn|}| jr4| j||dn|}t|t d}| j|d|d |S r[  )rW  rv   rX  r  r@   r   r^  ry   r   rW  rv   r   r   rz   rz   r{   r     s    	
zRandFlip.__call__r   c                 C  s6   |  |}|tj s|S |j|tj  | j|S r   )r   rI   r_  applied_operationsr   r   r  r   r*  rz   rz   r{   r     s
    

zRandFlip.inverse)rP  NF)TN)r   r   r   r   rV   r   rx   r(   rv   r   r   r   rz   rz   rz   r{   ra   r  s   c                      s~   e Zd ZdZejZddddddd	Zejj	dd
ddZddd fddZ
ddddddddZdddddZ  ZS )rc   a  
    Randomly select a spatial axis and flip along it.
    See numpy.flip for additional details.
    https://docs.scipy.org/doc/numpy/reference/generated/numpy.flip.html

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        prob: Probability of flipping.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
    rP  Fr>  rp   r   )rR  rv   r   c                 C  s2   t | | tj| |d d | _t| jd| _d S )Nrw   r  )r*   rx   r(   _axisrV   r  )ry   rR  rv   rz   rz   r{   rx     s    zRandAxisFlip.__init__rp  c                 C  s   || j _|| _d S r   rq  r   rz   rz   r{   rv     s    zRandAxisFlip.lazyr   r   c                   s.   t  d  | jsd S | j|jd | _d S rV  )r   rW  rX  r   rY  r  rt  r   r   rz   r{   rW    s    zRandAxisFlip.randomizeTNr|   r   rZ  c                 C  sd   |r| j |d |dkr| jn|}| jrB| j| j_| j||d}nt|t d}| j|d|d |S )a  
        Args:
            img: channel first array, must have shape: (num_channels, H[, W, ..., ])
            randomize: whether to execute `randomize()` function first, default to True.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        )r   Nrw   r  Tr\  )	rW  rv   rX  rt  r  r	  r@   r   r^  rr  rz   rz   r{   r     s    	
zRandAxisFlip.__call__c              
   C  s\   |  |}|tj s|S t|tj tj d d}|d ||W  5 Q R  S Q R X d S )NrO  r  F)r   rI   r_  rV   r   r   )ry   r   r   r  rz   rz   r{   r     s    

zRandAxisFlip.inverse)rP  F)TN)r   r   r   r   rV   r   rx   r(   rv   r   rW  r   r   r   rz   rz   r   r{   rc     s   c                      s   e Zd ZdZejZdddejej	de
jddf	dd	d	d
d
dddddd
ddZddd fddZddddddddddddZdddddZ  ZS )rd   a
  
    Randomly zooms input arrays with given probability within given zoom range.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        prob: Probability of zooming.
        min_zoom: Min zoom factor. Can be float or sequence same size as image.
            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims
            to keep the original spatial shape ratio.
            If a sequence, min_zoom should contain one value for each spatial axis.
            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.
        max_zoom: Max zoom factor. Can be float or sequence same size as image.
            If a float, select a random factor from `[min_zoom, max_zoom]` then apply to all spatial dims
            to keep the original spatial shape ratio.
            If a sequence, max_zoom should contain one value for each spatial axis.
            If 2 values provided for 3D data, use the first value for both H & W dims to keep the same zoom ratio.
        mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``, ``"area"``}
            The interpolation mode. Defaults to ``"area"``.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function. Defaults to ``"constant"``.
            The mode to pad data after zooming.
            See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        align_corners: This only has an effect when mode is
            'linear', 'bilinear', 'bicubic' or 'trilinear'. Default: None.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
        dtype: data type for resampling computation. Defaults to ``float32``.
            If None, use the data type of input data.
        keep_size: Should keep original size (pad if needed), default is True.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
        kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.

    rP  g?g?NTFr>  r2  ro   r   r  rp   r   )
rR  min_zoommax_zoomrr   rs   rt   ru   r4  rv   r   c
                 K  s   t | | tj| |	d t|| _t|| _t| jt| jkrdtdt| j dt| j d|| _|| _	|| _
|| _|| _|
| _dg| _d S )Nrw   z1min_zoom and max_zoom must have same length, got z and r         ?)r*   rx   r(   rA   ru  rv  r   r   rr   rs   rt   ru   r4  rC  rD  )ry   rR  ru  rv  rr   rs   rt   ru   r4  rv   rC  rz   rz   r{   rx     s    

zRandZoom.__init__r   )r   r   c                   s   t  d   jsd S  fddt j jD  _t jdkr\t jd |j	d  _n>t jdkr|j	dkrt jd |j	d t
 jd   _d S )Nc                   s   g | ]\}} j ||qS rz   )r   rm  )r  lhry   rz   r{   
<listcomp>5  s     z&RandZoom.randomize.<locals>.<listcomp>r   r   rM  rQ  r   )r   rW  rX  r   ru  rv  rD  r   rB   r  rA   )ry   r   r   rz  r{   rW  1  s    zRandZoom.randomizer|   r   )r   rr   rs   rt   ru   rW  rv   r   c              	   C  s   |r| j |d |dkr| jn|}| js<t|t tjd}	nLt| jf| j	|pP| j
|pX| j|dkrf| jn||pp| j|d| j}
|
|}	| j|	d|d |	S )aQ  
        Args:
            img: channel first array, must have shape 2D: (nchannels, H, W), or 3D: (nchannels, H, W, D).
            mode: {``"nearest"``, ``"nearest-exact"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``,
                ``"area"``}, the interpolation mode. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            padding_mode: available modes for numpy array:{``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
                ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                available modes for PyTorch Tensor: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
                One of the listed string values or a user supplied function. Defaults to ``"constant"``.
                The mode to pad data after zooming.
                See also: https://numpy.org/doc/1.18/reference/generated/numpy.pad.html
                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
            align_corners: This only has an effect when mode is
                'linear', 'bilinear', 'bicubic' or 'trilinear'. Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                If None, use the data type of input data.
            randomize: whether to execute `randomize()` function first, default to True.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        )r   Nrn  )r4  rr   rs   rt   ru   rv   Tr\  )rW  rv   rX  r@   r   r   r1  r]   rD  r4  rr   rs   rt   ru   rC  r^  )ry   r   rr   rs   rt   ru   rW  rv   r   r   r/  rz   rz   r{   r   =  s(    "
zRandZoom.__call__r   c                 C  s0   |  |}|tj s|S t| j||tj S r   )r   rI   r_  r]   rD  r  r   ro  rz   rz   r{   r   t  s    

zRandZoom.inverse)NNNNTN)r   r   r   r   r]   r   r;   r0  r<   rI  r   r1  rx   rW  r   r   r   rz   rz   r   r{   rd     s*   )"      7c                   @  sd   e Zd ZdZejgZdddddejdddf	dddddddddd	d

ddZ	ddddddddZ
dS )re   a  
    Affine transforms on the coordinates.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.

    Args:
        rotate_params: a rotation angle in radians, a scalar for 2D image, a tuple of 3 floats for 3D.
            Defaults to no rotation.
        shear_params: shearing factors for affine matrix, take a 3D affine as example::

            [
                [1.0, params[0], params[1], 0.0],
                [params[2], 1.0, params[3], 0.0],
                [params[4], params[5], 1.0, 0.0],
                [0.0, 0.0, 0.0, 1.0],
            ]

            a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing.
        translate_params: a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Translation is in
            pixel/voxel relative to the center of the input image. Defaults to no translation.
        scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,
            a tuple of 3 floats for 3D. Defaults to `1.0`.
        dtype: data type for the grid computation. Defaults to ``float32``.
            If ``None``, use the data type of input data (if `grid` is provided).
        device: device on which the tensor will be allocated, if a new grid is generated.
        align_corners: Defaults to False.
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
        affine: If applied, ignore the params (`rotate_params`, etc.) and use the
            supplied matrix. Should be square with each side = num of image spatial
            dimensions + 1.
        lazy: a flag to indicate whether this transform should execute lazily or not.
            Defaults to False
    NFr  torch.device | Noner   rp   NdarrayOrTensor | Noner   )
rotate_paramsshear_paramstranslate_paramsscale_paramsdeviceru   rt   r   rv   r   c
                 C  sf   t j| |	d || _|| _|| _|| _|| _t|tj	}
|
tj
tjd fkrN|
ntj| _|| _|| _d S r
  )r(   rx   r~  r  r  r  r  rN   r   r   float16r   r1  ru   rt   r   )ry   r~  r  r  r  r  ru   rt   r   rv   r(  rz   rz   r{   rx     s    zAffineGrid.__init__Sequence[int] | Noner}   r   z(tuple[torch.Tensor | None, torch.Tensor])r   gridrv   r   c                 C  s   |dkr| j n|}|sz|dkrD|dkr.tdt|| jd| jd}n|}| jpR|j}t||t d}|j}t|jd }n| j}t|}t	j
}	| jdkr"tj|d |d}
| jr|
t|| j||	d }
| jr|
t|| j||	d }
| jr|
t|| j||	d }
| jr(|
t|| j||	d }
n| j}
t||
}
|r@d|
fS t|
|j|jd	d
}
| jrt|dd |jdd D ||	d}t||
d }|
| ||jd df dgt|jdd  }n2|
||jd df dgt|jdd  }||
fS )a  
        The grid can be initialized with a `spatial_size` parameter, or provided directly as `grid`.
        Therefore, either `spatial_size` or `grid` must be provided.
        When initialising from `spatial_size`, the backend "torch" will be used.

        Args:
            spatial_size: output grid size.
            grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        Raises:
            ValueError: When ``grid=None`` and ``spatial_size=None``. Incompatible values.

        Nz5Incompatible values: grid=None and spatial_size=None.r   )r  r   ru   )ru   r  r   r  r  r   F)r  ru   r  c                 S  s$   g | ]}t |d t |d d  qS )rM  r   )r   r  drz   rz   r{   r{    s     z'AffineGrid.__call__.<locals>.<listcomp>r   r   )rv   r   r-   r  ru   r@   r   r   r   rJ   r   r   r   r   r~  r.   r  r0   r  r1   r  r/   r   rt   r>   viewr   )ry   r   r  rv   r   grid_r(  _devicespatial_dims_br   scrz   rz   r{   r     sP    
   82zAffineGrid.__call__)NNN)r   r   r   r   rJ   r   r   r   r1  rx   r   rz   rz   rz   r{   re   {  s    #"     c                
   @  s   e Zd ZdZejZdddddejdfddddddddd	d
dZd!ddddZ	d"dddddZ
d#ddddddddZdddd ZdS )$rf   z
    Generate randomised affine grid.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    NF	RandRanger|  r   rp   r   )rotate_rangeshear_rangetranslate_rangescale_ranger  ru   rv   r   c                 C  sp   t j| |d t|| _t|| _t|| _t|| _d| _d| _d| _	d| _
|| _|| _tjdtjd| _dS )a  
        Args:
            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then
                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter
                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.
                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be
                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`
                for dim0 and nothing for the remaining dimensions.
            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select
                shearing factors(a tuple of 2 floats for 2D, a tuple of 6 floats for 3D) for affine matrix,
                take a 3D affine as example::

                    [
                        [1.0, params[0], params[1], 0.0],
                        [params[2], 1.0, params[3], 0.0],
                        [params[4], params[5], 1.0, 0.0],
                        [0.0, 0.0, 0.0, 1.0],
                    ]

            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly
                select voxels to translate for every spatial dims.
            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
                This allows 0 to correspond to no change (i.e., a scaling of 1.0).
            device: device to store the output grid data.
            dtype: data type for the grid computation. Defaults to ``np.float32``.
                If ``None``, use the data type of input data (if `grid` is provided).
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False

        See also:
            - :py:meth:`monai.transforms.utils.create_rotate`
            - :py:meth:`monai.transforms.utils.create_shear`
            - :py:meth:`monai.transforms.utils.create_translate`
            - :py:meth:`monai.transforms.utils.create_scale`

        rw   Nr   r   )r(   rx   rA   r  r  r  r  r~  r  r  r  r  ru   r   r   r   r   )ry   r  r  r  r  r  ru   rv   rz   rz   r{   rx     s    /



zRandAffineGrid.__init__ra  r>  )
add_scalarc                 C  sz   g }|D ]l}t |rRt|dkr0td| d|| j|d |d |  q|d k	r|| j| ||  q|S )NrM  zBIf giving range as [min,max], should have 2 elements per dim, got r   r   r   )rE   r   r   r   r   rm  )ry   Zparam_ranger  	out_paramfrz   rz   r{   _get_rand_param?  s    "zRandAffineGrid._get_rand_paramrU  r   c                 C  s>   |  | j| _|  | j| _|  | j| _|  | jd| _d S )Nrw  )	r  r  r~  r  r  r  r  r  r  r   rz   rz   r{   rW  J  s    zRandAffineGrid.randomizeTr  r}  r   r|   )r   r  rW  rv   r   c              	   C  sl   |r|    |dkr| jn|}t| j| j| j| j| j| j|d}|rX|||d | _	dS |||\}| _	|S )aO  
        Args:
            spatial_size: output grid size.
            grid: grid to be transformed. Shape must be (3, H, W) for 2D or (4, H, W, D) for 3D.
            randomize: boolean as to whether the grid parameters governing the grid should be randomized.
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.

        Returns:
            a 2D (3xHxW) or 3D (4xHxWxD) grid.
        N)r~  r  r  r  r  ru   rv   r   )
rW  rv   re   r~  r  r  r  r  ru   r   )ry   r   r  rW  rv   r   affine_gridZ_gridrz   rz   r{   r   P  s"    	zRandAffineGrid.__call__r}   )r   c                 C  s   | j S )z3Get the most recently applied transformation matrixr   rz  rz   rz   r{   get_transformation_matrixv  s    z(RandAffineGrid.get_transformation_matrix)ra  )N)NNTN)r   r   r   r   re   r   r   r1  rx   r  rW  r   r  rz   rz   rz   r{   rf     s$   >    &c                   @  sN   e Zd ZdZejgZdddddddd	Zd
ddddZd
ddddZ	dS )rg   z+
    Generate random deformation grid.
    Nr2  tuple[float, float]r|  r   )spacingmagnitude_ranger  r   c                 C  s    || _ || _d| _|  || _dS )a  
        Args:
            spacing: spacing of the grid in 2D or 3D.
                e.g., spacing=(1, 1) indicates pixel-wise deformation in 2D,
                spacing=(1, 1, 1) indicates voxel-wise deformation in 3D,
                spacing=(2, 2) indicates deformation field defined on every other pixel in 2D.
            magnitude_range: the random offsets will be generated from
                `uniform[magnitude[0], magnitude[1])`.
            device: device to store the output grid data.
        rw  N)r  	magnituderand_magr  )ry   r  r  r  rz   rz   r{   rx     s
    zRandDeformGrid.__init__Sequence[int]	grid_sizer   c                 C  sJ   | j jt|gt| djtjdd| _| j | j	d | j	d | _
d S )N)r   Fr   r   r   )r   normalr   r   astyper   r1  random_offsetrm  r  r  ry   r  rz   rz   r{   rW    s    *zRandDeformGrid.randomizer|   r   r   c                 C  sp   t | jdt| | _t|| j| jdd}| |jdd  t| j| j	 |^}}|dt|  |7  < |S )zK
        Args:
            spatial_size: spatial size of the grid.
        rw  r   r  r   N)
rD   r  r   r,   r  rW  r   r>   r  r  )ry   r   Zcontrol_grid_offsetr  rz   rz   r{   r     s    zRandDeformGrid.__call__)N)
r   r   r   r   rJ   r   r   rx   rW  r   rz   rz   rz   r{   rg   {  s    c                	   @  sb   e Zd ZejejgZeje	j
dddejfddddddd	d
ddZdddddddddddZdS )rh   TNFrn   ro   rp   r|  r   r   )rr   rs   norm_coordsr  rt   ru   r   c                 C  s(   || _ || _|| _|| _|| _|| _dS )a)
  
        computes output image using values from `img`, locations from `grid` using pytorch.
        supports spatially 2D or 3D (num_channels, H, W[, D]).

        Args:
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `USE_COMPILED` is `True`, this argument uses
                ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
                See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `USE_COMPILED` is `True`, this argument uses an integer to represent the padding mode.
                See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            norm_coords: whether to normalize the coordinates from `[-(size-1)/2, (size-1)/2]` to
                `[0, size - 1]` (for ``monai/csrc`` implementation) or
                `[-1, 1]` (for torch ``grid_sample`` implementation) to be compatible with the underlying
                resampling API.
            device: device on which the tensor will be allocated.
            align_corners: Defaults to False.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            dtype: data type for resampling computation. Defaults to ``float64`` for best precision.
                If ``None``, use the data type of input data. To be compatible with other modules,
                the output data type is always `float32`.

        N)rr   rs   r  r  rt   ru   )ry   rr   rs   r  r  rt   ru   rz   rz   r{   rx     s    +zResample.__init__r|   r}   r~   r   r   )r   r  rr   rs   ru   rt   r   c                   s  t |t d}|dkr|S t|tjr,|jn| j}|p@| jp@|j}|dkrP| jn|}	t|tj||d^}
}t	t
t|
tr|
 n|
jdd d}t|dkr| jn||dkr| jn|dtd\} }ts|tjkrt|d| |
|jdd^}}t|tjr$| | kr$|jtjd	}t|
jdd|  D ]z\}}td
|}|d d }| jr|	r~|d | ||  | n
|| | ||< n"|	r:|d | || d  ||< q:tr|tjkrt|dd}t|
d|d|
d dd }n|tjkrh|
j }|rt!nt"|
dd}t|||jdd^}|rFt#nt$j%|rVt&nt'( fdd|D }t||
d }nt|t)t*|d dd dd}t||
ddd d}t|tjr| | kr|jtjd	}| jr6t|
j|d dd D ]*\}}|dd|f  dtd
| 9  < q
tj+j,j-|
d| |	t.j/kr\dn|	dd }t||t'j0d^}}|S )a	  
        Args:
            img: shape must be (num_channels, H, W[, D]).
            grid: shape must be (3, H, W) for 2D or (4, H, W, D) for 3D.
                if ``norm_coords`` is True, the grid values must be in `[-(size-1)/2, (size-1)/2]`.
                if ``USE_COMPILED=True`` and ``norm_coords=False``, grid values must be in `[0, size-1]`.
                if ``USE_COMPILED=False`` and ``norm_coords=False``, grid values must be in `[-1, 1]`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `USE_COMPILED` is `True`, this argument uses
                ``"nearest"``, ``"bilinear"``, ``"bicubic"`` to indicate 0, 1, 3 order interpolations.
                See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `USE_COMPILED` is `True`, this argument uses an integer to represent the padding mode.
                See also: https://docs.monai.io/en/stable/networks.html#grid-pull (experimental).
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            dtype: data type for resampling computation. Defaults to ``self.dtype``.
                To be compatible with other modules, the output data type is always `float32`.
            align_corners: Defaults to ``self.align_corners``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

        See also:
            :py:const:`monai.config.USE_COMPILED`
        r  N)ru   r  r   rQ  )r   use_compiledT)ru   wrap_sequence)memory_formatrM         @      ?r   r   )boundextrapolateinterpolation)r  c                   s   g | ]}| d qS ))orderrr   rz   )r  c_interp_modeZ
_map_coordr6  Zgrid_nprz   r{   r{  /  s     z%Resample.__call__.<locals>.<listcomp>.)rr   rs   rt   r<  )1r@   r   r   r   r   r  ru   rt   rM   r   r   r   r   r   r3   rr   rs   r   rJ   r   r>   data_ptrclonecontiguous_formatr   r   r  r   r8   r   r=  tois_cudar=   r?   cupy_ndinp_ndimap_coordinatesrQ   r   stackr   r,  nn
functionalgrid_samplerI   r   r1  )ry   r   r  rr   rs   ru   rt   r  r(  r'  rA  r  r   r   Zgrid_tidim_dimtr   r  img_npZout_valrz   r  r{   r     s~    )* 
0    $zResample.__call__)NNNNN)r   r   r   rJ   r   r   r   r9   r   r:   r   r   r   rx   r   rz   rz   rz   r{   rh     s   5     c                   @  s   e Zd ZdZeeejeej@ Zdddddde	j
ejddejdddfddddddddd	d
dd	d	d	ddddZejjd	ddddZd dddddddddZedd ZdddddZdS )!ri   a/  
    Transform ``img`` given the affine parameters.
    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    NFr  r}  r  rn   ro   rp   r|  r   r   )r~  r  r  r  r   r   rr   rs   r9  r  ru   rt   
image_onlyrv   r   c                 C  sb   t j| |d t||||||||
|d	| _|| _|	 | _t| j|
||d| _|| _|| _	|| _
dS )a"  
        The affine transformations are applied in rotate, shear, translate, scale order.

        Args:
            rotate_params: a rotation angle in radians, a scalar for 2D image, a tuple of 3 floats for 3D.
                Defaults to no rotation.
            shear_params: shearing factors for affine matrix, take a 3D affine as example::

                [
                    [1.0, params[0], params[1], 0.0],
                    [params[2], 1.0, params[3], 0.0],
                    [params[4], params[5], 1.0, 0.0],
                    [0.0, 0.0, 0.0, 1.0],
                ]

                a tuple of 2 floats for 2D, a tuple of 6 floats for 3D. Defaults to no shearing.
            translate_params: a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Translation is in
                pixel/voxel relative to the center of the input image. Defaults to no translation.
            scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D,
                a tuple of 3 floats for 3D. Defaults to `1.0`.
            affine: If applied, ignore the params (`rotate_params`, etc.) and use the
                supplied matrix. Should be square with each side = num of image spatial
                dimensions + 1.
            spatial_size: output image spatial size.
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if some components of the `spatial_size` are non-positive values, the transform will use the
                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
                to `(32, 64)` if the second spatial dimension size of img is `64`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"reflection"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            normalized: indicating whether the provided `affine` is defined to include a normalization
                transform converting the coordinates from `[-(size-1)/2, (size-1)/2]` (defined in ``create_grid``) to
                `[0, size - 1]` or `[-1, 1]` in order to be compatible with the underlying resampling API.
                If `normalized=False`, additional coordinate normalization will be applied before resampling.
                See also: :py:func:`monai.networks.utils.normalize_transform`.
            device: device on which the tensor will be allocated.
            dtype: data type for resampling computation. Defaults to ``float32``.
                If ``None``, use the data type of input data. To be compatible with other modules,
                the output data type is always `float32`.
            align_corners: Defaults to False.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            image_only: if True return only the image volume, otherwise return (image, affine).
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False
        rw   )	r~  r  r  r  r   ru   rt   r  rv   )r  r  ru   rt   N)r(   rx   re   r  r  Z
norm_coordrh   	resamplerr   rr   rs   )ry   r~  r  r  r  r   r   rr   rs   r9  r  ru   rt   r  rv   rz   rz   r{   rx   P  s$    IzAffine.__init__r   c                 C  s   || j _|| _d S r   )r  rv   r   r   rz   rz   r{   rv     s    zAffine.lazyr|   r~   r   r   z3torch.Tensor | tuple[torch.Tensor, NdarrayOrTensor])r   r   rr   rs   rv   r   c                 C  s   t |t d}t|tr | n|jdd }t|dkr>| jn||}|dkrT| jn|}|dk	rd|n| j	}	|dk	rv|n| j
}
| j||d\}}t|||| j||	|
d| j||  dS )a  
        Args:
            img: shape must be (num_channels, H, W[, D]),
            spatial_size: output image spatial size.
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if `img` has two spatial dimensions, `spatial_size` should have 2 elements [h, w].
                if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d].
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        r  r   Nr   rv   Tr   )r@   r   r   r   r   r   rD   r   rv   rr   rs   r  r   r  r  r   )ry   r   r   rr   rs   rv   r%  r$  r   r&  r6  r  r   rz   rz   r{   r     s(      zAffine.__call__c                 C  sf   t |}t||}t|dd |d | D }t|dd |d | D }|t|tjd  | }|S )Nc                 S  s   g | ]}t |d  d qS r   rM  r>  r  rz   rz   r{   r{    s     z+Affine.compute_w_affine.<locals>.<listcomp>c                 S  s   g | ]}t |d   d qS r  r  r  rz   rz   r{   r{    s     r   )r  r   r1   rM   r   r   )clsspatial_rankrB  r%  r$  rZshift_1Zshift_2rz   rz   r{   compute_w_affine  s    
zAffine.compute_w_affiner   c                 C  s  |  |}|tj }|tj d }|tj d }|tj d }|tj d }tt|}t|||jdd }t||d}	|	|\}
}| j	||
|||d}t
|tst|}|j|_t| tjd }ttt|d	 ||jd	d  ||^}}| j|  _|S )
Nr   rr   rs   rt   r   r   )r   rt   )rt   r   )r   rI   r   r   r7   r?   r>   ru   re   r  r   r   r   rM   r   r   r   ri   r  r   r   r   )ry   r   r   r.  
fwd_affinerr   rs   rt   
inv_affiner  r  r  r   r   r/  rz   rz   r{   r     s*    


  zAffine.inverse)NNNN)r   r   r   r   r   setre   r   rh   r9   r   r:   
REFLECTIONr   r1  rx   r(   rv   r   r   classmethodr  r   rz   rz   rz   r{   ri   E  s6   ,\    6
c                      s   e Zd ZdZejZddddddejej	dddfdddddddd	d
dd
ddddZ
ejjd
ddddZd
dddZdd
dddZd-ddd d fddZd.ddd  fd!d"Zd/d$dd%d&d
d'd$d(d)d*Zd$d$d d+d,Z  ZS )0rj   a  
    Random affine transform.
    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.

    This transform is capable of lazy execution. See the :ref:`Lazy Resampling topic<lazy_resampling>`
    for more information.
    rP  NFr>  r  r  rn   ro   rp   r|  r   )rR  r  r  r  r  r   rr   rs   
cache_gridr  rv   r   c                 C  sd   t | | tj| |d t|||||
|d| _t|
d| _|| _|	| _| 	|| _
|| _|| _dS )a  
        Args:
            prob: probability of returning a randomized affine grid.
                defaults to 0.1, with 10% chance returns a randomized grid.
            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then
                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter
                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.
                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be
                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`
                for dim0 and nothing for the remaining dimensions.
            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select
                shearing factors(a tuple of 2 floats for 2D, a tuple of 6 floats for 3D) for affine matrix,
                take a 3D affine as example::

                    [
                        [1.0, params[0], params[1], 0.0],
                        [params[2], 1.0, params[3], 0.0],
                        [params[4], params[5], 1.0, 0.0],
                        [0.0, 0.0, 0.0, 1.0],
                    ]

            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly
                select pixel/voxel to translate for every spatial dims.
            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
                This allows 0 to correspond to no change (i.e., a scaling of 1.0).
            spatial_size: output image spatial size.
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if some components of the `spatial_size` are non-positive values, the transform will use the
                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
                to `(32, 64)` if the second spatial dimension size of img is `64`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``bilinear``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``reflection``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            cache_grid: whether to cache the identity sampling grid.
                If the spatial size is not dynamically defined by input image, enabling this option could
                accelerate the transform.
            device: device on which the tensor will be allocated.
            lazy: a flag to indicate whether this transform should execute lazily or not.
                Defaults to False

        See also:
            - :py:class:`RandAffineGrid` for the random affine parameters configurations.
            - :py:class:`Affine` for the affine transformation parameters configurations.

        rw   r  r  r  r  r  rv   r  N)r*   rx   r(   rf   rand_affine_gridrh   r  r   r  _init_identity_cache_cached_gridrr   rs   )ry   rR  r  r  r  r  r   rr   rs   r  r  rv   rz   rz   r{   rx   	  s     FzRandAffine.__init__r   c                 C  s   || _ || j_d S r   )r   r  rv   r   rz   rz   r{   rv   m	  s    zRandAffine.lazyrw   c                 C  s   |rdS | j dkr&| jr"td dS t| j }t|}|t|dg| ks`|t|dg| kr~| jrztd| j  d dS t|| jj	ddS )	za
        Create cache of the identity grid if cache_grid=True and spatial_size is known.
        Nz_cache_grid=True is not compatible with the dynamic spatial_size, please specify 'spatial_size'.r   rM  zNcache_grid=True is not compatible with the dynamic spatial_size 'spatial_size=z!', please specify 'spatial_size'.r   r   r  r   )
r   r  r   r   rA   r   rD   r-   r  r  )ry   rv   Z_sp_size_ndimrz   rz   r{   r  r	  s"    

(zRandAffine._init_identity_cacher  r  c                 C  sj   |rdS t |}|t|dg| ks8|t|dg| krHtd| d| jdkrdt|| jjddS | jS )z
        Return a cached or new identity grid depends on the availability.

        Args:
            spatial_size: non-dynamic spatial size
        Nr   rM  z(spatial_size should not be dynamic, got r   r   r  )r   rD   r   r  r-   r  r  )ry   r   rv   r  rz   rz   r{   get_identity_grid	  s     zRandAffine.get_identity_grid
int | Nonenp.random.RandomState | Noneseedstater   c                   s    | j || t || | S r   r  set_random_stater   ry   r  r  r   rz   r{   r  	  s    zRandAffine.set_random_staterU  r   c                   s$   t  d  | jsd S | j  d S r   )r   rW  rX  r  r   r   rz   r{   rW  	  s    zRandAffine.randomizeTr|   r~   r   r   )r   r   rr   rs   rW  rv   r   c                 C  sJ  |r|    t|tr| n|jdd }t|dkr<| jn||}	| jpT|	t|k}
|dk	rb|n| j	}|dk	rt|n| j
}|dkr| jn|}t|t d}|r| jr|dkr| j|	|dd | j }n$ttt|	d || jjdd }n8|dkr| |	|}| jr| j|||d}| j }t|||| j|	|||
d||  d	S )
a4  
        Args:
            img: shape must be (num_channels, H, W[, D]),
            spatial_size: output image spatial size.
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if `img` has two spatial dimensions, `spatial_size` should have 2 elements [h, w].
                if `img` has three spatial dimensions, `spatial_size` should have 3 elements [h, w, d].
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            randomize: whether to execute `randomize()` function first, default to True.
            grid: precomputed grid to be used (mainly to accelerate `RandAffined`).
            lazy: a flag to indicate whether this transform should execute lazily or not
                during this call. Setting this to False or True overrides the ``lazy`` flag set
                during initialization for this call. Defaults to None.
        r   Nr  T)rW  rv   r   r   )r  rW  rv   r   )rW  r   r   r   r   rD   r   rX  rA   rr   rs   rv   r@   r   r  r  r>   r   r   r   ru   r  r   r  r   )ry   r   r   rr   rs   rW  r  rv   Zori_sizer$  do_resamplingr&  r6  r   r   rz   rz   r{   r   	  sB    $ &

zRandAffine.__call__c                 C  s  |  |}|tj d s|S |tj }t||jdd  }|tj d }|tj d }|tj d }tt|}t|||j	dd }t
|d}||\}	}
| ||	||}t|tst|}|j|_t| tjd }ttt|d ||jdd  ||^}}
| j|  _|S )	Nr  r   r   rr   rs   r   r   r  )r   rI   r   r   rD   r   r7   r?   r>   ru   re   r  r   r   r   rM   r   r   r   ri   r  r   r   )ry   r   r   r.  r  rr   rs   r  r  r  r  r   r   r/  rz   rz   r{   r   	  s.    



  zRandAffine.inverse)NN)N)NNNTNN)r   r   r   r   ri   r   r9   r   r:   r  rx   r(   rv   r   r  r  r  rW  r   r   r   rz   rz   r   r{   rj   
	  s8   &X	      Jc                      s   e Zd ZdZejZddddddejej	df	ddddddddd	d
dddddZ
d#ddd d fddZdd Zddd fddZd$ddddddd d!d"Z  ZS )%rk   z
    Random elastic deformation and affine in 2D.
    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.

    rP  Nrb  r  r>  r  ztuple[int, int] | int | Nonern   ro   r|  r   )r  r  rR  r  r  r  r  r   rr   rs   r  r   c                 C  sZ   t | | t|||d| _t|||||dd| _t|d| _|| _|| _	|	| _
|
| _dS )al  
        Args:
            spacing : distance in between the control points.
            magnitude_range: the random offsets will be generated from ``uniform[magnitude[0], magnitude[1])``.
            prob: probability of returning a randomized elastic transform.
                defaults to 0.1, with 10% chance returns a randomized elastic transform,
                otherwise returns a ``spatial_size`` centered area extracted from the input image.
            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then
                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter
                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.
                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be
                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`
                for dim0 and nothing for the remaining dimensions.
            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select
                shearing factors(a tuple of 2 floats for 2D) for affine matrix, take a 2D affine as example::

                    [
                        [1.0, params[0], 0.0],
                        [params[1], 1.0, 0.0],
                        [0.0, 0.0, 1.0],
                    ]

            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly
                select pixel to translate for every spatial dims.
            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
                This allows 0 to correspond to no change (i.e., a scaling of 1.0).
            spatial_size: specifying output image spatial size [h, w].
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if some components of the `spatial_size` are non-positive values, the transform will use the
                corresponding components of img size. For example, `spatial_size=(32, -1)` will be adapted
                to `(32, 64)` if the second spatial dimension size of img is `64`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"reflection"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            device: device on which the tensor will be allocated.

        See also:
            - :py:class:`RandAffineGrid` for the random affine parameters configurations.
            - :py:class:`Affine` for the affine transformation parameters configurations.

        )r  r  r  Fr  r  N)r*   rx   rg   deform_gridrf   r  rh   r  r  r   rr   rs   )ry   r  r  rR  r  r  r  r  r   rr   rs   r  rz   rz   r{   rx   
  s    BzRand2DElastic.__init__r  r  r  c                   s.   | j || | j|| t || | S r   )r  r  r  r   r  r   rz   r{   r  l
  s    zRand2DElastic.set_random_statec                 C  s"   || j _|| j_|| j_|| _d S r   )r  r  r  r  ry   r  rz   rz   r{   
set_devicer
  s    zRand2DElastic.set_devicer  r  c                   s0   t  d  | jsd S | j| | j  d S r   )r   rW  rX  r  r  )ry   r   r   rz   r{   rW  x
  s
    zRand2DElastic.randomizeTr|   r~   r   rp   r   r   rr   rs   rW  r   c           
      C  s   t |dkr| jn||jdd }|r2| j|d | jr| j|d}| j|d}tjj	j
d|dtt| jjtjjdd}t|d	|d }n.t|tjr|jn| j}ttjt||d
d}| j|||dk	r|n| j|dk	r|n| jd}	|	S )a"  
        Args:
            img: shape must be (num_channels, H, W),
            spatial_size: specifying output image spatial size [h, w].
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            randomize: whether to execute `randomize()` function first, default to True.
        Nr   r;  r  Tr   F)recompute_scale_factorinputscale_factorrr   rt   )roi_sizer   r  rr   rs   )rD   r   r   rW  rX  r  r  r   r  r  interpolater=  r   rA   r  r;   BICUBICvaluer   r   r   r  r   r-   r  rr   rs   )
ry   r   r   rr   rs   rW  r$  r  r  r   rz   rz   r{   r   
  s.    "zRand2DElastic.__call__)NN)NNNTr   r   r   r   rh   r   r9   r   r:   r  rx   r  r  rW  r   r   rz   rz   r   r{   rk   
  s(   &S
    c                      s   e Zd ZdZejZddddddejej	df	dddddddddd	d
ddddZ
d"ddd d fddZdd Zddd fddZd#dddddddd d!Z  ZS )$rl   z
    Random elastic deformation and affine in 3D.
    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/0.6.0/modules/transforms_demo_2d.ipynb.

    rP  Nr  r>  r  z!tuple[int, int, int] | int | Nonern   ro   r|  r   )sigma_ranger  rR  r  r  r  r  r   rr   rs   r  r   c                 C  sf   t | | t|||||dd| _t|d| _|| _|| _|| _|	| _	|
| _
|| _|  d| _d| _dS )ad  
        Args:
            sigma_range: a Gaussian kernel with standard deviation sampled from
                ``uniform[sigma_range[0], sigma_range[1])`` will be used to smooth the random offset grid.
            magnitude_range: the random offsets on the grid will be generated from
                ``uniform[magnitude[0], magnitude[1])``.
            prob: probability of returning a randomized elastic transform.
                defaults to 0.1, with 10% chance returns a randomized elastic transform,
                otherwise returns a ``spatial_size`` centered area extracted from the input image.
            rotate_range: angle range in radians. If element `i` is a pair of (min, max) values, then
                `uniform[-rotate_range[i][0], rotate_range[i][1])` will be used to generate the rotation parameter
                for the `i`th spatial dimension. If not, `uniform[-rotate_range[i], rotate_range[i])` will be used.
                This can be altered on a per-dimension basis. E.g., `((0,3), 1, ...)`: for dim0, rotation will be
                in range `[0, 3]`, and for dim1 `[-1, 1]` will be used. Setting a single value will use `[-x, x]`
                for dim0 and nothing for the remaining dimensions.
            shear_range: shear range with format matching `rotate_range`, it defines the range to randomly select
                shearing factors(a tuple of 6 floats for 3D) for affine matrix, take a 3D affine as example::

                    [
                        [1.0, params[0], params[1], 0.0],
                        [params[2], 1.0, params[3], 0.0],
                        [params[4], params[5], 1.0, 0.0],
                        [0.0, 0.0, 0.0, 1.0],
                    ]

            translate_range: translate range with format matching `rotate_range`, it defines the range to randomly
                select voxel to translate for every spatial dims.
            scale_range: scaling range with format matching `rotate_range`. it defines the range to randomly select
                the scale factor to translate for every spatial dims. A value of 1.0 is added to the result.
                This allows 0 to correspond to no change (i.e., a scaling of 1.0).
            spatial_size: specifying output image spatial size [h, w, d].
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
                if some components of the `spatial_size` are non-positive values, the transform will use the
                corresponding components of img size. For example, `spatial_size=(32, 32, -1)` will be adapted
                to `(32, 32, 64)` if the third spatial dimension size of img is `64`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"reflection"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            device: device on which the tensor will be allocated.

        See also:
            - :py:class:`RandAffineGrid` for the random affine parameters configurations.
            - :py:class:`Affine` for the affine transformation parameters configurations.

        Fr  r  rw  N)r*   rx   rf   r  rh   r  r  r  r   rr   rs   r  r  sigma)ry   r  r  rR  r  r  r  r  r   rr   rs   r  rz   rz   r{   rx   
  s&    EzRand3DElastic.__init__r  r  r  c                   s    | j || t || | S r   r  r  r   rz   r{   r    s    zRand3DElastic.set_random_statec                 C  s   || j _|| j_|| _d S r   )r  r  r  r  rz   rz   r{   r    s    zRand3DElastic.set_devicer  r  c                   s   t  d  | jsd S | jdddgt| jtjdd| _	| j| j
d | j
d | _| j| jd | jd | _| j  d S )Ng      rw  rQ  Fr  r   r   )r   rW  rX  r   rm  r   r  r   r1  rand_offsetr  r  r  r  r  r  r   rz   r{   rW  $  s    (zRand3DElastic.randomizeTr|   r~   r   rp   r  c                 C  s   t |dkr| jn||jdd }|r2| j|d t|tjrD|jn| j}t||dd}| j	r| j
dkrptdtd| jdj|d	}	tj| j
|d	d
}
|dd  |	|
d
 | j 7  < | j|d}| j|||dk	r|n| j|dk	r|n| jd}|S )a3  
        Args:
            img: shape must be (num_channels, H, W, D),
            spatial_size: specifying spatial 3D output image spatial size [h, w, d].
                if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1,
                the transform will use the spatial size of `img`.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            randomize: whether to execute `randomize()` function first, default to True.
        Nr   )r  r   r  zrand_offset is not initialized.rQ  g      @r  r   r  r  )rD   r   r   rW  r   r   r   r  r-   rX  r  r   r   r  r  r   r=  r  r  r  rr   rs   )ry   r   r   rr   rs   rW  r$  r  r  gaussianr   r   rz   rz   r{   r   -  s&    "
"zRand3DElastic.__call__)NN)NNNTr  rz   rz   r   r{   rl   
  s(   &[    c                   @  sP   e Zd ZejgZejej	dfdddddddd	d
Z
dddddddddZdS )rW   Ntuple[int] | intzSequence[Sequence[float]]rn   ro   r|  r   )	num_cellsdistort_stepsrr   rs   r  r   c                 C  s&   t |||d| _|| _|| _|| _dS )a  
        Grid distortion transform. Refer to:
        https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py

        Args:
            num_cells: number of grid cells on each dimension.
            distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
                corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
                Each value in the tuple represents the distort step of the related cell.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            device: device on which the tensor will be allocated.

        )rr   rs   r  N)rh   r  r  r  r  )ry   r  r  rr   rs   r  rz   rz   r{   rx   b  s     zGridDistortion.__init__r|   zSequence[Sequence] | Noner   )r   r  rr   rs   r   c                 C  s`  |dkr| j n|}t|jt|d kr0tdg }t| jt|jd }t|trd|jrdt	
d t|jdd D ]\}}|| }	tj|tjd}
|||  }d}t|| d D ]X}t|| }|| }||kr|}|}n|||	|   }t|||| |
||< |}q|
|d d  }
||
 qvt| }t|t|d f}| j||||d	S )
a,  
        Args:
            img: shape must be (num_channels, H, W[, D]).
            distort_steps: This argument is a list of tuples, where each tuple contains the distort steps of the
                corresponding dimensions (in the order of H, W[, D]). The length of each tuple equals to `num_cells + 1`.
                Each value in the tuple represents the distort step of the related cell.
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html

        Nr   zKthe spatial size of `img` does not match with the length of `distort_steps`NMetaTensor img has pending operations, transform may return incorrect results.r   r   rw  r  )r  rr   rs   )r  r   r   r   rB   r  r   r   pending_operationsr   r   r   r   zerosr1  r,  r  linspacer   r   r  	ones_liker  )ry   r   r  rr   rs   
all_rangesr  Zdim_idxdim_sizeZdim_distort_stepsranges	cell_sizeprevr   startendcurcoordsr  rz   rz   r{   r     s4    
zGridDistortion.__call__)NNN)r   r   r   rJ   r   r   r9   r   r:   r   rx   r   rz   rz   rz   r{   rW   _  s   (   c                	      sp   e Zd ZejgZdddejej	dfddddd	d
ddddZ
ddd fddZdddddddddZ  ZS )rb      rP  )gQgQ?Nr  r>  rb  rn   ro   r|  r   )r  rR  distort_limitrr   rs   r  r   c                 C  sn   t | | || _t|ttfr<t| |t| |f| _nt|t|f| _d| _	t
|| j	|||d| _dS )a  
        Random grid distortion transform. Refer to:
        https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/transforms.py

        Args:
            num_cells: number of grid cells on each dimension.
            prob: probability of returning a randomized grid distortion transform. Defaults to 0.1.
            distort_limit: range to randomly distort.
                If single number, distort_limit is picked from (-distort_limit, distort_limit).
                Defaults to (-0.03, 0.03).
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"border"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            device: device on which the tensor will be allocated.

        )r  )r  r  rr   rs   r  N)r*   rx   r  r   r  r>  r   r   r  r  rW   grid_distortion)ry   r  rR  r  rr   rs   r  rz   rz   r{   rx     s    "    zRandGridDistortion.__init__r  )r   r   c                   s>   t  d   jsd S t fddt jt|D  _d S )Nc                 3  s8   | ]0}t d  jj jd  jd |d d V  qdS )rw  r   r   )rk  rl  r   N)r!  r   rm  r  )r  Zn_cellsrz  rz   r{   r    s   z/RandGridDistortion.randomize.<locals>.<genexpr>)r   rW  rX  r!  rB   r  r   r  )ry   r   r   rz  r{   rW    s    zRandGridDistortion.randomizeTr|   r   rp   )r   rr   rs   rW  r   c                 C  sZ   |r2t |tr|jrtd | |jdd  | jsFt|t	 dS | j
|| j||dS )aJ  
        Args:
            img: shape must be (num_channels, H, W[, D]).
            mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
                Interpolation mode to calculate output values. Defaults to ``self.mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
                and the value represents the order of the spline interpolation.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``self.padding_mode``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
                When `mode` is an integer, using numpy/cupy backends, this argument accepts
                {'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
                See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
            randomize: whether to shuffle the random factors using `randomize()`, default to True.
        r  r   Nr  )r  rr   rs   )r   r   r  r   r   rW  r   rX  r@   r   r  r  )ry   r   rr   rs   rW  rz   rz   r{   r     s    
zRandGridDistortion.__call__)NNT)r   r   r   rJ   r   r   r9   r   r:   r   rx   rW  r   r   rz   rz   r   r{   rb     s   -
     c                   @  sT   e Zd ZdZejejgZddddddZdd	d
ddddZ	ddddddZ
dS )rX   a  
    Split the image into patches based on the provided grid in 2D.

    Args:
        grid: a tuple define the shape of the grid upon which the image is split. Defaults to (2, 2)
        size: a tuple or an integer that defines the output patch sizes.
            If it's an integer, the value will be repeated for each dimension.
            The default is None, where the patch size will be inferred from the grid shape.

    Example:
        Given an image (torch.Tensor or numpy.ndarray) with size of (3, 10, 10) and a grid of (2, 2),
        it will return a Tensor or array with the size of (4, 3, 5, 5).
        Here, if the `size` is provided, the returned shape will be (4, 3, size, size)

    Note: This transform currently support only image with two spatial dimensions.
    rM  rM  NrK  zint | tuple[int, int] | None)r  r   c                 C  s(   || _ |d krd nt|t| j | _d S r   )r  rB   r   r   )ry   r  r   rz   rz   r{   rx   ,  s    zGridSplit.__init__r   z)int | tuple[int, int] | np.ndarray | Nonezlist[NdarrayOrTensor])imager   r   c                 C  s^  |d kr| j nt|t| j}| jdkr6|d kr6|gS t|trP|jrPtd | 	|j
dd  |\}}t|tjrtj}| \}}}	n8t|tjrtjjj}|j\}}}	ntdt| d|\}
}|j
d }||| j||d |d f||
 |	| |||	f}|jd|j
dd   }t|tjr>d	d
 |D }nt|tjrZdd
 |D }|S )N)r   r   r  r   zInput type [z] is not supported.r   r   rM  c                 S  s   g | ]}|  qS rz   )
contiguousr  prz   rz   r{   r{  S  s     z&GridSplit.__call__.<locals>.<listcomp>c                 S  s   g | ]}t |qS rz   )r   ascontiguousarrayr  rz   rz   r{   r{  U  s     )r   )r   rB   r   r  r   r   r  r   r   _get_paramsr   r   r   
as_stridedstrider   r   libstride_tricksstridesr   typer   )ry   r  r   
input_size
split_sizestepsZas_strided_funcZc_stridex_strideZy_strideZx_stepy_step
n_channelsZstrided_imagepatchesrz   rz   r{   r   3  s4    


zGridSplit.__call__zSequence[int] | np.ndarrayz!Sequence[int] | np.ndarray | None
image_sizer   c                   s   dkr*t  fddttjD t fddttjD rbtd  d dt  fddttjD }|fS )	z
        Calculate the size and step required for splitting the image
        Args:
            The size of the input image
        Nc                 3  s    | ]} | j |  V  qd S r   r  r  r  )r$  ry   rz   r{   r  a  s     z(GridSplit._get_params.<locals>.<genexpr>c                 3  s   | ]}|  | kV  qd S r   rz   r%  r#  rz   r{   r  c  s     zThe image size (z+)is smaller than the requested split size ()c                 3  sB   | ]:}j | d kr2 | |  j | d   n | V  qdS )r   Nr  r%  r$  ry   r   rz   r{   r  f  s   )r!  r,  r   r  anyr   )ry   r$  r   r  rz   r'  r{   r  Y  s    ""zGridSplit._get_params)r  N)N)N)r   r   r   r   rJ   r   r   r   rx   r   r  rz   rz   rz   r{   rX     s    &c                	   @  sl   e Zd ZdZejejgZddddddd	dd
ddZddddddZ	ddddddZ
dddddZdS )rY   a  
    Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps.
    It can sort the patches and return all or a subset of them.

    Args:
        patch_size: size of patches to generate slices for, 0 or None selects whole dimension
        offset: offset of starting position in the array, default is 0 for each dimension.
        num_patches: number of patches (or maximum number of patches) to return.
            If the requested number of patches is greater than the number of available patches,
            padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
            When `threshold` is set, this value is treated as the maximum number of patches.
            Defaults to None, which does not limit number of the patches.
        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
        sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
            lowest values (`"min"`), or in their default order (`None`). Default to None.
        threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
            Defaults to no filtering.
        pad_mode: the  mode for padding the input image by `patch_size` to include patches that cross boundaries.
            Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function.
            Defaults to `None`, which means no padding will be applied.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
            requires pytorch >= 1.10 for best compatibility.
        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.
    Returns:
        MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
            with following metadata:

            - `PatchKeys.LOCATION`: the starting location of the patch in the image,
            - `PatchKeys.COUNT`: total number of patches in the image,
            - "spatial_shape": spatial size of the extracted patch, and
            - "offset": the amount of offset for the patches in the image (starting position of the first patch)

    Nra  r  r  r  r2  r   float | None
patch_sizer   num_patchesoverlapsort_fn	thresholdpad_modec           	      K  sZ   t || _|rt |ndt| j | _|| _|| _|| _|| _|rJ| nd | _	|| _
d S )Nr   )rA   r+  r   r   r0  
pad_kwargsr-  r,  lowerr.  r/  )	ry   r+  r   r,  r-  r.  r/  r0  r2  rz   rz   r{   rx     s    
zGridPatch.__init__r   
np.ndarray"tuple[NdarrayOrTensor, np.ndarray]image_np	locationsr   c                 C  sN   t |j}t|ttd|| jk d}t|t	j
d }|| || fS )a  
        Filter the patches and their locations according to a threshold.

        Args:
            image_np: a numpy.ndarray or torch.Tensor representing a stack of patches.
            locations: a numpy.ndarray representing the stack of location of each patch.

        Returns:
            tuple[NdarrayOrTensor, numpy.ndarray]:  tuple of filtered patches and locations.
        r   r   r   )r   r   r6   sumr!  r,  r/  r   rM   r   r   ry   r7  r8  n_dimsr   idx_nprz   rz   r{   filter_threshold  s    
$zGridPatch.filter_thresholdc                 C  s   | j dkr(|d| j }|d| j }n| jdk	rt|j}| j tjkrbt|tt	d|}n:| j tj
krt|tt	d| }ntd| j  d|d| j }t|tjd }|| }|| }||fS )a0  
        Sort the patches based on the sum of their intensity, and just keep `self.num_patches` of them.

        Args:
            image_np: a numpy.ndarray or torch.Tensor representing a stack of patches.
            locations: a numpy.ndarray representing the stack of location of each patch.
        Nr   z2`sort_fn` should be either "min", "max", or None! 
 provided!r   )r.  r,  r   r   rG   MINr5   r9  r!  r,  MAXr   rM   r   r   r:  rz   rz   r{   filter_count  s    


zGridPatch.filter_countr   )r   r   c                 C  s  t |fd| j d| j | jd| jd| j}tt| }t|t	j
rTt	|d nt|d }t	|d dddddf }| jdk	r| ||\}}| jrz| ||\}}| jdkrz| jt| }|dkrz| jdd}|ft|jdd }t|t	j
r2t	j|||jd	}	t	j||	gdd
}n,tj|||j|j|jd}	tj||	gdd}t	j|d|gddggdd}t|tr|jnt }
|j|
tj < t||
tj!< t	"t	#| jt|dfj|
d< | j|
d< t||
d}d|_$|S )a  
        Extract the patches (sweeping the entire image in a row-major sliding-window manner with possible overlaps).

        Args:
            array: a input image as `numpy.ndarray` or `torch.Tensor`

        Return:
            MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
                with defined `PatchKeys.LOCATION` and `PatchKeys.COUNT` metadata.
        r   r1  F)r+  	start_posr-  	copy_backrr   r   r   Nconstant_valuesr   )axis)ru   layoutr  )r  )rD  r   r   )rg  r   T)%r   r+  r   r-  r0  r2  r   r   r   r   r   r  r   r/  r=  r,  rA  r   r   r   fullru   concatenaterF  r  catpadr   r   get_default_metaTrH   LOCATIONCOUNTtiler   is_batch)ry   r   Zpatch_iteratorr"  Zpatched_imager8  paddingrD  Zpadding_shapeZconstant_paddingmetadataoutputrz   rz   r{   r     sV    	( 

"
zGridPatch.__call__)NNra  NNN)r   r   r   r   rJ   r   r   r   rx   r=  rA  r   rz   rz   rz   r{   rY   n  s   (      c                
      sv   e Zd ZdZejejgZdddddddd	dd
 fddZdd Z	dddd fddZ
dddd fddZ  ZS )rZ   a  
    Extract all the patches sweeping the entire image in a row-major sliding-window manner with possible overlaps,
    and with random offset for the minimal corner of the image, (0,0) for 2D and (0,0,0) for 3D.
    It can sort the patches and return all or a subset of them.

    Args:
        patch_size: size of patches to generate slices for, 0 or None selects whole dimension
        min_offset: the minimum range of offset to be selected randomly. Defaults to 0.
        max_offset: the maximum range of offset to be selected randomly.
            Defaults to image size modulo patch size.
        num_patches: number of patches (or maximum number of patches) to return.
            If the requested number of patches is greater than the number of available patches,
            padding will be applied to provide exactly `num_patches` patches unless `threshold` is set.
            When `threshold` is set, this value is treated as the maximum number of patches.
            Defaults to None, which does not limit number of the patches.
        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
        sort_fn: when `num_patches` is provided, it determines if keep patches with highest values (`"max"`),
            lowest values (`"min"`), in random ("random"), or in their default order (`None`). Default to None.
        threshold: a value to keep only the patches whose sum of intensities are less than the threshold.
            Defaults to no filtering.
        pad_mode: the  mode for padding the input image by `patch_size` to include patches that cross boundaries.
            Available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function.
            Defaults to `None`, which means no padding will be applied.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
            requires pytorch >= 1.10 for best compatibility.
        pad_kwargs: other arguments for the `np.pad` or `torch.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.

    Returns:
        MetaTensor: the extracted patches as a single tensor (with patch dimension as the first dimension),
            with following metadata:

            - `PatchKeys.LOCATION`: the starting location of the patch in the image,
            - `PatchKeys.COUNT`: total number of patches in the image,
            - "spatial_shape": spatial size of the extracted patch, and
            - "offset": the amount of offset for the patches in the image (starting position of the first patch)

    Nra  r  r  r  r2  r   r)  )r+  
min_offset
max_offsetr,  r-  r.  r/  r0  c	           
   
     s>   t  jf |d|||||d|	 || _|| _|| _|| _d S )Nrz   r*  )r   rx   rT  rU  r,  r.  )
ry   r+  rT  rU  r,  r-  r.  r/  r0  r2  r   rz   r{   rx   H  s    
zRandGridPatch.__init__c                   s    j d krdt j }nt j t j} jd kr\tdd t|jdd   jD }nt jt j}t fddt||D  _d S )Nr1  c                 s  s   | ]\}}|| V  qd S r   rz   )r  r  r  rz   rz   r{   r  i  s     z*RandGridPatch.randomize.<locals>.<genexpr>r   c                 3  s&   | ]\}} j j||d  dV  qdS )r   rj  N)r   rY  )r  rk  rl  rz  rz   r{   r  m  s     )	rT  r   r+  rB   rU  r!  r   r   r   )ry   r   rT  rU  rz   rz  r{   rW  c  s    

&zRandGridPatch.randomizer   r4  r5  r6  c                   s   | j tjkrT| j|jd }|d | j }t|tj	d }|| }|| }||fS | j d tj
tjfkrztd| j  dt ||S )Nr   z;`sort_fn` should be either "min", "max", "random" or None! r>  )r.  rG   RANDOMr   permutationr   r,  rM   r   r   r?  r@  r   r   rA  )ry   r7  r8  r   r<  r   rz   r{   rA  o  s    zRandGridPatch.filter_countTrp   )r   rW  c                   s   |r|  | t |S r   )rW  r   r   )ry   r   rW  r   rz   r{   r   {  s    
zRandGridPatch.__call__)NNNra  NNN)T)r   r   r   r   rJ   r   r   r   rx   rW  rA  r   r   rz   rz   r   r{   rZ     s   ,       "c                      sn   e Zd ZdZejZdejejdddfddddd	d
dddZ	ddd
d fddZ
dddddddZ  ZS )rm   a  
    Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
    (https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
    First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
    from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
    rP  )r  rw  FNr>  zInterpolateMode | strzSequence[float]r|  r   )rR  downsample_modeupsample_mode
zoom_ranger  r   c                 C  s4   t | | || _|| _|| _|| _|| _d| _dS )a  
        Args:
            prob: probability of performing this augmentation
            downsample_mode: interpolation mode for downsampling operation
            upsample_mode: interpolation mode for upsampling operation
            zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
            sampled. It determines the shape of the downsampled tensor.
            align_corners: This only has an effect when downsample_mode or upsample_mode  is 'linear', 'bilinear',
                'bicubic' or 'trilinear'. Default: False
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
            device: device on which the tensor will be allocated.

        rw  N)r*   rx   rX  rY  rZ  rt   r  zoom_factor)ry   rR  rX  rY  rZ  rt   r  rz   rz   r{   rx     s    z"RandSimulateLowResolution.__init__rU  r   c                   s6   t  d  | j| jd | jd | _| js2d S d S )Nr   r   )r   rW  r   rm  rZ  r[  rX  r   r   rz   r{   rW    s    z#RandSimulateLowResolution.randomizeTr|   rp   )r   rW  r   c           
      C  s   |r|    | jr|jdd }tt|| j tj}t	|d| j
dd}t	|d| jd| jd}t }td ||}||}	t| t|	}	|	| |	S |S dS )z
        Args:
            img: shape must be (num_channels, H, W[, D]),
            randomize: whether to execute `randomize()` function first, defaults to True.
        r   Nr  F)r   r  rr   r  )r   r  rr   r  rt   )rW  rX  r   r   r  r   r[  r  int_r[   rX  rY  rt   r   r   r   copy_meta_from)
ry   r   rW  r#  target_shapeZresize_tfm_downsampleZresize_tfm_upsampleZoriginal_tack_meta_valueZimg_downsampledZimg_upsampledrz   rz   r{   r     s6       	
z"RandSimulateLowResolution.__call__)N)T)r   r   r   r   ri   r   r;   NEAREST	TRILINEARrx   rW  r   r   rz   rz   r   r{   rm     s   )r   
__future__r   r   collections.abcr   r   r   	itertoolsr   typingr   r   r   r	   r
   r   numpyr   r   monai.configr   r   monai.config.type_definitionsr   monai.data.meta_objr   r   monai.data.meta_tensorr   monai.data.utilsr   r   r   r   r   r   monai.networks.layersr   r   r   monai.networks.utilsr   Zmonai.transforms.croppad.arrayr   r   monai.transforms.inverser   Z#monai.transforms.spatial.functionalr   r    r!   r"   r#   r$   r%   r&   monai.transforms.traitsr'   monai.transforms.transformr(   r)   r*   r+   monai.transforms.utilsr,   r-   r.   r/   r0   r1   r2   r3   r4   0monai.transforms.utils_pytorch_numpy_unificationr5   r6   r7   r8   monai.utilsr9   r:   r;   r<   r=   r>   r?   r@   rA   rB   rC   rD   rE   rF   monai.utils.enumsrG   rH   rI   rJ   monai.utils.miscrK   r   monai.utils.modulerL   monai.utils.type_conversionrM   rN   rO   r   Zhas_nibrQ   r  r  r  __all__r>  r  rR   rS   rT   rU   rV   r[   r\   r]   r^   r_   r`   ra   rc   rd   re   rf   rg   rh   ri   rj   rk   rl   rW   rb   rX   rY   rZ   rm   rz   rz   rz   r{   <module>   s     (
,@$ R Tt- "  	7D 5@ | ,  F   ' *cVV ,h