o
    (i1                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ G dd dejZG dd deZG d	d
 d
ejZG dd deZG dd deZG dd deZdS )    )annotationsN)Convolution)ActConvNormPool
split_argsc                      s>   e Zd ZdZddddifddfd fddZdddZ  ZS )ChannelSELayerz
    Re-implementation of the Squeeze-and-Excitation block based on:
    "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
       reluinplaceTsigmoidFspatial_dimsintin_channelsracti_type_1tuple[str, dict] | stracti_type_2add_residualboolreturnNonec              	     s   t    || _ttj|f }|d| _t|| }|dkr)td| d| dt|\}	}
t|\}}t	
t	j||ddt|	 d	i |
t	j||ddt| d	i || _dS )
aS  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to ``("relu", {"inplace": True})``.
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".

        Raises:
            ValueError: When ``r`` is nonpositive or larger than ``in_channels``.

        See also:

            :py:class:`monai.networks.layers.Act`

           r   z7r must be positive and smaller than in_channels, got r=z in_channels=.T)biasN )super__init__r   r   ADAPTIVEAVGavg_poolr   
ValueErrorr   nn
SequentialLinearr   fc)selfr   r   r   r   r   r   	pool_typechannelsZact_1Z
act_1_argsZact_2Z
act_2_args	__class__r   n/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/squeeze_and_excitation.pyr      s   


zChannelSELayer.__init__xtorch.Tensorc                 C  sb   |j dd \}}| |||}| |||gdg|jd   }|| }| jr/||7 }|S )b
        Args:
            x: in shape (batch, in_channels, spatial_1[, spatial_2, ...]).
        Nr
   r   )shaper    viewr%   ndimr   )r&   r,   bcyresultr   r   r+   forwardJ   s   $zChannelSELayer.forward)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r,   r-   r   r-   __name__
__module____qualname____doc__r   r6   __classcell__r   r   r)   r+   r	      s    	
-r	   c                      s*   e Zd ZdZ			dd fddZ  ZS )ResidualSELayerz
    A "squeeze-and-excitation"-like layer with a residual connection::

        --+-- SE --o--
          |        |
          +--------+
    r
   	leakyrelur   r   r   r   r   r   r   r   r   r   c                   s   t  j|||||dd dS )a  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: defaults to "leakyrelu".
            acti_type_2: defaults to "relu".

        See also:
            :py:class:`monai.networks.blocks.ChannelSELayer`
        T)r   r   r   r   r   r   N)r   r   )r&   r   r   r   r   r   r)   r   r+   r   e   s   
zResidualSELayer.__init__)r
   r?   r   )r   r   r   r   r   r   r   r   r   r   r   r   )r9   r:   r;   r<   r   r=   r   r   r)   r+   r>   \   s    r>   c                
      sN   e Zd ZdZddddddddifddddiffd! fddZd"dd Z  ZS )#SEBlockac  
    Residual module enhanced with Squeeze-and-Excitation::

        ----+- conv1 --  conv2 -- conv3 -- SE -o---
            |                                  |
            +---(channel project if needed)----+

    Re-implementation of the SE-Resnet block based on:
    "Hu et al., Squeeze-and-Excitation Networks, https://arxiv.org/abs/1709.01507".
    Nr
   r   r   Tr   r   r   r   n_chns_1n_chns_2n_chns_3conv_param_1dict | Noneconv_param_2conv_param_3projectConvolution | Noner   r   r   r   acti_type_finaltuple[str, dict] | str | Nonec                   s,  t    |sdtjdddifd}td|||d|| _|s+dtjdddifd}td|||d|| _|s@dtjdd}td|||d|| _t|||
||d	| _	|	du rl||krlt
t
j|f ||dd
| _n|	du rvt | _n|	| _|durt|\}}t| di || _dS t | _dS )ai  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            n_chns_1: number of output channels in the 1st convolution.
            n_chns_2: number of output channels in the 2nd convolution.
            n_chns_3: number of output channels in the 3rd convolution.
            conv_param_1: additional parameters to the 1st convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_2: additional parameters to the 2nd convolution.
                Defaults to ``{"kernel_size": 3, "norm": Norm.BATCH, "act": ("relu", {"inplace": True})}``
            conv_param_3: additional parameters to the 3rd convolution.
                Defaults to ``{"kernel_size": 1, "norm": Norm.BATCH, "act": None}``
            project: in the case of residual chns and output chns doesn't match, a project
                (Conv) layer/block is used to adjust the number of chns. In SENET, it is
                consisted with a Conv layer as well as a Norm layer.
                Defaults to None (chns are matchable) or a Conv layer with kernel size 1.
            r: the reduction ratio r in the paper. Defaults to 2.
            acti_type_1: activation type of the hidden squeeze layer. Defaults to "relu".
            acti_type_2: activation type of the output squeeze layer. Defaults to "sigmoid".
            acti_type_final: activation type of the end of the block. Defaults to "relu".

        See also:

            :py:class:`monai.networks.blocks.ChannelSELayer`

        r   r   r   T)kernel_sizenormact)r   r   out_channels   N)r   r   r   r   r   )rL   r   )r   r   r   BATCHr   conv1conv2conv3r	   se_layerr   CONVrH   r"   Identityr   r   rN   )r&   r   r   rA   rB   rC   rD   rF   rG   rH   r   r   r   rJ   Z	act_finalZact_final_argsr)   r   r+   r      s4   
+
zSEBlock.__init__r,   r-   r   c                 C  sH   |  |}| |}| |}| |}| |}||7 }| |}|S )r.   )rH   rR   rS   rT   rU   rN   )r&   r,   residualr   r   r+   r6      s   





zSEBlock.forward)r   r   r   r   rA   r   rB   r   rC   r   rD   rE   rF   rE   rG   rE   rH   rI   r   r   r   r   r   r   rJ   rK   r7   r8   r   r   r)   r+   r@      s    

Lr@   c                      ,   e Zd ZdZdZ		dd fddZ  ZS )SEBottleneckz"
    Bottleneck for SENet154.
       r   Nr   r   inplanesplanesgroups	reductionstride
downsamplerI   r   r   c                   sx   dddddift jdd}|ddddift j|dd}	ddd t jdd}
t j|||d	 |d
 |d
 ||	|
||d
 d S )Nr   r   r   TFstridesrL   rN   rM   r   rP   rc   rL   rN   rM   r^   r   r
   r[   
r   r   rA   rB   rC   rD   rF   rG   rH   r   r   rQ   r   r   r&   r   r\   r]   r^   r_   r`   ra   rD   rF   rG   r)   r   r+   r      s4   


zSEBottleneck.__init__r   Nr   r   r\   r   r]   r   r^   r   r_   r   r`   r   ra   rI   r   r   r9   r:   r;   r<   	expansionr   r=   r   r   r)   r+   rZ      s    	rZ   c                      rY   )SEResNetBottleneckz
    ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
    implementation and uses `strides=stride` in `conv1` and not in `conv2`
    (the latter is used in the torchvision implementation of ResNet).
    r[   r   Nr   r   r\   r]   r^   r_   r`   ra   rI   r   r   c                   sp   |ddddift jdd}dddddift j|dd}	ddd t jdd}
t j|||||d	 ||	|
||d

 d S )Nr   r   r   TFrb   rP   rd   r[   re   rf   rg   r)   r   r+   r   "  s4   


zSEResNetBottleneck.__init__rh   ri   rj   r   r   r)   r+   rl     s    	rl   c                      s.   e Zd ZdZdZ			dd fddZ  ZS )SEResNeXtBottleneckzI
    ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
    r[   r   Nr   r   r\   r]   r^   r_   r`   ra   rI   
base_widthr   r   c	                   s   dddddift jdd}	|ddddift j|dd}
ddd t jdd}t||d	  | }t j|||||d
 |	|
|||d
 d S )Nr   r   r   TFrb   rP   rd   @   r[   re   )r   rQ   mathfloorr   r   )r&   r   r\   r]   r^   r_   r`   ra   rn   rD   rF   rG   widthr)   r   r+   r   R  s6   


zSEResNeXtBottleneck.__init__)r   Nr[   )r   r   r\   r   r]   r   r^   r   r_   r   r`   r   ra   rI   rn   r   r   r   rj   r   r   r)   r+   rm   K  s    	rm   )
__future__r   rp   torchtorch.nnr"   Zmonai.networks.blocksr   monai.networks.layers.factoriesr   r   r   r   r   Moduler	   r>   r@   rZ   rl   rm   r   r   r   r+   <module>   s   E&g02