U
    Ph,                  	   @  s
  d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZmZmZ d#ddddddddddZd$ddddddddZdddddddZG dd dejZG dd dejZG dd dejZG dd  d ejZG d!d" d"ejZdS )%    )annotations)SequenceN)nn)
functionalConvolution)same_padding)ConvNormPool   RELUBATCHintSequence[int] | inttuple | str | Nonez	nn.Module)spatial_dimsin_channelsout_channelskernel_sizeactnormreturnc                 C  s&   t |}t| |||||dd|d	}|S )NF)r   r   r   bias	conv_onlypaddingr   r   )r   r   r   r   r   r   r   mod r   Y/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/localnet_block.pyget_conv_block   s    r    r   r   r   r   r   c              	   C  s"   t |}t| |||dd|d}|S )NFT)r   r   r   r   r   )r   r   r   r   r   r   r   r   r   get_conv_layer0   s          r"   r   r   r   r   c                 C  s    t | ||dddddddd
}|S )N   r   r   FT   )
r   r   r   stridesr   r   r   is_transposedr   output_paddingr   )r   r   r   r   r   r   r   get_deconv_block:   s    r)   c                      s8   e Zd Zdddddd fddZddd	d
Z  ZS )ResidualBlockr   r   Noner!   c                   sl   t    ||kr&td| d| t||||d| _t||||d| _ttj|f || _	t
 | _d S )N7expecting in_channels == out_channels, got in_channels=, out_channels=r   r   r   r   )super__init__
ValueErrorr    
conv_blockr"   convr
   r   r   r   ReLUreluselfr   r   r   r   	__class__r   r   r0   L   s&    
      zResidualBlock.__init__torch.Tensorr   c              	   C  s$   |  | | | || }|S N)r5   r   r3   r2   r7   xoutr   r   r   forward]   s     zResidualBlock.forward__name__
__module____qualname__r0   r@   __classcell__r   r   r8   r   r*   J   s   r*   c                      s6   e Zd Zddddd fddZdddd	Z  ZS )
LocalNetResidualBlockr   r+   r#   c                   sX   t    ||kr&td| d| t|||d| _ttj|f || _t	 | _
d S )Nr,   r-   r   r   r   )r/   r0   r1   r"   
conv_layerr
   r   r   r   r4   r5   )r7   r   r   r   r8   r   r   r0   d   s    
zLocalNetResidualBlock.__init__r:   r;   c                 C  s   |  | | || }|S r<   )r5   r   rH   r7   r>   midr?   r   r   r   r@   n   s    zLocalNetResidualBlock.forwardrA   r   r   r8   r   rF   b   s   
rF   c                      s<   e Zd ZdZdddddd fddZdd	d
dZ  ZS )LocalNetDownSampleBlocka  
    A down-sample module that can be used for LocalNet, based on:
    `Weakly-supervised convolutional neural networks for multimodal image registration
    <https://doi.org/10.1016/j.media.2018.07.002>`_.
    `Label-driven weakly-supervised learning for multimodal deformable image registration
    <https://arxiv.org/abs/1711.01666>`_.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    r   r   r+   r!   c                   sH   t    t||||d| _t||||d| _ttj|f dd| _dS )a7  
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            kernel_size: convolution kernel size.
        Raises:
            NotImplementedError: when ``kernel_size`` is even
        r.   r$   )r   N)	r/   r0   r    r2   r*   residual_blockr   MAXmax_poolr6   r8   r   r   r0      s    
      z LocalNetDownSampleBlock.__init__z!tuple[torch.Tensor, torch.Tensor]r;   c                 C  sV   |j dd D ] }|d dkrtd|j  q| |}| |}| |}||fS )a  
        Halves the spatial dimensions.
        A tuple of (x, mid) is returned:

            -  x is the downsample result, in shape (batch, ``out_channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),
            -  mid is the mid-level feature, in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3])

        Args:
            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])

        Raises:
            ValueError: when input spatial dimensions are not even.
        r$   Nr   z7expecting x spatial dimensions be even, got x of shape )shaper1   r2   rL   rN   )r7   r>   irJ   r   r   r   r@      s    


zLocalNetDownSampleBlock.forwardrB   rC   rD   __doc__r0   r@   rE   r   r   r8   r   rK   s   s   rK   c                      sN   e Zd ZdZdddddddd fd	d
ZddddZddddZ  ZS )LocalNetUpSampleBlocka  
    An up-sample module that can be used for LocalNet, based on:
    `Weakly-supervised convolutional neural networks for multimodal image registration
    <https://doi.org/10.1016/j.media.2018.07.002>`_.
    `Label-driven weakly-supervised learning for multimodal deformable image registration
    <https://arxiv.org/abs/1711.01666>`_.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    nearestNr   strzbool | Noner+   )r   r   r   modealign_cornersr   c                   sp   t    t|||d| _t|||d| _t|||d| _|| dkrZtd| d| || _	|| _
|| _dS )a  
        Args:
            spatial_dims: number of spatial dimensions.
            in_channels: number of input channels.
            out_channels: number of output channels.
            mode: interpolation mode of the additive upsampling, default to 'nearest'.
            align_corners: whether to align corners for the additive upsampling, default to None.
        Raises:
            ValueError: when ``in_channels != 2 * out_channels``
        rG   r$   z;expecting in_channels == 2 * out_channels, got in_channels=r-   N)r/   r0   r)   deconv_blockr    r2   rF   rL   r1   r   rV   rW   )r7   r   r   r   rV   rW   r8   r   r   r0      s&    
    zLocalNetUpSampleBlock.__init__r:   r;   c                 C  sP   t j||jdd  | j| jd}|jt| jdd}tj	tj
|dddd}|S )Nr$   )rV   rW   r%   )
split_sizedim)rZ   )FinterpolaterO   rV   rW   splitr   r   torchsumstackrI   r   r   r   additive_upsampling   s     z)LocalNetUpSampleBlock.additive_upsamplingc           	      C  s   t |jdd |jdd D ],\}}|d| krtd|j d|j q| || || }|| }| |}| ||}|S )a  
        Halves the channel and doubles the spatial dimensions.

        Args:
            x: feature to be up-sampled, in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
            mid: mid-level feature saved during down-sampling,
                in shape (batch, ``out_channels``, midsize_1, midsize_2, [midsize_3])

        Raises:
            ValueError: when ``midsize != insize * 2``
        r$   Nz_expecting mid spatial dimensions be exactly the double of x spatial dimensions, got x of shape z, mid of shape )ziprO   r1   rX   rb   r2   rL   )	r7   r>   rJ   rP   jh0r1r2r?   r   r   r   r@      s    &
zLocalNetUpSampleBlock.forward)rT   N)rB   rC   rD   rR   r0   rb   r@   rE   r   r   r8   r   rS      s     #rS   c                      s@   e Zd ZdZdddddddd fd	d
ZddddZ  ZS )LocalNetFeatureExtractorBlocka  
    A feature-extraction module that can be used for LocalNet, based on:
    `Weakly-supervised convolutional neural networks for multimodal image registration
    <https://doi.org/10.1016/j.media.2018.07.002>`_.
    `Label-driven weakly-supervised learning for multimodal deformable image registration
    <https://arxiv.org/abs/1711.01666>`_.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    r   kaiming_uniformr   r   rU   r+   )r   r   r   r   initializerr   c                   s   t    t||||dd| _ttj|f }| j D ]Z}t||r6|dkrbtj	
t|j q6|dkrtj	t|j q6td| dq6dS )a+  
        Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        act: activation type and arguments. Defaults to ReLU.
        kernel_initializer: kernel initializer. Defaults to None.
        N)r   r   r   r   r   ri   zeroszinitializer zA is not supported, currently supporting kaiming_uniform and zeros)r/   r0   r    r2   r	   CONVmodules
isinstancer   initkaiming_normal_r_   	as_tensorweightzeros_r1   )r7   r   r   r   r   rj   	conv_typemr8   r   r   r0     s$    
    

z&LocalNetFeatureExtractorBlock.__init__r:   r;   c                 C  s   |  |}|S )zo
        Args:
            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])
        )r2   r=   r   r   r   r@   '  s    
z%LocalNetFeatureExtractorBlock.forward)r   ri   rQ   r   r   r8   r   rh      s
      rh   )r   r   r   )r   )
__future__r   collections.abcr   r_   r   torch.nnr   r\   Zmonai.networks.blocksr   monai.networks.layersr   monai.networks.layers.factoriesr	   r
   r   r    r"   r)   Moduler*   rF   rK   rS   rh   r   r   r   r   <module>   s&       
8P