U
    Phy"                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZmZmZmZ d"ddddddddddd
ddZd#ddddddddZG dd dejZG dd dejZdddddddZG d d! d!ejZdS )$    )annotations)SequenceN)nn)
functionalConvolution)ConvNormPoolsame_padding      RELUBATCHkaiming_uniformintzSequence[int] | intztuple[int, ...] | int | Noneztuple | str | None
str | Nonez	nn.Module)
spatial_dimsin_channelsout_channelskernel_sizestridespaddingactnorminitializerreturnc	                 C  s   |d krt |}t| ||||||dd|d
}	ttj| f }
|	 D ]Z}t||
rB|dkrntjt	
|j qB|dkrtjt	
|j qBtd| dqB|	S )NF)r   r   r   r   bias	conv_onlyr   r   zeroszinitializer zA is not supported, currently supporting kaiming_uniform and zeros)r   r   r   CONVmodules
isinstancer   initkaiming_normal_torch	as_tensorweightzeros_
ValueError)r   r   r   r   r   r   r   r   r   
conv_block	conv_typem r-   X/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/regunet_block.pyget_conv_block   s2    

r/   )r   r   r   r   r   c              	   C  s"   t |}t| |||dd|d}|S )NFT)r   r   r   r   )r   r   )r   r   r   r   r   modr-   r-   r.   get_conv_layer?   s          r1   c                      s@   e Zd ZdZddddddd fddZddd	d
dZ  ZS )RegistrationResidualConvBlockz
    A block with skip links and layer - norm - activation.
    Only changes the number of channels, the spatial size is kept same.
       r   r   )r   r   r   
num_layersr   c                   sr   t    || _t fddt|D | _tfddt|D | _tdd t|D | _dS )a  

        Args:
            spatial_dims: number of spatial dimensions
            in_channels: number of input channels
            out_channels: number of output channels
            num_layers: number of layers inside the block
            kernel_size: kernel_size
        c                   s(   g | ] }t |d kr ndqS )r   )r   r   r   r   )r1   ).0ir   r   r   r   r-   r.   
<listcomp>^   s   z:RegistrationResidualConvBlock.__init__.<locals>.<listcomp>c                   s   g | ]}t t jf  qS r-   )r	   r   r5   _)r   r   r-   r.   r8   h   s     c                 S  s   g | ]}t  qS r-   )r   ReLUr9   r-   r-   r.   r8   i   s     N)	super__init__r4   r   
ModuleListrangelayersnormsacts)selfr   r   r   r4   r   	__class__r7   r.   r=   O   s    
 z&RegistrationResidualConvBlock.__init__torch.Tensorxr   c                 C  s\   |}t t| j| j| jD ]<\}\}}}||}||}|| jd krN|| }||}q|S )a	  

        Args:
            x: Tensor in shape (batch, ``in_channels``, insize_1, insize_2, [insize_3])

        Returns:
            Tensor in shape (batch, ``out_channels``, insize_1, insize_2, [insize_3]),
            with the same spatial size as ``x``
        r   )	enumeratezipr@   rA   rB   r4   )rC   rH   skipr6   convr   r   r-   r-   r.   forwardk   s    
$
z%RegistrationResidualConvBlock.forward)r3   r   __name__
__module____qualname____doc__r=   rM   __classcell__r-   r-   rD   r.   r2   I   s
      r2   c                      s<   e Zd ZdZddddd fddZddd	d
dZ  ZS )RegistrationDownSampleBlockz
    A down-sample module used in RegUNet to half the spatial size.
    The number of channels is kept same.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    r   boolNone)r   channelspoolingr   c                   s@   t    |r&ttj|f dd| _nt|||dddd| _dS )z
        Args:
            spatial_dims: number of spatial dimensions.
            channels: channels
            pooling: use MaxPool if True, strided conv if False
        r3   )r   r   )r   r   r   r   r   r   N)r<   r=   r
   MAXlayerr/   )rC   r   rW   rX   rD   r-   r.   r=      s    
z$RegistrationDownSampleBlock.__init__rF   rG   c                 C  s>   |j dd D ] }|d dkrtd|j  q| |}|S )a_  
        Halves the spatial dimensions and keeps the same channel.
        output in shape (batch, ``channels``, insize_1 / 2, insize_2 / 2, [insize_3 / 2]),

        Args:
            x: Tensor in shape (batch, ``channels``, insize_1, insize_2, [insize_3])

        Raises:
            ValueError: when input spatial dimensions are not even.
        r3   Nr   z7expecting x spatial dimensions be even, got x of shape )shaper)   rZ   )rC   rH   r6   outr-   r-   r.   rM      s
    
z#RegistrationDownSampleBlock.forwardrN   r-   r-   rD   r.   rT      s   rT   )r   r   r   r   c                 C  s    t | ||dddddddd
}|S )Nr3   r   r   FTr   )
r   r   r   r   r   r   r   is_transposedr   output_paddingr   )r   r   r   r0   r-   r-   r.   get_deconv_block   s    r_   c                
      sH   e Zd ZdZdddddddd	d
d fddZddddddZ  ZS )RegistrationExtractionBlockzx
    The Extraction Block used in RegUNet.
    Extracts feature from each ``extract_levels`` and takes the average.
    r   Nnearestr   z
tuple[int]ztuple[int] | list[int]r   strzbool | None)r   extract_levelsnum_channelsr   kernel_initializer
activationmodealign_cornersc	           	        sL   t    || _t|| _t fdd|D | _|| _|| _	dS )an  

        Args:
            spatial_dims: number of spatial dimensions
            extract_levels: spatial levels to extract feature from, 0 refers to the input scale
            num_channels: number of channels at each scale level,
                List or Tuple of length equals to `depth` of the RegNet
            out_channels: number of output channels
            kernel_initializer: kernel initializer
            activation: kernel activation function
            mode: feature map interpolation mode, default to "nearest".
            align_corners: whether to align corners for feature map interpolation.
        c              
     s$   g | ]}t | d  dqS )N)r   r   r   r   r   r   )r/   )r5   drf   re   rd   r   r   r-   r.   r8      s   	z8RegistrationExtractionBlock.__init__.<locals>.<listcomp>N)
r<   r=   rc   max	max_levelr   r>   r@   rg   rh   )	rC   r   rc   rd   r   re   rf   rg   rh   rD   rj   r.   r=      s    

	z$RegistrationExtractionBlock.__init__zlist[torch.Tensor]z	list[int]rF   )rH   
image_sizer   c                   s<    fddt jjD }tjtj|dddd}|S )a#  

        Args:
            x: Decoded feature at different spatial levels, sorted from deep to shallow
            image_size: output image size

        Returns:
            Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size``
        c                   s4   g | ],\}}t j|j|   jjd qS ))sizerg   rh   )Finterpolaterl   rg   rh   )r5   rZ   levelrm   rC   rH   r-   r.   r8      s      z7RegistrationExtractionBlock.forward.<locals>.<listcomp>r   )dim)rJ   r@   rc   r%   meanstack)rC   rH   rm   Zfeature_listr\   r-   rr   r.   rM      s
    
z#RegistrationExtractionBlock.forward)r   Nra   NrN   r-   r-   rD   r.   r`      s       "+r`   )r   r   Nr   r   r   )r   )
__future__r   collections.abcr   r%   r   torch.nnr   ro   Zmonai.networks.blocksr   monai.networks.layersr   r	   r
   r   r/   r1   Moduler2   rT   r_   r`   r-   r-   r-   r.   <module>   s&         "( 
7/