o
    &iW                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZmZmZ ed	\ZZd
dgZG dd
 d
ejZG dd dejZdS )    )annotationsN)nn)
functional)USE_COMPILED)	grid_pull)meshgrid_ij)GridSampleModeGridSamplePadModeoptional_importzmonai._CWarpDVF2DDFc                      sF   e Zd ZdZejjejjdf fdd	Z	ddddZ
dddZ  ZS )r   zB
    Warp an image with given dense displacement field (DDF).
    Fc                   s   t    tr2|dd tD v r.t|}|tjkrd}n|tjkr$d}n
|tjkr,d}nd}|| _nt	d t|j
| _trj|dd tD v rft|}|tjkrTd}n|tjkr\d}n
|tjkrdd}nd}|| _nt|j
| _d	| _|| _d	S )
ac  
        For pytorch native APIs, the possible values are:

            - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``.
            - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``

        See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

        For MONAI C++/CUDA extensions, the possible values are:

            - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``, 0, 1, ...
            - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ...

        See also: :py:class:`monai.networks.layers.grid_pull`

        - jitter: bool, default=False
            Define reference grid on non-integer values
            Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration
            based on mutual information. Image and Vision Computing, 19:33-44, 2001.
        c                 s      | ]}|j V  qd S Nvalue).0inter r   \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/warp.py	<genexpr><       z Warp.__init__.<locals>.<genexpr>   r      z=monai.networks.blocks.Warp: Using PyTorch native grid_sample.c                 s  r   r   r   )r   padr   r   r   r   M   r      N)super__init__r   r   BILINEARNEARESTBICUBIC_interp_modewarningswarnr   r	   ZEROSBORDER
REFLECTION_padding_moderef_gridjitter)selfmodepadding_moder(   	__class__r   r   r   #   s8   








zWarp.__init__r   ddftorch.Tensorr(   boolseedintreturnc                 C  s   | j d ur"| j jd |jd kr"| j jdd  |jdd  kr"| j S dd |jdd  D }tjt| dd}tj|g|jd  dd}||| _ |rptjj|d tj| |t	|7 }W d    n1 skw   Y  d| j _
| j S )	Nr   r      c                 S  s   g | ]}t d |qS )r   )torcharange)r   dimr   r   r   
<listcomp>e   s    z+Warp.get_reference_grid.<locals>.<listcomp>)r7   )enabledF)r'   shaper5   stackr   torandomfork_rngmanual_seed	rand_likerequires_grad)r)   r.   r(   r1   Zmesh_pointsgridr   r   r   get_reference_grid^   s   
zWarp.get_reference_gridimagec           	   
   C  sB  t |jd }|dvrtd| d|jd |ft|jdd  }|j|kr;td| d|j d	| d
|j d	| j|| jd| }|dgtt	dd|  dg }t
st|jdd D ]\}}|d|f d |d  d |d|f< qbtt	|d dd}|d|f }tj||| j| j ddS t||| jd| jdS )a+  
        Args:
            image: Tensor in shape (batch, num_channels, H, W[, D])
            ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])

        Returns:
            warped_image in the same shape as image (batch, num_channels, H, W[, D])
        r4   )r4   r   zgot unsupported spatial_dims=z, currently support 2 or 3.r   NzGiven input z-d image shape z, the input DDF shape must be z, Got z	 instead.)r(   r   .T)r*   r+   align_corners)boundextrapolateinterpolation)lenr:   NotImplementedErrortuple
ValueErrorrC   r(   permutelistranger   	enumerateFgrid_sampler    r&   r   )	r)   rD   r.   spatial_dimsZ	ddf_shaperB   ir7   Zindex_orderingr   r   r   forwardq   s*   	 
$&zWarp.forward)Fr   )r.   r/   r(   r0   r1   r2   r3   r/   )rD   r/   r.   r/   )__name__
__module____qualname____doc__r   r   r   r	   r$   r   rC   rV   __classcell__r   r   r,   r   r      s
    ;c                      s<   e Zd ZdZdejjejjfd fddZ	dd
dZ
  ZS )r   z
    Layer calculates a dense displacement field (DDF) from a dense velocity field (DVF)
    with scaling and squaring.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)

    r   	num_stepsr2   c                   s8   t    |dkrtd| || _t||d| _d S )Nr   z"expecting positive num_steps, got )r*   r+   )r   r   rM   r\   r   
warp_layer)r)   r\   r*   r+   r,   r   r   r      s
   
zDVF2DDF.__init__dvfr/   r3   c                 C  s4   |d| j   }t| j D ]}|| j||d }q|S )z
        Args:
            dvf: dvf to be transformed, in shape (batch, ``spatial_dims``, H, W[,D])

        Returns:
            a dense displacement field
        r4   )rD   r.   )r\   rP   r]   )r)   r^   r.   _r   r   r   rV      s   zDVF2DDF.forward)r\   r2   )r^   r/   r3   r/   )rW   rX   rY   rZ   r   r   r   r	   r#   r   rV   r[   r   r   r,   r   r      s
    
	)
__future__r   r!   r5   r   torch.nnr   rR   Zmonai.config.deviceconfigr   Z(monai.networks.layers.spatial_transformsr   monai.networks.utilsr   monai.utilsr   r	   r
   _Cr_   __all__Moduler   r   r   r   r   r   <module>   s   u