U
    Phc                     @  s0  d dl mZ d dlmZ d dlZd dlmZ d dlZd dlm	Z	 d dl
mZmZmZmZmZmZ ed\ZZddd	d
dgZG dd dejjZd!dddddddZG dd dejjZd"dddddd	ZG dd dejjZd#ddddd
ZG dd dejjZd$ddddddZG d d dejZdS )%    )annotations)SequenceN)to_norm_affine)GridSampleModeGridSamplePadModeconvert_to_dst_typeensure_tuplelook_up_optionoptional_importzmonai._CAffineTransform	grid_pull	grid_push
grid_count	grid_gradc                   @  s$   e Zd Zedd Zedd ZdS )	_GridPullc                 C  s>   |||f}t j||f| }|js(|jr:|| _| || |S N)_Cr   requires_gradoptsave_for_backwardctxinputgridinterpolationboundextrapolater   output r   ]/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/layers/spatial_transforms.pyforward%   s    
z_GridPull.forwardc                 C  s   | j d s| j d sdS | j}| j}tj|f|| }| j d rd|d | j d rX|d nd d d d fS | j d rd |d d d d fS d S Nr      NNNNN)needs_input_gradsaved_tensorsr   r   Zgrid_pull_backwardr   gradvarr   gradsr   r   r   backward/   s    
$
z_GridPull.backwardN__name__
__module____qualname__staticmethodr    r*   r   r   r   r   r   #   s   
	r   linearzeroTtorch.Tensorbool)r   r   r   returnc                 C  sX   dd t |D }dd t |D }t| ||||}t| tjjrTt|| dd }|S )a	  
    Sample an image with respect to a deformation field.

    `interpolation` can be an int, a string or an InterpolationType.
    Possible values are::

        - 0 or 'nearest'    or InterpolationType.nearest
        - 1 or 'linear'     or InterpolationType.linear
        - 2 or 'quadratic'  or InterpolationType.quadratic
        - 3 or 'cubic'      or InterpolationType.cubic
        - 4 or 'fourth'     or InterpolationType.fourth
        - 5 or 'fifth'      or InterpolationType.fifth
        - 6 or 'sixth'      or InterpolationType.sixth
        - 7 or 'seventh'    or InterpolationType.seventh

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific interpolation orders.

    `bound` can be an int, a string or a BoundType.
    Possible values are::

        - 0 or 'replicate' or 'nearest'      or BoundType.replicate or 'border'
        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1
        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2
        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1
        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2
        - 5 or 'dft'       or 'wrap'         or BoundType.dft
        - 7 or 'zero'      or 'zeros'        or BoundType.zero

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific boundary conditions.
    `sliding` is a specific condition than only applies to flow fields
    (with as many channels as dimensions). It cannot be dimension-specific.
    Note that:

        - `dft` corresponds to circular padding
        - `dct2` corresponds to Neumann boundary conditions (symmetric)
        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)

    See Also:
        - https://en.wikipedia.org/wiki/Discrete_cosine_transform
        - https://en.wikipedia.org/wiki/Discrete_sine_transform
        - ``help(monai._C.BoundType)``
        - ``help(monai._C.InterpolationType)``

    Args:
        input: Input image. `(B, C, Wi, Hi, Di)`.
        grid: Deformation field. `(B, Wo, Ho, Do, 1|2|3)`.
        interpolation (int or list[int] , optional): Interpolation order.
            Defaults to `'linear'`.
        bound (BoundType, or list[BoundType], optional): Boundary conditions.
            Defaults to `'zero'`.
        extrapolate: Extrapolate out-of-bound data.
            Defaults to `True`.

    Returns:
        output (torch.Tensor): Deformed image `(B, C, Wo, Ho, Do)`.

    c                 S  s,   g | ]$}t |trtjj| nt|qS r   
isinstancestrr   Z	BoundType__members__.0br   r   r   
<listcomp>{   s     zgrid_pull.<locals>.<listcomp>c                 S  s,   g | ]$}t |trtjj| nt|qS r   r6   r7   r   ZInterpolationTyper8   r:   ir   r   r   r<   |   s   dstr   )r   r   applyr6   monaidata
MetaTensorr   r   r   r   r   r   outr   r   r   r   <   s    ?c                   @  s$   e Zd Zedd Zedd ZdS )	_GridPushc           	      C  s@   |||f}t j|||f| }|js*|jr<|| _| || |S r   )r   r   r   r   r   )	r   r   r   shaper   r   r   r   r   r   r   r   r       s    
z_GridPush.forwardc                 C  s   | j d s| j d sdS | j}| j}tj|f|| }| j d rf|d | j d rX|d nd d d d d fS | j d rd |d d d d d fS d S )Nr   r"   )NNNNNN)r$   r%   r   r   Zgrid_push_backwardr&   r   r   r   r*      s    
&
z_GridPush.backwardNr+   r   r   r   r   rH      s   
	rH   )r   r   r   c                 C  st   dd t |D }dd t |D }|dkr>t| jdd }t| |||||}t| tjjrpt	|| dd }|S )a	  
    Splat an image with respect to a deformation field (pull adjoint).

    `interpolation` can be an int, a string or an InterpolationType.
    Possible values are::

        - 0 or 'nearest'    or InterpolationType.nearest
        - 1 or 'linear'     or InterpolationType.linear
        - 2 or 'quadratic'  or InterpolationType.quadratic
        - 3 or 'cubic'      or InterpolationType.cubic
        - 4 or 'fourth'     or InterpolationType.fourth
        - 5 or 'fifth'      or InterpolationType.fifth
        - 6 or 'sixth'      or InterpolationType.sixth
        - 7 or 'seventh'    or InterpolationType.seventh

    A list of values can be provided, in the order `[W, H, D]`,
    to specify dimension-specific interpolation orders.

    `bound` can be an int, a string or a BoundType.
    Possible values are::

        - 0 or 'replicate' or 'nearest'      or BoundType.replicate
        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1
        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2
        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1
        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2
        - 5 or 'dft'       or 'wrap'         or BoundType.dft
        - 7 or 'zero'                        or BoundType.zero

    A list of values can be provided, in the order `[W, H, D]`,
    to specify dimension-specific boundary conditions.
    `sliding` is a specific condition than only applies to flow fields
    (with as many channels as dimensions). It cannot be dimension-specific.
    Note that:

        - `dft` corresponds to circular padding
        - `dct2` corresponds to Neumann boundary conditions (symmetric)
        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)

    See Also:

        - https://en.wikipedia.org/wiki/Discrete_cosine_transform
        - https://en.wikipedia.org/wiki/Discrete_sine_transform
        - ``help(monai._C.BoundType)``
        - ``help(monai._C.InterpolationType)``

    Args:
        input: Input image `(B, C, Wi, Hi, Di)`.
        grid: Deformation field `(B, Wi, Hi, Di, 1|2|3)`.
        shape: Shape of the source image.
        interpolation (int or list[int] , optional): Interpolation order.
            Defaults to `'linear'`.
        bound (BoundType, or list[BoundType], optional): Boundary conditions.
            Defaults to `'zero'`.
        extrapolate: Extrapolate out-of-bound data.
            Defaults to `True`.

    Returns:
        output (torch.Tensor): Splatted image `(B, C, Wo, Ho, Do)`.

    c                 S  s,   g | ]$}t |trtjj| nt|qS r   r5   r9   r   r   r   r<      s     zgrid_push.<locals>.<listcomp>c                 S  s,   g | ]$}t |trtjj| nt|qS r   r=   r>   r   r   r   r<      s   N   r@   r   )
r   tuplerI   rH   rB   r6   rC   rD   rE   r   )r   r   rI   r   r   r   rG   r   r   r   r      s    Ac                   @  s$   e Zd Zedd Zedd ZdS )
_GridCountc                 C  s6   |||f}t j||f| }|jr2|| _| | |S r   )r   r   r   r   r   )r   r   rI   r   r   r   r   r   r   r   r   r       s    

z_GridCount.forwardc                 C  s6   | j d r2| j}| j}tj|f|| d d d d fS dS )Nr   r#   )r$   r%   r   r   Zgrid_count_backward)r   r'   r(   r   r   r   r   r*      s
    
z_GridCount.backwardNr+   r   r   r   r   rL      s   
	rL   )r   r   c                 C  sr   dd t |D }dd t |D }|dkr>t| jdd }t| ||||}tttjj	rnt
|tdd }|S )a
  
    Splatting weights with respect to a deformation field (pull adjoint).

    This function is equivalent to applying grid_push to an image of ones.

    `interpolation` can be an int, a string or an InterpolationType.
    Possible values are::

        - 0 or 'nearest'    or InterpolationType.nearest
        - 1 or 'linear'     or InterpolationType.linear
        - 2 or 'quadratic'  or InterpolationType.quadratic
        - 3 or 'cubic'      or InterpolationType.cubic
        - 4 or 'fourth'     or InterpolationType.fourth
        - 5 or 'fifth'      or InterpolationType.fifth
        - 6 or 'sixth'      or InterpolationType.sixth
        - 7 or 'seventh'    or InterpolationType.seventh

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific interpolation orders.

    `bound` can be an int, a string or a BoundType.
    Possible values are::

        - 0 or 'replicate' or 'nearest'      or BoundType.replicate
        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1
        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2
        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1
        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2
        - 5 or 'dft'       or 'wrap'         or BoundType.dft
        - 7 or 'zero'                        or BoundType.zero

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific boundary conditions.
    `sliding` is a specific condition than only applies to flow fields
    (with as many channels as dimensions). It cannot be dimension-specific.
    Note that:

        - `dft` corresponds to circular padding
        - `dct2` corresponds to Neumann boundary conditions (symmetric)
        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)

    See Also:

        - https://en.wikipedia.org/wiki/Discrete_cosine_transform
        - https://en.wikipedia.org/wiki/Discrete_sine_transform
        - ``help(monai._C.BoundType)``
        - ``help(monai._C.InterpolationType)``

    Args:
        grid: Deformation field `(B, Wi, Hi, Di, 2|3)`.
        shape: shape of the source image.
        interpolation (int or list[int] , optional): Interpolation order.
            Defaults to `'linear'`.
        bound (BoundType, or list[BoundType], optional): Boundary conditions.
            Defaults to `'zero'`.
        extrapolate (bool, optional): Extrapolate out-of-bound data.
            Defaults to `True`.

    Returns:
        output (torch.Tensor): Splat weights `(B, 1, Wo, Ho, Do)`.

    c                 S  s,   g | ]$}t |trtjj| nt|qS r   r5   r9   r   r   r   r<   E  s     zgrid_count.<locals>.<listcomp>c                 S  s,   g | ]$}t |trtjj| nt|qS r   r=   r>   r   r   r   r<   F  s   NrJ   r@   r   )r   rK   rI   rL   rB   r6   r   rC   rD   rE   r   )r   rI   r   r   r   rG   r   r   r   r     s    @c                   @  s$   e Zd Zedd Zedd ZdS )	_GridGradc                 C  s>   |||f}t j||f| }|js(|jr:|| _| || |S r   )r   r   r   r   r   r   r   r   r   r    V  s    
z_GridGrad.forwardc                 C  s   | j d s| j d sdS | j}| j}tj|f|| }| j d rd|d | j d rX|d nd d d d fS | j d rd |d d d d fS d S r!   )r$   r%   r   r   Zgrid_grad_backwardr&   r   r   r   r*   `  s    
$
z_GridGrad.backwardNr+   r   r   r   r   rM   T  s   
	rM   c                 C  sX   dd t |D }dd t |D }t| ||||}t| tjjrTt|| dd }|S )a	  
    Sample an image with respect to a deformation field.

    `interpolation` can be an int, a string or an InterpolationType.
    Possible values are::

        - 0 or 'nearest'    or InterpolationType.nearest
        - 1 or 'linear'     or InterpolationType.linear
        - 2 or 'quadratic'  or InterpolationType.quadratic
        - 3 or 'cubic'      or InterpolationType.cubic
        - 4 or 'fourth'     or InterpolationType.fourth
        - 5 or 'fifth'      or InterpolationType.fifth
        - 6 or 'sixth'      or InterpolationType.sixth
        - 7 or 'seventh'    or InterpolationType.seventh

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific interpolation orders.

    `bound` can be an int, a string or a BoundType.
    Possible values are::

        - 0 or 'replicate' or 'nearest'      or BoundType.replicate
        - 1 or 'dct1'      or 'mirror'       or BoundType.dct1
        - 2 or 'dct2'      or 'reflect'      or BoundType.dct2
        - 3 or 'dst1'      or 'antimirror'   or BoundType.dst1
        - 4 or 'dst2'      or 'antireflect'  or BoundType.dst2
        - 5 or 'dft'       or 'wrap'         or BoundType.dft
        - 7 or 'zero'                        or BoundType.zero

    A list of values can be provided, in the order [W, H, D],
    to specify dimension-specific boundary conditions.
    `sliding` is a specific condition than only applies to flow fields
    (with as many channels as dimensions). It cannot be dimension-specific.
    Note that:

        - `dft` corresponds to circular padding
        - `dct2` corresponds to Neumann boundary conditions (symmetric)
        - `dst2` corresponds to Dirichlet boundary conditions (antisymmetric)

    See Also:

        - https://en.wikipedia.org/wiki/Discrete_cosine_transform
        - https://en.wikipedia.org/wiki/Discrete_sine_transform
        - ``help(monai._C.BoundType)``
        - ``help(monai._C.InterpolationType)``


    Args:
        input: Input image. `(B, C, Wi, Hi, Di)`.
        grid: Deformation field. `(B, Wo, Ho, Do, 2|3)`.
        interpolation (int or list[int] , optional): Interpolation order.
            Defaults to `'linear'`.
        bound (BoundType, or list[BoundType], optional): Boundary conditions.
            Defaults to `'zero'`.
        extrapolate: Extrapolate out-of-bound data. Defaults to `True`.

    Returns:
        output (torch.Tensor): Sampled gradients (B, C, Wo, Ho, Do, 1|2|3).

    c                 S  s,   g | ]$}t |trtjj| nt|qS r   r5   r9   r   r   r   r<     s     zgrid_grad.<locals>.<listcomp>c                 S  s,   g | ]$}t |trtjj| nt|qS r   r=   r>   r   r   r   r<     s   r@   r   )r   rM   rB   r6   rC   rD   rE   r   rF   r   r   r   r   m  s    >c                
      sZ   e Zd Zddejejdddfddddddddd	 fd
dZddddddddZ  Z	S )r   NFTzSequence[int] | int | Noner3   r7   zbool | NoneNone)spatial_size
normalizedmodepadding_modealign_cornersreverse_indexingzero_centeredr4   c                   sv   t    |dk	rt|nd| _|| _t|t| _t|t| _	|| _
|| _|dk	r`| jr`td|dk	rl|nd| _dS )a  
        Apply affine transformations with a batch of affine matrices.

        When `normalized=False` and `reverse_indexing=True`,
        it does the commonly used resampling in the 'pull' direction
        following the ``scipy.ndimage.affine_transform`` convention.
        In this case `theta` is equivalent to (ndim+1, ndim+1) input ``matrix`` of ``scipy.ndimage.affine_transform``,
        operates on homogeneous coordinates.
        See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.affine_transform.html

        When `normalized=True` and `reverse_indexing=False`,
        it applies `theta` to the normalized coordinates (coords. in the range of [-1, 1]) directly.
        This is often used with `align_corners=False` to achieve resolution-agnostic resampling,
        thus useful as a part of trainable modules such as the spatial transformer networks.
        See also: https://pytorch.org/tutorials/intermediate/spatial_transformer_tutorial.html

        Args:
            spatial_size: output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src` input of `self.forward`.
            normalized: indicating whether the provided affine matrix `theta` is defined
                for the normalized coordinates. If `normalized=False`, `theta` will be converted
                to operate on normalized coordinates as pytorch affine_grid works with the normalized
                coordinates.
            mode: {``"bilinear"``, ``"nearest"``}
                Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
                Padding mode for outside grid values. Defaults to ``"zeros"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
            align_corners: see also https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html.
            reverse_indexing: whether to reverse the spatial indexing of image and coordinates.
                set to `False` if `theta` follows pytorch's default "D, H, W" convention.
                set to `True` if `theta` follows `scipy.ndimage` default "i, j, k" convention.
            zero_centered: whether the affine is applied to coordinates in a zero-centered value range.
                With `zero_centered=True`, for example, the center of rotation will be the
                spatial center of the input; with `zero_centered=False`, the center of rotation will be the
                origin of the input. This option is only available when `normalized=False`,
                where the default behaviour is `False` if unspecified.
                See also: :py:func:`monai.networks.utils.normalize_transform`.
        NzD`normalized=True` is not compatible with the `zero_centered` option.F)super__init__r   rO   rP   r	   r   rQ   r   rR   rS   rT   
ValueErrorrU   )selfrO   rP   rQ   rR   rS   rT   rU   	__class__r   r   rW     s    2
zAffineTransform.__init__r2   )srcthetarO   r4   c                 C  s  t |tjs"tdt|j d| dkr@td|j d| dkrT|d }|	 }t
|jdd }|dkrt|d	 dkrd	d	dgn
d	d	d	dg}||jd	 dd|}d
|_tj||gdd}t
|jdd dkrtd|j dt|std|j t |tjs8tdt|j d| d }|dkr^td| dt
|j}|}| jdk	r|dd | j }|dk	r|dd t| }| jst||dd |dd d
| jd}| jrDtjt|d dd|jd}	|dd|	f |ddd|f< |dddd|	f |ddddd|f< |jd	 dkrt|d	 dkrt||d	 dd}|jd	 |d	 krtd|jd	  d|d	  dtjj|ddd|f t|| jd}
tjj |! |
| j"| j#| jd}|S )a   
        ``theta`` must be an affine transformation matrix with shape
        3x3 or Nx3x3 or Nx2x3 or 2x3 for spatial 2D transforms,
        4x4 or Nx4x4 or Nx3x4 or 3x4 for spatial 3D transforms,
        where `N` is the batch size. `theta` will be converted into float Tensor for the computation.

        Args:
            src (array_like): image in spatial 2D or 3D (N, C, spatial_dims),
                where N is the batch dim, C is the number of channels.
            theta (array_like): Nx3x3, Nx2x3, 3x3, 2x3 for spatial 2D inputs,
                Nx4x4, Nx3x4, 3x4, 4x4 for spatial 3D inputs. When the batch dimension is omitted,
                `theta` will be repeated N times, N is the batch dim of `src`.
            spatial_size: output spatial shape, the full output shape will be
                `[N, C, *spatial_size]` where N and C are inferred from the `src`.

        Raises:
            TypeError: When ``theta`` is not a ``torch.Tensor``.
            ValueError: When ``theta`` is not one of [Nxdxd, dxd].
            ValueError: When ``theta`` is not one of [Nx3x3, Nx4x4].
            TypeError: When ``src`` is not a ``torch.Tensor``.
            ValueError: When ``src`` spatially is not one of [2D, 3D].
            ValueError: When affine and image batch dimension differ.

        z"theta must be torch.Tensor but is .rJ      z theta must be Nxdxd or dxd, got rJ   Nr"   )r_   )r`      r   F)dim))r`   r`   )ra   ra   z"theta must be Nx3x3 or Nx4x4, got z'theta must be floating point data, got z src must be torch.Tensor but is zUnsupported src dimension: z, available options are [2, 3].)affinesrc_sizedst_sizerS   rU   )devicez8affine and image batch dimension must match, got affine=z image=)r]   sizerS   )r   r   rQ   rR   rS   )$r6   torchTensor	TypeErrortyper,   rb   rX   rI   clonerK   tensorrepeattor   catis_floating_pointdtyperO   r   rP   r   rU   rT   	as_tensorrangerg   nn
functionalaffine_gridlistrS   grid_sample
contiguousrQ   rR   )rY   r\   r]   rO   Ztheta_shapeZ
pad_affinesrrd   re   rev_idxr   rA   r   r   r   r      sn    (




 ,(zAffineTransform.forward)N)
r,   r-   r.   r   BILINEARr   ZEROSrW   r    __classcell__r   r   rZ   r   r     s   "> )r0   r1   T)Nr0   r1   T)Nr0   r1   T)r0   r1   T) 
__future__r   collections.abcr   ri   torch.nnrv   rC   Zmonai.networksr   monai.utilsr   r   r   r   r	   r
   r   ___all__autogradFunctionr   r   rH   r   rL   r   rM   r   Moduler   r   r   r   r   <module>   s0    	     K       POJ