U
    Ph1                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ G dd dejZG dd deZG d	d
 d
ejZG dd deZG dd deZG dd deZdS )    )annotationsN)Convolution)ActConvNormPool
split_argsc                	      sT   e Zd ZdZddddifddfdddd	d	d
dd fddZdddddZ  ZS )ChannelSELayerz
    Re-implementation of the Squeeze-and-Excitation block based on:
    "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
       reluinplaceTsigmoidFinttuple[str, dict] | strboolNone)spatial_dimsin_channelsracti_type_1acti_type_2add_residualreturnc              	     s   t    || _ttj|f }|d| _t|| }|dkrRtd| d| dt|\}	}
t|\}}t	
t	j||ddt|	 f |
t	j||ddt| f || _dS )	aS  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``.
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".

        Raises:
            ValueError: When ``r`` is nonpositive or larger than ``in_channels``.

        See also:

            :py:class:`monai.networks.layers.Act`

           r   z7r must be positive and smaller than in_channels, got r=z in_channels=.T)biasN)super__init__r   r   ADAPTIVEAVGavg_poolr   
ValueErrorr   nn
SequentialLinearr   fc)selfr   r   r   r   r   r   	pool_typechannelsZact_1Z
act_1_argsZact_2Z
act_2_args	__class__ a/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/squeeze_and_excitation.pyr      s    

zChannelSELayer.__init__torch.Tensorxr   c                 C  sb   |j dd \}}| |||}| |||gdg|jd   }|| }| jr^||7 }|S )b
        Args:
            x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
        Nr
   r   )shaper   viewr$   ndimr   )r%   r.   bcyresultr*   r*   r+   forwardJ   s    $zChannelSELayer.forward__name__
__module____qualname____doc__r   r7   __classcell__r*   r*   r(   r+   r	      s   	
 -r	   c                      s2   e Zd ZdZdddddddd fd	d
Z  ZS )ResidualSELayerz
    A "squeeze-and-excitation"-like layer with a residual connection::

        --+-- SE --o--
          |        |
          +--------+
    r
   	leakyrelur   r   r   r   )r   r   r   r   r   r   c                   s   t  j|||||dd dS )a  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: defaults to "leakyrelu".
            acti_type_2: defaults to "relu".

        See also:
            :py:class:`monai.networks.blocks.ChannelSELayer`
        T)r   r   r   r   r   r   N)r   r   )r%   r   r   r   r   r   r(   r*   r+   r   e   s    zResidualSELayer.__init__)r
   r?   r   )r9   r:   r;   r<   r   r=   r*   r*   r(   r+   r>   \   s
      r>   c                      sp   e Zd ZdZddddddddifddddiffdddddd	d	d	d
ddddd fddZdddddZ  ZS )SEBlockac  
    Residual module enhanced with Squeeze-and-Excitation::

        ----+- conv1 --  conv2 -- conv3 -- SE -o---
            |                                  |
            +---(channel project if needed)----+

    Re-implementation of the SE-Resnet block based on:
    "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
    Nr
   r   r   Tr   r   zdict | NoneConvolution | Noner   ztuple[str, dict] | str | None)r   r   n_chns_1n_chns_2n_chns_3conv_param_1conv_param_2conv_param_3projectr   r   r   acti_type_finalc                   s(  t    |s$dtjdddifd}tf |||d|| _|sVdtjdddifd}tf |||d|| _|sdtjdd}tf |||d|| _t|||
||d	| _	|	dkr||krt
t
j|f ||dd
| _n|	dkrt | _n|	| _|dk	rt|\}}t| f || _n
t | _dS )ai  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            n_chns_1: number of output channels in the 1st convolution.
            n_chns_2: number of output channels in the 2nd convolution.
            n_chns_3: number of output channels in the 3rd convolution.
            conv_param_1: additional parameters to the 1st convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_2: additional parameters to the 2nd convolution.
                Defaults to ``{"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_3: additional parameters to the 3rd convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": None}``
            project: in the case of residual chns and output chns doesn't match, a project
                (Conv) layer/block is used to adjust the number of chns. In SENET, it is
                consisted with a Conv layer as well as a Norm layer.
                Defaults to None (chns are matchable) or a Conv layer with kernel size 1.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to "relu".
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".
            acti_type_final: activation type of the end of the block. Defaults to "relu".

        See also:

            :py:class:`monai.networks.blocks.ChannelSELayer`

        r   r   r   T)kernel_sizenormact)r   r   out_channels   N)r   r   r   r   r   )rJ   )r   r   r   BATCHr   conv1conv2conv3r	   se_layerr   CONVrH   r!   Identityr   r   rL   )r%   r   r   rB   rC   rD   rE   rF   rG   rH   r   r   r   rI   Z	act_finalZact_final_argsr(   r*   r+   r      s@    +
      
zSEBlock.__init__r,   r-   c                 C  sH   |  |}| |}| |}| |}| |}||7 }| |}|S )r/   )rH   rP   rQ   rR   rS   rL   )r%   r.   residualr*   r*   r+   r7      s    





zSEBlock.forwardr8   r*   r*   r(   r+   r@      s   

,Lr@   c                
      s:   e Zd ZdZdZdddddddddd fd	d
Z  ZS )SEBottleneckz"
    Bottleneck for SENet154.
       r   Nr   rA   r   r   inplanesplanesgroups	reductionstride
downsampler   c                   sx   dddddift jdd}|ddddift j|dd}	ddd t jdd}
t j|||d	 |d
 |d
 ||	|
||d
 d S )Nr   r   r   TFstridesrJ   rL   rK   r   rN   ra   rJ   rL   rK   r\   r   r
   rX   
r   r   rB   rC   rD   rE   rF   rG   rH   r   r   rO   r   r   r%   r   rZ   r[   r\   r]   r^   r_   rE   rF   rG   r(   r*   r+   r      s4    

zSEBottleneck.__init__)r   Nr9   r:   r;   r<   	expansionr   r=   r*   r*   r(   r+   rW      s
   	  rW   c                
      s:   e Zd ZdZdZdddddddddd fd	d
Z  ZS )SEResNetBottleneckz
    ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
    implementation and uses `strides=stride` in `conv1` and not in `conv2`
    (the latter is used in the torchvision implementation of ResNet).
    rX   r   Nr   rA   r   rY   c                   sp   |ddddift jdd}dddddift j|dd}	ddd t jdd}
t j|||||d	 ||	|
||d

 d S )Nr   r   r   TFr`   rN   rb   rX   rc   rd   re   r(   r*   r+   r   "  s4    

zSEResNetBottleneck.__init__)r   Nrf   r*   r*   r(   r+   rh     s
   	  rh   c                      s<   e Zd ZdZdZddddddddddd	 fd	d
Z  ZS )SEResNeXtBottleneckzI
    ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
    rX   r   Nr   rA   r   )	r   rZ   r[   r\   r]   r^   r_   
base_widthr   c	                   s   dddddift jdd}	|ddddift j|dd}
ddd t jdd}t||d	  | }t j|||||d
 |	|
|||d
 d S )Nr   r   r   TFr`   rN   rb   @   rX   rc   )r   rO   mathfloorr   r   )r%   r   rZ   r[   r\   r]   r^   r_   rj   rE   rF   rG   widthr(   r*   r+   r   R  s6    

zSEResNeXtBottleneck.__init__)r   NrX   rf   r*   r*   r(   r+   ri   K  s   	   ri   )
__future__r   rl   torchtorch.nnr!   Zmonai.networks.blocksr   monai.networks.layers.factoriesr   r   r   r   r   Moduler	   r>   r@   rW   rh   ri   r*   r*   r*   r+   <module>   s   E&g02