o
    ,iH                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlmZm	Z	m
Z
mZmZ d dlmZ g dZG dd	 d	ejZG d
d dejZG dd deZG dd dejZG dd deZdS )    )annotationsN)nn)
functional)RegistrationDownSampleBlockRegistrationExtractionBlockRegistrationResidualConvBlockget_conv_blockget_deconv_block)meshgrid_ij)RegUNet
AffineHead	GlobalNetLocalNetc                      s   e Zd ZdZ							d.d/ fddZdd Zdd Zdd Zd0d d!Zd1d"d#Z	d$d% Z
d2d(d)Zd3d*d+Zd,d- Z  ZS )4r   u  
    Class that implements an adapted UNet. This class also serve as the parent class of LocalNet and GlobalNet

    Reference:
        O. Ronneberger, P. Fischer, and T. Brox,
        “U-net: Convolutional networks for biomedical image segmentation,”,
        Lecture Notes in Computer Science, 2015, vol. 9351, pp. 234–241.
        https://arxiv.org/abs/1505.04597

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    kaiming_uniformN   TFspatial_dimsintin_channelsnum_channel_initialdepthout_kernel_initializer
str | Noneout_activationout_channelsextract_levelstuple[int] | Nonepoolingboolconcat_skipencode_kernel_sizesint | list[int]c                   s   t    |s
|f}t||krt| _| _| _| _| _| _	| _
| _|	 _|
 _t|tr=|g jd  }t| jd krHt| _ fddt jd D  _t j _               dS )a,  
        Args:
            spatial_dims: number of spatial dims
            in_channels: number of input channels
            num_channel_initial: number of initial channels
            depth: input is at level 0, bottom is at level depth.
            out_kernel_initializer: kernel initializer for the last layer
            out_activation: activation at the last layer
            out_channels: number of channels for the output
            extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
            encode_kernel_sizes: kernel size for down-sampling
           c                   s   g | ]	} j d |  qS    )r   .0dself ]/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/regunet.py
<listcomp>`   s    z$RegUNet.__init__.<locals>.<listcomp>N)super__init__maxAssertionErrorr   r   r   r   r   r   r   r   r   r   
isinstancer   lenr   rangenum_channelsminmin_extract_levelbuild_layers)r(   r   r   r   r   r   r   r   r   r   r   r   	__class__r'   r*   r-   ,   s:   

zRegUNet.__init__c                 C  s   |    |   d S )N)build_encode_layersbuild_decode_layersr'   r)   r)   r*   r6   o   s   zRegUNet.build_layersc                   s`   t  fddt jD  _t  fddt jD  _ j jd  jd d _d S )Nc                   s@   g | ]} j |d kr jn j|d   j|  j| dqS )r   r!   r   r   kernel_size)build_conv_blockr   r3   r   r$   r'   r)   r*   r+   v   s    z/RegUNet.build_encode_layers.<locals>.<listcomp>c                   s   g | ]} j  j| d qS ))channels)build_down_sampling_blockr3   r$   r'   r)   r*   r+      s    r   r   )	r   
ModuleListr2   r   encode_convsencode_poolsbuild_bottom_blockr3   bottom_blockr'   r)   r'   r*   r9   s   s   

zRegUNet.build_encode_layersc              	   C  s(   t t| j|||dt| j|||dS N)r   r   r   r<   )r   
Sequentialr   r   r   r(   r   r   r<   r)   r)   r*   r=      s   zRegUNet.build_conv_blockr>   c                 C  s   t | j|| jdS )N)r   r>   r   )r   r   r   )r(   r>   r)   r)   r*   r?      s   z!RegUNet.build_down_sampling_blockc              	   C  s4   | j | j }tt| j|||dt| j|||dS rH   )r   r   r   rI   r   r   r   rJ   r)   r)   r*   rF      s   zRegUNet.build_bottom_blockc                   sj   t  fddt jd  jd dD  _t  fddt jd  jd dD  _   _d S )Nc                   s*   g | ]} j  j|d    j| dqS )r!   rB   )build_up_sampling_blockr3   r$   r'   r)   r*   r+      s    z/RegUNet.build_decode_layers.<locals>.<listcomp>r!   rA   c                   s<   g | ]} j  jrd  j|  n j|  j| ddqS )r#   r   r;   )r=   r   r3   r$   r'   r)   r*   r+      s    )	r   rC   r2   r   r5   decode_deconvsdecode_convsbuild_output_blockoutput_blockr'   r)   r'   r*   r:      s   

zRegUNet.build_decode_layersreturn	nn.Modulec                 C  s   t | j||dS Nr   r   r   )r	   r   r(   r   r   r)   r)   r*   rK      s   zRegUNet.build_up_sampling_blockc                 C  s    t | j| j| j| j| j| jdS )N)r   r   r3   r   kernel_initializer
activation)r   r   r   r3   r   r   r   r'   r)   r)   r*   rN      s   zRegUNet.build_output_blockc                 C  s   |j dd }g }|}t| j| jD ]\}}||}||}|| q| |}|g}	tt| j| jD ].\}
\}}||}| j	rQt
j|||
 d  gdd}n	|||
 d   }||}|	| q5| j|	|d}|S )z
        Args:
            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])

        Returns:
            Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]), with the same spatial size as ``x``
        r#   Nr!   dim)
image_size)shapeziprD   rE   appendrG   	enumeraterL   rM   r   torchcatrO   )r(   xrY   skipsencodedZencode_convZencode_poolskipdecodedoutsiZdecode_deconvZdecode_convoutr)   r)   r*   forward   s$   
zRegUNet.forward)r   Nr   NTFr   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    )r>   r   r   r   r   r   r   r   r   r   rP   rQ   )rP   rQ   )__name__
__module____qualname____doc__r-   r6   r9   r=   r?   rF   r:   rK   rN   rh   __classcell__r)   r)   r7   r*   r      s&    C




r   c                      sD   e Zd Z	dd fd
dZedddZdddZdddZ  ZS )r   Fr   r   rY   	list[int]decode_sizer   
save_thetar   c           	        s   t    || _|dkr#||d  |d  }d}tjg dtjd}n&|dkrB||d  |d  |d  }d}tjg d	tjd}ntd
| tj||d| _	| 
|| _| j	jj  | j	jj| || _t | _dS )aR  
        Args:
            spatial_dims: number of spatial dimensions
            image_size: output spatial size
            decode_size: input spatial size (two or three integers depending on ``spatial_dims``)
            in_channels: number of input channels
            save_theta: whether to save the theta matrix estimation
        r#   r   r!      )r!   r   r   r   r!   r   dtyper      )r!   r   r   r   r   r!   r   r   r   r   r!   r   z/only support 2D/3D operation, got spatial_dims=)in_featuresout_featuresN)r,   r-   r   r^   tensorfloat
ValueErrorr   Linearfcget_reference_gridgridweightdatazero_biascopy_rr   Tensortheta)	r(   r   rY   rq   r   rr   rw   rx   Zout_initr7   r)   r*   r-      s"   
zAffineHead.__init__tuple[int] | list[int]rP   torch.Tensorc                 C  s.   dd | D }t jt| dd}|jt jdS )Nc                 S  s   g | ]}t d |qS )r   )r^   arange)r%   rX   r)   r)   r*   r+         z1AffineHead.get_reference_grid.<locals>.<listcomp>r   rW   rt   )r^   stackr
   torz   )rY   mesh_pointsr   r)   r)   r*   r~     s   zAffineHead.get_reference_gridr   c              	   C  s|   t | jt | jd d g}| jdkr#t d||ddd}|S | jdkr6t d||ddd}|S td| j )	Nr!   r#   zqij,bpq->bpijrA   r   zqijk,bpq->bpijk   zdo not support spatial_dims=)r^   r_   r   	ones_liker   einsumreshaper{   )r(   r   Zgrid_paddedZgrid_warpedr)   r)   r*   affine_transform  s    

zAffineHead.affine_transformr`   list[torch.Tensor]c                 C  sV   |d }| j j|jd| _ | ||jd d}| jr!| | _| 	|| j  }|S )Nr   )devicerA   )
r   r   r   r}   r   rZ   rr   detachr   r   )r(   r`   rY   fr   rg   r)   r)   r*   rh   (  s   
zAffineHead.forward)F)
r   r   rY   rp   rq   rp   r   r   rr   r   )rY   r   rP   r   )r   r   )r`   r   rY   rp   rP   r   )	rk   rl   rm   r-   staticmethodr~   r   rh   ro   r)   r)   r7   r*   r      s    '
r   c                      s8   e Zd ZdZ						dd fddZdd Z  ZS )r   z
    Build GlobalNet for image registration.

    Reference:
        Hu, Yipeng, et al.
        "Label-driven weakly-supervised learning
        for multimodal deformable image registration,"
        https://arxiv.org/abs/1711.01666
    r   NTFr   rY   rp   r   r   r   r   r   r   r   r   r   r   r   r   r    rr   c                   s|   |D ]}|d   dkrt d  dd   d| q|| _ fdd|D | _|| _t j||| |||||	|
d
 d	S )
a  
        Args:
            image_size: output displacement field spatial size
            spatial_dims: number of spatial dims
            in_channels: number of input channels
            num_channel_initial: number of initial channels
            depth: input is at level 0, bottom is at level depth.
            out_kernel_initializer: kernel initializer for the last layer
            out_activation: activation at the last layer
            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv
            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
            encode_kernel_sizes: kernel size for down-sampling
            save_theta: whether to save the theta matrix estimation
        r#   r   zgiven depth z3, all input spatial dimension must be divisible by z, got input of size c                   s   g | ]}|d    qS r"   r)   r%   sizer   r)   r*   r+   a  r   z&GlobalNet.__init__.<locals>.<listcomp>)
r   r   r   r   r   r   r   r   r   r   N)r{   rY   rq   rr   r,   r-   )r(   rY   r   r   r   r   r   r   r   r   r   rr   r   r7   r   r*   r-   =  s2   
zGlobalNet.__init__c                 C  s    t | j| j| j| jd | jdS )NrA   )r   rY   rq   r   rr   )r   r   rY   rq   r3   rr   r'   r)   r)   r*   rN   p  s   zGlobalNet.build_output_block)r   NTFr   F)rY   rp   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    rr   r   )rk   rl   rm   rn   r-   rN   ro   r)   r)   r7   r*   r   2  s    3r   c                      s.   e Zd Z		dd fddZdddZ  ZS )AdditiveUpSampleBlocknearestNr   r   r   r   modestralign_cornersbool | Nonec                   s*   t    t|||d| _|| _|| _d S rR   )r,   r-   r	   deconvr   r   )r(   r   r   r   r   r   r7   r)   r*   r-   |  s   

zAdditiveUpSampleBlock.__init__r`   r   rP   c                 C  sp   dd |j dd  D }| |}tj||| j| jd}tjtj|j	|j d d dddddd}|| }|S )	Nc                 S  s   g | ]}|d  qS r"   r)   r   r)   r)   r*   r+     s    z1AdditiveUpSampleBlock.forward.<locals>.<listcomp>r#   )r   r   r!   )
split_sizerX   rA   rW   )
rZ   r   Finterpolater   r   r^   sumr   split)r(   r`   output_sizeZdeconvedresizedrg   r)   r)   r*   rh     s   
,zAdditiveUpSampleBlock.forward)r   N)
r   r   r   r   r   r   r   r   r   r   )r`   r   rP   r   )rk   rl   rm   r-   rh   ro   r)   r)   r7   r*   r   z  s
    r   c                      sH   e Zd ZdZ								d"d# fddZd$ddZd%d d!Z  ZS )&r   a  
    Reimplementation of LocalNet, based on:
    `Weakly-supervised convolutional neural networks for multimodal image registration
    <https://doi.org/10.1016/j.media.2018.07.002>`_.
    `Label-driven weakly-supervised learning for multimodal deformable image registration
    <https://arxiv.org/abs/1711.01666>`_.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    r   Nr   TFr   r   r   r   r   r   
tuple[int]r   r   r   r   r   r   use_additive_samplingr   r   r   r   r   c                   sL   |	| _ || _|| _t j||||t||||||
dgdgt|  d dS )a  
        Args:
            spatial_dims: number of spatial dims
            in_channels: number of input channels
            num_channel_initial: number of initial channels
            out_kernel_initializer: kernel initializer for the last layer
            out_activation: activation at the last layer
            out_channels: number of channels for the output
            extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
            pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
            use_additive_sampling: whether use additive up-sampling layer for decoding.
            concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
            mode: mode for interpolation when use_additive_sampling, default is "nearest".
            align_corners: align_corners for interpolation when use_additive_sampling, default is None.
           r   )r   r   r   r   r   r   r   r   r   r   r   N)use_additive_upsamplingr   r   r,   r-   r.   )r(   r   r   r   r   r   r   r   r   r   r   r   r   r7   r)   r*   r-     s    
zLocalNet.__init__c                 C  s   | j | j }t| j|||dS rH   )r   r   r   r   rJ   r)   r)   r*   rF     s   
zLocalNet.build_bottom_blockrP   rQ   c                 C  s.   | j rt| j||| j| jdS t| j||dS )N)r   r   r   r   r   rS   )r   r   r   r   r   r	   rT   r)   r)   r*   rK     s   z LocalNet.build_up_sampling_block)r   Nr   TTFr   N)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   ri   rj   )rk   rl   rm   rn   r-   rF   rK   ro   r)   r)   r7   r*   r     s    
/r   )
__future__r   r^   r   torch.nnr   r   Z#monai.networks.blocks.regunet_blockr   r   r   r   r	   monai.networks.utilsr
   __all__Moduler   r   r   r   r   r)   r)   r)   r*   <module>   s    OFH