U
    PhH                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlmZm	Z	m
Z
mZmZ d dlmZ ddd	d
gZG dd dejZG dd dejZG dd	 d	eZG dd dejZG dd
 d
eZdS )    )annotationsN)nn)
functional)RegistrationDownSampleBlockRegistrationExtractionBlockRegistrationResidualConvBlockget_conv_blockget_deconv_block)meshgrid_ijRegUNet
AffineHead	GlobalNetLocalNetc                      s   e Zd ZdZd&dddddddd	d
d
dd fddZdd Zdd Zdd ZddddZdddddZ	dd Z
dddddd Zdd!d"d#Zd$d% Z  ZS )'r   u  
    Class that implements an adapted UNet. This class also serve as the parent class of LocalNet and GlobalNet

    Reference:
        O. Ronneberger, P. Fischer, and T. Brox,
        “U-net: Convolutional networks for biomedical image segmentation,”,
        Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
        https://arxiv.org/abs/1505.04597

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    kaiming_uniformN   TFint
str | Noneztuple[int] | Noneboolint | list[int])spatial_dimsin_channelsnum_channel_initialdepthout_kernel_initializerout_activationout_channelsextract_levelspoolingconcat_skipencode_kernel_sizesc                   s   t    |s|f}t||kr$t| _| _| _| _| _| _	| _
| _|	 _|
 _t|trz|g jd  }t| jd krt| _ fddt jd D  _t j _               dS )a,  
        Args:
            spatial_dims: number of spatial dims
            in_channels: number of input channels
            num_channel_initial: number of initial channels
            depth: input is at level 0, bottom is at level depth.
            out_kernel_initializer: kernel initializer for the last layer
            out_activation: activation at the last layer
            out_channels: number of channels for the output
            extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
            encode_kernel_sizes: kernel size for down-sampling
           c                   s   g | ]} j d |  qS    )r   .0dself P/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/regunet.py
<listcomp>`   s     z$RegUNet.__init__.<locals>.<listcomp>N)super__init__maxAssertionErrorr   r   r   r   r   r   r   r   r   r   
isinstancer   lenr   rangenum_channelsminmin_extract_levelbuild_layers)r'   r   r   r   r   r   r   r   r   r   r   r   	__class__r&   r)   r,   ,   s:    

zRegUNet.__init__c                 C  s   |    |   d S )N)build_encode_layersbuild_decode_layersr&   r(   r(   r)   r5   o   s    zRegUNet.build_layersc                   s`   t  fddt jD  _t  fddt jD  _ j jd  jd d _d S )Nc                   s@   g | ]8} j |d kr jn j|d   j|  j| dqS )r   r    r   r   kernel_size)build_conv_blockr   r2   r   r#   r&   r(   r)   r*   v   s   z/RegUNet.build_encode_layers.<locals>.<listcomp>c                   s   g | ]} j  j| d qS )channels)build_down_sampling_blockr2   r#   r&   r(   r)   r*      s     r   r   )	r   
ModuleListr1   r   encode_convsencode_poolsbuild_bottom_blockr2   bottom_blockr&   r(   r&   r)   r8   s   s    

 zRegUNet.build_encode_layersc              	   C  s(   t t| j|||dt| j|||dS N)r   r   r   r;   )r   
Sequentialr   r   r   r'   r   r   r;   r(   r(   r)   r<      s    zRegUNet.build_conv_blockr=   c                 C  s   t | j|| jdS )N)r   r>   r   )r   r   r   )r'   r>   r(   r(   r)   r?      s    z!RegUNet.build_down_sampling_blockrB   c              	   C  s4   | j | j }tt| j|||dt| j|||dS rH   )r   r   r   rI   r   r   r   rJ   r(   r(   r)   rF      s    zRegUNet.build_bottom_blockc                   sj   t  fddt jd  jd dD  _t  fddt jd  jd dD  _   _d S )Nc                   s*   g | ]"} j  j|d    j| dqS )r    rB   )build_up_sampling_blockr2   r#   r&   r(   r)   r*      s   z/RegUNet.build_decode_layers.<locals>.<listcomp>r    rA   c                   s<   g | ]4} j  jr d  j|  n j|  j| ddqS )r"   r   r:   )r<   r   r2   r#   r&   r(   r)   r*      s   )	r   rC   r1   r   r4   decode_deconvsdecode_convsbuild_output_blockoutput_blockr&   r(   r&   r)   r9      s    

zRegUNet.build_decode_layers	nn.Moduler   r   returnc                 C  s   t | j||dS Nr   r   r   )r	   r   r'   r   r   r(   r(   r)   rK      s    zRegUNet.build_up_sampling_block)rR   c                 C  s    t | j| j| j| j| j| jdS )N)r   r   r2   r   kernel_initializer
activation)r   r   r   r2   r   r   r   r&   r(   r(   r)   rN      s    zRegUNet.build_output_blockc                 C  s   |j dd }g }|}t| j| jD ]"\}}||}||}|| q$| |}|g}	tt| j| jD ]\\}
\}}||}| j	rt
j|||
 d  gdd}n|||
 d   }||}|	| qj| j|	|d}|S )z
        Args:
            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])

        Returns:
            Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]), with the same spatial size as ``x``
        r"   Nr    dim)
image_size)shapeziprD   rE   appendrG   	enumeraterL   rM   r   torchcatrO   )r'   xrZ   skipsencodedZencode_convZencode_poolskipdecodedoutsiZdecode_deconvZdecode_convoutr(   r(   r)   forward   s$    
zRegUNet.forward)r   Nr   NTFr   )__name__
__module____qualname____doc__r,   r5   r8   r<   r?   rF   r9   rK   rN   ri   __classcell__r(   r(   r6   r)   r      s$          (C
c                      s`   e Zd Zddddddd fddZedd	d
ddZd	dddZddd	dddZ  ZS )r   Fr   	list[int]r   r   rZ   decode_sizer   
save_thetac           	        s   t    || _|dkrN||d  |d  }d}tjddddddgtjd}n`|dkr||d  |d  |d  }d}tjddddddddddddgtjd}ntd| tj||d	| _	| 
|| _| j	jj  | j	jj| || _t | _d
S )aR  
        Args:
            spatial_dims: number of spatial dimensions
            image_size: output spatial size
            decode_size: input spatial size (two or three integers depending on ``spatial_dims``)
            in_channels: number of input channels
            save_theta: whether to save the theta matrix estimation
        r"   r   r       dtyper      z/only support 2D/3D operation, got spatial_dims=)in_featuresout_featuresN)r+   r,   r   r_   tensorfloat
ValueErrorr   Linearfcget_reference_gridgridweightdatazero_biascopy_rr   Tensortheta)	r'   r   rZ   rq   r   rr   rw   rx   Zout_initr6   r(   r)   r,      s"    
*zAffineHead.__init__ztuple[int] | list[int]torch.Tensor)rZ   rR   c                 C  s.   dd | D }t jt| dd}|jt jdS )Nc                 S  s   g | ]}t d |qS )r   )r_   arange)r$   rY   r(   r(   r)   r*     s     z1AffineHead.get_reference_grid.<locals>.<listcomp>r   rX   rt   )r_   stackr
   torz   )rZ   mesh_pointsr   r(   r(   r)   r~     s    zAffineHead.get_reference_grid)r   c              	   C  s|   t | jt | jd d g}| jdkrDt d||ddd}n4| jdkrht d||ddd}ntd| j |S )	Nr    r"   zqij,bpq->bpijrA   r   zqijk,bpq->bpijk   zdo not support spatial_dims=)r_   r`   r   	ones_liker   einsumreshaper{   )r'   r   Zgrid_paddedZgrid_warpedr(   r(   r)   affine_transform  s     

zAffineHead.affine_transformzlist[torch.Tensor])ra   rZ   rR   c                 C  sV   |d }| j j|jd| _ | ||jd d}| jrB| | _| 	|| j  }|S )Nr   )devicerA   )
r   r   r   r}   r   r[   rr   detachr   r   )r'   ra   rZ   fr   rh   r(   r(   r)   ri   (  s    
zAffineHead.forward)F)	rj   rk   rl   r,   staticmethodr~   r   ri   rn   r(   r(   r6   r)   r      s    'c                      sD   e Zd ZdZddddddd	d	d
d
dd
d fddZdd Z  ZS )r   z
    Build GlobalNet for image registration.

    Reference:
        Hu, Yipeng, et al.
        "Label-driven weakly-supervised learning
        for multimodal deformable image registration,"
        https://arxiv.org/abs/1711.01666
    r   NTFr   ro   r   r   r   r   )rZ   r   r   r   r   r   r   r   r   r   rr   c                   s|   |D ]2}|d   dkrt d  dd   d| q|| _ fdd|D | _|| _t j||| |||||	|
d
 d	S )
a  
        Args:
            image_size: output displacement field spatial size
            spatial_dims: number of spatial dims
            in_channels: number of input channels
            num_channel_initial: number of initial channels
            depth: input is at level 0, bottom is at level depth.
            out_kernel_initializer: kernel initializer for the last layer
            out_activation: activation at the last layer
            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
            encode_kernel_sizes: kernel size for down-sampling
            save_theta: whether to save the theta matrix estimation
        r"   r   zgiven depth z3, all input spatial dimension must be divisible by z, got input of size c                   s   g | ]}|d    qS r!   r(   r$   sizer   r(   r)   r*   a  s     z&GlobalNet.__init__.<locals>.<listcomp>)
r   r   r   r   r   r   r   r   r   r   N)r{   rZ   rq   rr   r+   r,   )r'   rZ   r   r   r   r   r   r   r   r   r   rr   r   r6   r   r)   r,   =  s(    zGlobalNet.__init__c                 C  s    t | j| j| j| jd | jdS )NrA   rp   )r   r   rZ   rq   r2   rr   r&   r(   r(   r)   rN   p  s    zGlobalNet.build_output_block)r   NTFr   F)rj   rk   rl   rm   r,   rN   rn   r(   r(   r6   r)   r   2  s         (3c                      s<   e Zd Zddddddd fddZd	d	d
ddZ  ZS )AdditiveUpSampleBlocknearestNr   strbool | Noner   r   r   modealign_cornersc                   s*   t    t|||d| _|| _|| _d S rS   )r+   r,   r	   deconvr   r   )r'   r   r   r   r   r   r6   r(   r)   r,   |  s    
zAdditiveUpSampleBlock.__init__r   )ra   rR   c                 C  sp   dd |j dd  D }| |}tj||| j| jd}tjtj|j	|j d d dddddd}|| }|S )	Nc                 S  s   g | ]}|d  qS r!   r(   r   r(   r(   r)   r*     s     z1AdditiveUpSampleBlock.forward.<locals>.<listcomp>r"   )r   r   r    )
split_sizerY   rA   rX   )
r[   r   Finterpolater   r   r_   sumr   split)r'   ra   output_sizeZdeconvedresizedrh   r(   r(   r)   ri     s    
,zAdditiveUpSampleBlock.forward)r   N)rj   rk   rl   r,   ri   rn   r(   r(   r6   r)   r   z  s     r   c                      s`   e Zd ZdZddddd	d
d
ddddddd fddZdddddZddddddZ  ZS )r   a  
    Reimplementation of LocalNet, based on:
    `Weakly-supervised convolutional neural networks for multimodal image registration
    <https://doi.org/10.1016/j.media.2018.07.002>`_.
    `Label-driven weakly-supervised learning for multimodal deformable image registration
    <https://arxiv.org/abs/1711.01666>`_.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    r   Nr   TFr   r   z
tuple[int]r   r   r   r   )r   r   r   r   r   r   r   r   use_additive_samplingr   r   r   c                   sL   |	| _ || _|| _t j||||t||||||
dgdgt|  d dS )a  
        Args:
            spatial_dims: number of spatial dims
            in_channels: number of input channels
            num_channel_initial: number of initial channels
            out_kernel_initializer: kernel initializer for the last layer
            out_activation: activation at the last layer
            out_channels: number of channels for the output
            extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
            use_additive_sampling: whether use additive up-sampling layer for decoding.
            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
            mode: mode for interpolation when use_additive_sampling, default is "nearest".
            align_corners: align_corners for interpolation when use_additive_sampling, default is None.
           r   )r   r   r   r   r   r   r   r   r   r   r   N)use_additive_upsamplingr   r   r+   r,   r-   )r'   r   r   r   r   r   r   r   r   r   r   r   r   r6   r(   r)   r,     s     zLocalNet.__init__rB   c                 C  s   | j | j }t| j|||dS rH   )r   r   r   r   rJ   r(   r(   r)   rF     s       zLocalNet.build_bottom_blockrP   rQ   c                 C  s.   | j rt| j||| j| jdS t| j||dS )Nr   rT   )r   r   r   r   r   r	   rU   r(   r(   r)   rK     s    z LocalNet.build_up_sampling_block)r   Nr   TTFr   N)rj   rk   rl   rm   r,   rF   rK   rn   r(   r(   r6   r)   r     s           */)
__future__r   r_   r   torch.nnr   r   Z#monai.networks.blocks.regunet_blockr   r   r   r   r	   monai.networks.utilsr
   __all__Moduler   r   r   r   r   r(   r(   r(   r)   <module>   s    OFH