U
    PhW                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZmZmZ ed	\ZZd
dgZG dd
 d
ejZG dd dejZdS )    )annotationsN)nn)
functional)USE_COMPILED)	grid_pull)meshgrid_ij)GridSampleModeGridSamplePadModeoptional_importzmonai._CWarpDVF2DDFc                      sV   e Zd ZdZejjejjdf fdd	Z	dddddd	d
dZ
dddddZ  ZS )r   zB
    Warp an image with given dense displacement field (DDF).
    Fc                   s   t    trd|dd tD kr\t|}|tjkr8d}n$|tjkrHd}n|tjkrXd}nd}|| _nt	d t|j
| _tr|dd tD krt|}|tjkrd}n$|tjkrd}n|tjkrd}nd}|| _nt|j
| _d	| _|| _d	S )
ac  
        For pytorch native APIs, the possible values are:

            - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``.
            - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``

        See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html

        For MONAI C++/CUDA extensions, the possible values are:

            - mode: ``"nearest"``, ``"bilinear"``, ``"bicubic"``, 0, 1, ...
            - padding_mode: ``"zeros"``, ``"border"``, ``"reflection"``, 0, 1, ...

        See also: :py:class:`monai.networks.layers.grid_pull`

        - jitter: bool, default=False
            Define reference grid on non-integer values
            Reference: B. Likar and F. Pernus. A heirarchical approach to elastic registration
            based on mutual information. Image and Vision Computing, 19:33-44, 2001.
        c                 s  s   | ]}|j V  qd S Nvalue).0inter r   O/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/warp.py	<genexpr><   s     z Warp.__init__.<locals>.<genexpr>   r      z=monai.networks.blocks.Warp: Using PyTorch native grid_sample.c                 s  s   | ]}|j V  qd S r   r   )r   padr   r   r   r   M   s        N)super__init__r   r   BILINEARNEARESTBICUBIC_interp_modewarningswarnr   r	   ZEROSBORDER
REFLECTION_padding_moderef_gridjitter)selfmodepadding_moder&   	__class__r   r   r   #   s8    







zWarp.__init__r   torch.Tensorboolint)ddfr&   seedreturnc              	   C  s   | j d k	rD| j jd |jd krD| j jdd  |jdd  krD| j S dd |jdd  D }tjt| dd}tj|g|jd  dd}||| _ |rtjj|d  tj| |t	|7 }W 5 Q R X d| j _
| j S )	Nr   r      c                 S  s   g | ]}t d |qS )r   )torcharange)r   dimr   r   r   
<listcomp>e   s     z+Warp.get_reference_grid.<locals>.<listcomp>)r5   )enabledF)r%   shaper3   stackr   torandomfork_rngmanual_seed	rand_likerequires_grad)r'   r/   r&   r0   Zmesh_pointsgridr   r   r   get_reference_grid^   s"    zWarp.get_reference_gridimager/   c           	   
   C  sD  t |jd }|dkr&td| d|jd |ft|jdd  }|j|krvtd| d|j d	| d
|j d	| j|| jd| }|dgtt	dd|  dg }t
s.t|jdd D ],\}}|d|f d |d  d |d|f< qtt	|d dd}|d|f }tj||| j| j ddS t||| jd| jdS )a+  
        Args:
            image: Tensor in shape (batch, num_channels, H, W[, D])
            ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])

        Returns:
            warped_image in the same shape as image (batch, num_channels, H, W[, D])
        r2   )r2   r   zgot unsupported spatial_dims=z, currently support 2 or 3.r   NzGiven input z-d image shape z, the input DDF shape must be z, Got z	 instead.)r&   r   .T)r(   r)   align_corners)boundextrapolateinterpolation)lenr8   NotImplementedErrortuple
ValueErrorrA   r&   permutelistranger   	enumerateFgrid_sampler   r$   r   )	r'   rC   r/   spatial_dimsZ	ddf_shaper@   ir5   Zindex_orderingr   r   r   forwardq   s.    	 
 $&    zWarp.forward)Fr   )__name__
__module____qualname____doc__r   r   r   r	   r"   r   rA   rU   __classcell__r   r   r*   r   r      s   ;c                      sF   e Zd ZdZdejjejjfdd fddZ	dddd	d
Z
  ZS )r   z
    Layer calculates a dense displacement field (DDF) from a dense velocity field (DVF)
    with scaling and squaring.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)

    r   r.   )	num_stepsc                   s8   t    |dkr td| || _t||d| _d S )Nr   z"expecting positive num_steps, got )r(   r)   )r   r   rL   r[   r   
warp_layer)r'   r[   r(   r)   r*   r   r   r      s
    
zDVF2DDF.__init__r,   )dvfr1   c                 C  s4   |d| j   }t| j D ]}|| j||d }q|S )z
        Args:
            dvf: dvf to be transformed, in shape (batch, ``spatial_dims``, H, W[,D])

        Returns:
            a dense displacement field
        r2   rB   )r[   rO   r\   )r'   r]   r/   _r   r   r   rU      s    zDVF2DDF.forward)rV   rW   rX   rY   r   r   r   r	   r!   r   rU   rZ   r   r   r*   r   r      s   
  	)
__future__r   r   r3   r   torch.nnr   rQ   Zmonai.config.deviceconfigr   Z(monai.networks.layers.spatial_transformsr   monai.networks.utilsr   monai.utilsr   r	   r
   _Cr^   __all__Moduler   r   r   r   r   r   <module>   s   u