U
    PhYK                     @  s  d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	Z	d dl
mZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d d
lmZmZmZmZmZ d dlmZ ddddddddgZdddddddZG dd dej Z!ddddd d!Z"G d"d de!Z#G d#d de!Z$G d$d de!Z%G d%d de!Z&G d&d' d'e!Z'G d(d de!Z(e! Z)Z*e# Z+ Z,Z-e$ Z. Z/Z0e% Z1 Z2Z3e& Z4 Z5Z6e' Z7 Z8 Z9Z:e( Z; Z< Z=Z>dS ))    )annotationsN)OrderedDict)Sequence)Any)load_state_dict_from_url)download_url)Convolution)SEBottleneckSEResNetBottleneckSEResNeXtBottleneck)ActConvDropoutNormPool)look_up_optionSENetSENet154
SEResNet50SEResNet101SEResNet152SEResNeXt50SEResNext101SE_NET_MODELSzAhttp://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pthzDhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pthzEhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pthzEhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pthzKhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pthzLhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth)senet154se_resnet50se_resnet101se_resnet152se_resnext50_32x4dse_resnext101_32x4dc                      s   e Zd ZdZdddd	d
dddddddddd fddZd dddddddddddZddddZddddZdddddZ  Z	S )!r   a  
    SENet based on `Squeeze-and-Excitation Networks <https://arxiv.org/pdf/1709.01507.pdf>`_.
    Adapted from `Cadene Hub 2D version
    <https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py>`_.

    Args:
        spatial_dims: spatial dimension of the input data.
        in_channels: channel number of the input data.
        block: SEBlock class or str.
            for SENet154: SEBottleneck or 'se_bottleneck'
            for SE-ResNet models: SEResNetBottleneck or 'se_resnet_bottleneck'
            for SE-ResNeXt models:  SEResNeXtBottleneck or 'se_resnetxt_bottleneck'
        layers: number of residual blocks for 4 layers of the network (layer1...layer4).
        groups: number of groups for the 3x3 convolution in each bottleneck block.
            for SENet154: 64
            for SE-ResNet models: 1
            for SE-ResNeXt models:  32
        reduction: reduction ratio for Squeeze-and-Excitation modules.
            for all models: 16
        dropout_prob: drop probability for the Dropout layer.
            if `None` the Dropout layer is not used.
            for SENet154: 0.2
            for SE-ResNet models: None
            for SE-ResNeXt models: None
        dropout_dim: determine the dimensions of dropout. Defaults to 1.
            When dropout_dim = 1, randomly zeroes some of the elements for each channel.
            When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
            When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).
        inplanes:  number of input channels for layer1.
            for SENet154: 128
            for SE-ResNet models: 64
            for SE-ResNeXt models: 64
        downsample_kernel_size: kernel size for downsampling convolutions in layer2, layer3 and layer4.
            for SENet154: 3
            for SE-ResNet models: 1
            for SE-ResNeXt models: 1
        input_3x3: If `True`, use three 3x3 convolutions instead of
            a single 7x7 convolution in layer0.
            - For SENet154: True
            - For SE-ResNet models: False
            - For SE-ResNeXt models: False
        num_classes: number of outputs in `last_linear` layer.
            for all models: 1000
    皙?         T  intzCtype[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck] | strSequence[int]float | NoneboolNone)spatial_dimsin_channelsblocklayersgroups	reductiondropout_probdropout_diminplanesdownsample_kernel_size	input_3x3num_classesreturnc                   s  t    t|trJ|dkr"t}n(|dkr0t}n|dkr>t}ntd| ttj	 }t
t
j|f }ttj|f }ttj|f }ttj|f }ttj|f }|	| _|| _|r<d||dddd	d
dfd|ddfd|ddfd|dddd	d	d
dfd|ddfd|ddfd|d|	dd	d	d
dfd||	dfd|ddfg	}n2d|||	dddd
dfd||	dfd|ddfg}|d|ddddf tt|| _| j|d|d ||d	d| _| j|d|d	 d|||
d| _| j|d|d d|||
d| _| j|d|d d|||
d| _|d	| _|d k	r(||nd | _ t!d|j" || _#| $ D ]}t||rptj%&t'(|j) n^t||rtj%*t'(|j)d	 tj%*t'(|j+d n$t|tj!rJtj%*t'(|j+d qJd S ) NZse_bottleneckZse_resnet_bottleneckZse_resnetxt_bottleneckzUUnknown block '%s', use se_bottleneck, se_resnet_bottleneck or se_resnetxt_bottleneckconv1@   r#      r!   F)r+   out_channelskernel_sizestridepaddingbiasbn1)num_featuresrelu1T)inplaceconv2bn2relu2conv3bn3relu3   pool)r;   r<   	ceil_moder   )planesblocksr.   r/   r3   r"   )rL   rM   r<   r.   r/   r3      i   ),super__init__
isinstancestrr	   r
   r   
ValueErrorr   RELUr   CONVr   MAXr   BATCHr   DROPOUTADAPTIVEAVGr2   r*   appendnn
Sequentialr   layer0_make_layerlayer1layer2layer3layer4adaptive_avg_pooldropoutLinear	expansionlast_linearmodulesinitkaiming_normal_torch	as_tensorweight	constant_r>   )selfr*   r+   r,   r-   r.   r/   r0   r1   r2   r3   r4   r5   	relu_type	conv_type	pool_type	norm_typedropout_typeavg_pool_typeZlayer0_modulesm	__class__ N/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/senet.pyrP   `   s    


          			
zSENet.__init__z=type[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck]znn.Sequential)r,   rL   rM   r.   r/   r<   r3   r6   c                 C  s   d }|dks| j ||j kr@t| j| j ||j ||d tjdd}g }	|	|| j| j |||||d ||j | _ td|D ] }
|	|| j| j |||d qztj	|	 S )Nr!   F)r*   r+   r:   stridesr;   actnormr>   )r*   r2   rL   r.   r/   r<   
downsample)r*   r2   rL   r.   r/   )
r2   rf   r   r*   r   rW   rZ   ranger[   r\   )ro   r,   rL   rM   r.   r/   r<   r3   r~   r-   _numry   ry   rz   r^      sH    

zSENet._make_layerztorch.Tensor)xc                 C  s6   |  |}| |}| |}| |}| |}|S N)r]   r_   r`   ra   rb   ro   r   ry   ry   rz   features  s    




zSENet.featuresc                 C  s8   |  |}| jd k	r| |}t|d}| |}|S )Nr!   )rc   rd   rk   flattenrg   r   ry   ry   rz   logits  s    



zSENet.logits)r   r6   c                 C  s   |  |}| |}|S r   )r   r   r   ry   ry   rz   forward  s    

zSENet.forward)r    r!   r"   r#   Tr$   )r!   r!   )
__name__
__module____qualname____doc__rP   r^   r   r   r   __classcell__ry   ry   rw   rz   r   2   s   5      ,}  1z	nn.ModulerR   r(   )modelarchprogressc                   s  t |td}|dkrtdtd}td}td}td}td}td}	t|trt|d	 |d
 d tj	|d
 ddnt
||dt D ]}
d}||
rt|d|
}n||
rt|d|
}n||
r
|
  |
< t|d|
}nb||
r6|
  |
< t|d|
}n6||
rRt|d|
}n|	|
rlt|	d|
}|r|
 |< |
= q|    fdd D   |   dS )z:
    This function is used to load pretrained models.
    Nzonly 'senet154', 'se_resnet50', 'se_resnet101',  'se_resnet152', 'se_resnext50_32x4d', and se_resnext101_32x4d are supported to load pretrained weights.z%^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$z%^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$z+^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$z+^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$z*^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$z*^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$urlfilename)filepath)map_location)r   z	\1conv.\2z\1conv\2adn.N.\3z\1se_layer.fc.0.\2z\1se_layer.fc.2.\2z\1project.conv.\2z\1project.adn.N.\2c                   s2   i | ]*\}}| kr | j | j kr||qS ry   )shape).0kv
model_dict
state_dictry   rz   
<dictcomp>I  s
       z$_load_state_dict.<locals>.<dictcomp>)r   r   rS   recompilerQ   dictr   rk   loadr   listkeysmatchsubsqueezer   itemsupdateload_state_dict)r   r   r   	model_urlZpattern_convZ
pattern_bnZ
pattern_seZpattern_se2Zpattern_down_convZpattern_down_bnkeynew_keyry   r   rz   _load_state_dict  sP    









r   c                      s2   e Zd ZdZddddd	d	d
d fddZ  ZS )r   zlSENet154 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.r#      $   r#   r8      FTr&   r%   r(   r)   )r-   r.   r/   
pretrainedr   r6   c                   s0   t  jf t|||d| |r,t| d| d S )N)r,   r-   r.   r/   r   )rO   rP   r	   r   )ro   r-   r.   r/   r   r   kwargsrw   ry   rz   rP   S  s    	zSENet154.__init__)r   r8   r   FTr   r   r   r   rP   r   ry   ry   rw   rz   r   P  s        c                      s:   e Zd ZdZdd	d
d
dd
d
ddddd
 fddZ  ZS )r   znSEResNet50 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.r#         r#   r!   r   Nr8   FTr&   r%   r'   r(   r)   
r-   r.   r/   r0   r2   r3   r4   r   r   r6   c
                   s8   t  jf t|||||||d|
 |r4t| d|	 d S )N)r,   r-   r.   r/   r0   r2   r3   r4   r   rO   rP   r
   r   ro   r-   r.   r/   r0   r2   r3   r4   r   r   r   rw   ry   rz   rP   e  s    	zSEResNet50.__init__)	r   r!   r   Nr8   r!   FFTr   ry   ry   rw   rz   r   b  s            c                      s8   e Zd ZdZddd	d	d	d	d
d
d
dd	 fddZ  ZS )r   zy
    SEResNet101 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r#   r      r#   r!   r   r8   FTr&   r%   r(   r)   	r-   r.   r/   r2   r3   r4   r   r   r6   c	           
   
     s6   t  jf t||||||d|	 |r2t| d| d S )Nr,   r-   r.   r/   r2   r3   r4   r   r   
ro   r-   r.   r/   r2   r3   r4   r   r   r   rw   ry   rz   rP     s    
zSEResNet101.__init__)r   r!   r   r8   r!   FFTr   ry   ry   rw   rz   r     s           c                      s8   e Zd ZdZddd	d	d	d	d
d
d
dd	 fddZ  ZS )r   zy
    SEResNet152 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r   r!   r   r8   FTr&   r%   r(   r)   r   c	           
   
     s6   t  jf t||||||d|	 |r2t| d| d S )Nr   r   r   r   rw   ry   rz   rP     s    
zSEResNet152.__init__)r   r!   r   r8   r!   FFTr   ry   ry   rw   rz   r     s           c                      s:   e Zd ZdZdd
dddddddddd
 fddZ  ZS )SEResNext50zy
    SEResNext50 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r       r   Nr8   r!   FTr&   r%   r'   r(   r)   r   c
                   s8   t  jf t|||||||d|
 |r4t| d|	 d S )Nr,   r-   r.   r0   r/   r2   r3   r4   r   rO   rP   r   r   r   rw   ry   rz   rP     s    	zSEResNext50.__init__)	r   r   r   Nr8   r!   FFTr   ry   ry   rw   rz   r     s            r   c                      s:   e Zd ZdZdd
dddddddddd
 fddZ  ZS )r   zz
    SEResNext101 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r   r   r   Nr8   r!   FTr&   r%   r'   r(   r)   r   c
                   s8   t  jf t|||||||d|
 |r4t| d|	 d S )Nr   r   r   r   rw   ry   rz   rP     s    	zSEResNext101.__init__)	r   r   r   Nr8   r!   FFTr   ry   ry   rw   rz   r     s            )?
__future__r   r   collectionsr   collections.abcr   typingr   rk   torch.nnr[   Z	torch.hubr   monai.apps.utilsr   "monai.networks.blocks.convolutionsr   Z,monai.networks.blocks.squeeze_and_excitationr	   r
   r   monai.networks.layers.factoriesr   r   r   r   r   monai.utils.moduler   __all__r   Moduler   r   r   r   r   r   r   r   SEnetSenetSEnet154Senet154r   
SEresnet50
Seresnet50
seresnet50SEresnet101Seresnet101seresnet101SEresnet152Seresnet152seresnet152r   SEresnext50Seresnext50seresnext50SEResNeXt101SEresnext101Seresnext101seresnext101ry   ry   ry   rz   <module>   sX   
 l3   ""