o
    &ilK                     @  s  d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	Z	d dl
mZ d dlmZ d dlmZ d dlmZ d d	lmZmZmZ d d
lmZmZmZmZmZ d dlmZ g dZdddddddZG dd dej Z!d*ddZ"G dd de!Z#G d d! d!e!Z$G d"d# d#e!Z%G d$d% d%e!Z&G d&d' d'e!Z'G d(d) d)e!Z(e! Z)Z*e# Z+ Z,Z-e$ Z. Z/Z0e% Z1 Z2Z3e& Z4 Z5Z6e' Z7 Z8 Z9Z:e( Z; Z< Z=Z>dS )+    )annotationsN)OrderedDict)Sequence)Any)load_state_dict_from_url)download_url)Convolution)SEBottleneckSEResNetBottleneckSEResNeXtBottleneck)ActConvDropoutNormPool)look_up_option)SENetSENet154
SEResNet50SEResNet101SEResNet152SEResNeXt50SEResNext101SE_NET_MODELSzAhttp://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pthzDhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pthzEhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pthzEhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pthzKhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pthzLhttp://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth)senet154se_resnet50se_resnet101se_resnet152se_resnext50_32x4dse_resnext101_32x4dc                      s^   e Zd ZdZ						d,d- fddZ		d.d/d"d#Zd0d&d'Zd0d(d)Zd1d*d+Z  Z	S )2r   a  
    SENet based on `Squeeze-and-Excitation Networks <https://arxiv.org/pdf/1709.01507.pdf>`_.
    Adapted from `Cadene Hub 2D version
    <https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/senet.py>`_.

    Args:
        spatial_dims: spatial dimension of the input data.
        in_channels: channel number of the input data.
        block: SEBlock class or str.
            for SENet154: SEBottleneck or 'se_bottleneck'
            for SE-ResNet models: SEResNetBottleneck or 'se_resnet_bottleneck'
            for SE-ResNeXt models:  SEResNeXtBottleneck or 'se_resnetxt_bottleneck'
        layers: number of residual blocks for 4 layers of the network (layer1...layer4).
        groups: number of groups for the 3x3 convolution in each bottleneck block.
            for SENet154: 64
            for SE-ResNet models: 1
            for SE-ResNeXt models:  32
        reduction: reduction ratio for Squeeze-and-Excitation modules.
            for all models: 16
        dropout_prob: drop probability for the Dropout layer.
            if `None` the Dropout layer is not used.
            for SENet154: 0.2
            for SE-ResNet models: None
            for SE-ResNeXt models: None
        dropout_dim: determine the dimensions of dropout. Defaults to 1.
            When dropout_dim = 1, randomly zeroes some of the elements for each channel.
            When dropout_dim = 2, Randomly zeroes out entire channels (a channel is a 2D feature map).
            When dropout_dim = 3, Randomly zeroes out entire channels (a channel is a 3D feature map).
        inplanes:  number of input channels for layer1.
            for SENet154: 128
            for SE-ResNet models: 64
            for SE-ResNeXt models: 64
        downsample_kernel_size: kernel size for downsampling convolutions in layer2, layer3 and layer4.
            for SENet154: 3
            for SE-ResNet models: 1
            for SE-ResNeXt models: 1
        input_3x3: If `True`, use three 3x3 convolutions instead of
            a single 7x7 convolution in layer0.
            - For SENet154: True
            - For SE-ResNet models: False
            - For SE-ResNeXt models: False
        num_classes: number of outputs in `last_linear` layer.
            for all models: 1000
    皙?         T  spatial_dimsintin_channelsblockCtype[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck] | strlayersSequence[int]groups	reductiondropout_probfloat | Nonedropout_diminplanesdownsample_kernel_size	input_3x3boolnum_classesreturnNonec                   s  t    t|tr%|dkrt}n|dkrt}n|dkrt}ntd| ttj	 }t
t
j|f }ttj|f }ttj|f }ttj|f }ttj|f }|	| _|| _|rd||dddd	d
dfd|ddfd|ddfd|dddd	d	d
dfd|ddfd|ddfd|d|	dd	d	d
dfd||	dfd|ddfg	}nd|||	dddd
dfd||	dfd|ddfg}|d|ddddf tt|| _| j|d|d ||d	d| _| j|d|d	 d|||
d| _| j|d|d d|||
d| _| j|d|d d|||
d| _|d	| _|d ur||nd | _ t!d|j" || _#| $ D ]E}t||r8tj%&t'(|j) q$t||rVtj%*t'(|j)d	 tj%*t'(|j+d q$t|tj!rhtj%*t'(|j+d q$d S ) NZse_bottleneckZse_resnet_bottleneckZse_resnetxt_bottleneckzUUnknown block '%s', use se_bottleneck, se_resnet_bottleneck or se_resnetxt_bottleneckconv1@   r#      r!   F)r'   out_channelskernel_sizestridepaddingbiasbn1)num_featuresrelu1T)inplaceconv2bn2relu2conv3bn3relu3   pool)r<   r=   	ceil_moder   )planesblocksr,   r-   r2   r"   )rM   rN   r=   r,   r-   r2      i   ),super__init__
isinstancestrr	   r
   r   
ValueErrorr   RELUr   CONVr   MAXr   BATCHr   DROPOUTADAPTIVEAVGr1   r%   appendnn
Sequentialr   layer0_make_layerlayer1layer2layer3layer4adaptive_avg_pooldropoutLinear	expansionlast_linearmodulesinitkaiming_normal_torch	as_tensorweight	constant_r?   )selfr%   r'   r(   r*   r,   r-   r.   r0   r1   r2   r3   r5   	relu_type	conv_type	pool_type	norm_typedropout_typeavg_pool_typeZlayer0_modulesm	__class__ [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/senet.pyrQ   `   s   


		
	zSENet.__init__=type[SEBottleneck | SEResNetBottleneck | SEResNeXtBottleneck]rM   rN   r=   nn.Sequentialc                 C  s   d }|dks| j ||j kr t| j| j ||j ||d tjdd}g }	|	|| j| j |||||d ||j | _ td|D ]}
|	|| j| j |||d q=tj	|	 S )Nr!   F)r%   r'   r;   stridesr<   actnormr?   )r%   r1   rM   r,   r-   r=   
downsample)r%   r1   rM   r,   r-   )
r1   rg   r   r%   r   rX   r[   ranger\   r]   )rp   r(   rM   rN   r,   r-   r=   r2   r   r*   _numrz   rz   r{   r_      sH   


zSENet._make_layerxtorch.Tensorc                 C  s6   |  |}| |}| |}| |}| |}|S N)r^   r`   ra   rb   rc   rp   r   rz   rz   r{   features  s   




zSENet.featuresc                 C  s8   |  |}| jd ur| |}t|d}| |}|S )Nr!   )rd   re   rl   flattenrh   r   rz   rz   r{   logits  s   



zSENet.logitsc                 C  s   |  |}| |}|S r   )r   r   r   rz   rz   r{   forward  s   

zSENet.forward)r    r!   r"   r#   Tr$   )r%   r&   r'   r&   r(   r)   r*   r+   r,   r&   r-   r&   r.   r/   r0   r&   r1   r&   r2   r&   r3   r4   r5   r&   r6   r7   )r!   r!   )r(   r|   rM   r&   rN   r&   r,   r&   r-   r&   r=   r&   r2   r&   r6   r}   )r   r   )r   r   r6   r   )
__name__
__module____qualname____doc__rQ   r_   r   r   r   __classcell__rz   rz   rx   r{   r   2   s    5}
1
r   model	nn.ModulearchrS   progressr4   c                   s  t |td}|du rtdtd}td}td}td}td}td}	t|trFt|d	 |d
 d tj	|d
 dddnt
||dt D ]l}
d}||
rct|d|
}nP||
rpt|d|
}nC||
r|
  |
< t|d|
}n.||
r|
  |
< t|d|
}n||
rt|d|
}n|	|
rt|	d|
}|r|
 |< |
= qR|    fdd D   |   dS )z:
    This function is used to load pretrained models.
    Nzonly 'senet154', 'se_resnet50', 'se_resnet101',  'se_resnet152', 'se_resnext50_32x4d', and se_resnext101_32x4d are supported to load pretrained weights.z%^(layer[1-4]\.\d\.(?:conv)\d\.)(\w*)$z%^(layer[1-4]\.\d\.)(?:bn)(\d\.)(\w*)$z+^(layer[1-4]\.\d\.)(?:se_module.fc1.)(\w*)$z+^(layer[1-4]\.\d\.)(?:se_module.fc2.)(\w*)$z*^(layer[1-4]\.\d\.)(?:downsample.0.)(\w*)$z*^(layer[1-4]\.\d\.)(?:downsample.1.)(\w*)$urlfilename)filepathT)map_locationweights_only)r   z	\1conv.\2z\1conv\2adn.N.\3z\1se_layer.fc.0.\2z\1se_layer.fc.2.\2z\1project.conv.\2z\1project.adn.N.\2c                   s2   i | ]\}}| v r | j | j kr||qS rz   )shape).0kv
model_dict
state_dictrz   r{   
<dictcomp>I  s    ,z$_load_state_dict.<locals>.<dictcomp>)r   r   rT   recompilerR   dictr   rl   loadr   listkeysmatchsubsqueezer   itemsupdateload_state_dict)r   r   r   	model_urlZpattern_convZ
pattern_bnZ
pattern_seZpattern_se2Zpattern_down_convZpattern_down_bnkeynew_keyrz   r   r{   _load_state_dict  sR   













r   c                      s.   e Zd ZdZ					dd fddZ  ZS )r   zlSENet154 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.r#      $   r#   r9      FTr*   r+   r,   r&   r-   
pretrainedr4   r   r6   r7   c                   s4   t  jdt|||d| |rt| d| d S d S )N)r(   r*   r,   r-   r   rz   )rP   rQ   r	   r   )rp   r*   r,   r-   r   r   kwargsrx   rz   r{   rQ   S  s   	zSENet154.__init__)r   r9   r   FT)r*   r+   r,   r&   r-   r&   r   r4   r   r4   r6   r7   r   r   r   r   rQ   r   rz   rz   rx   r{   r   P  s    r   c                      s6   e Zd ZdZ									dd fddZ  ZS )r   znSEResNet50 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.r#         r#   r!   r   Nr9   FTr*   r+   r,   r&   r-   r.   r/   r1   r2   r3   r4   r   r   r6   r7   c
                   s<   t  jdt|||||||d|
 |rt| d|	 d S d S )N)r(   r*   r,   r-   r.   r1   r2   r3   r   rz   rP   rQ   r
   r   rp   r*   r,   r-   r.   r1   r2   r3   r   r   r   rx   rz   r{   rQ   e     	zSEResNet50.__init__)	r   r!   r   Nr9   r!   FFTr*   r+   r,   r&   r-   r&   r.   r/   r1   r&   r2   r&   r3   r4   r   r4   r   r4   r6   r7   r   rz   rz   rx   r{   r   b  s    r   c                      4   e Zd ZdZ								dd fddZ  ZS )r   zy
    SEResNet101 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r#   r      r#   r!   r   r9   FTr*   r+   r,   r&   r-   r1   r2   r3   r4   r   r   r6   r7   c	           
   
     :   t  jdt||||||d|	 |rt| d| d S d S )Nr(   r*   r,   r-   r1   r2   r3   r   rz   r   
rp   r*   r,   r-   r1   r2   r3   r   r   r   rx   rz   r{   rQ        
zSEResNet101.__init__)r   r!   r   r9   r!   FFTr*   r+   r,   r&   r-   r&   r1   r&   r2   r&   r3   r4   r   r4   r   r4   r6   r7   r   rz   rz   rx   r{   r         r   c                      r   )r   zy
    SEResNet152 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r   r!   r   r9   FTr*   r+   r,   r&   r-   r1   r2   r3   r4   r   r   r6   r7   c	           
   
     r   )Nr   r   rz   r   r   rx   rz   r{   rQ     r   zSEResNet152.__init__)r   r!   r   r9   r!   FFTr   r   rz   rz   rx   r{   r     r   r   c                      6   e Zd ZdZ										dd fddZ  ZS )SEResNext50zy
    SEResNext50 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r       r   Nr9   r!   FTr*   r+   r,   r&   r-   r.   r/   r1   r2   r3   r4   r   r   r6   r7   c
                   <   t  jdt|||||||d|
 |rt| d|	 d S d S )Nr(   r*   r,   r.   r-   r1   r2   r3   r   rz   rP   rQ   r   r   r   rx   rz   r{   rQ     r   zSEResNext50.__init__)	r   r   r   Nr9   r!   FFTr   r   rz   rz   rx   r{   r         r   c                      r   )r   zz
    SEResNext101 based on `Squeeze-and-Excitation Networks` with optional pretrained support when spatial_dims is 2.
    r   r   r   Nr9   r!   FTr*   r+   r,   r&   r-   r.   r/   r1   r2   r3   r4   r   r   r6   r7   c
                   r   )Nr   r   rz   r   r   rx   rz   r{   rQ     r   zSEResNext101.__init__)	r   r   r   Nr9   r!   FFTr   r   rz   rz   rx   r{   r     r   r   )r   r   r   rS   r   r4   )?
__future__r   r   collectionsr   collections.abcr   typingr   rl   torch.nnr\   Z	torch.hubr   monai.apps.utilsr   "monai.networks.blocks.convolutionsr   Z,monai.networks.blocks.squeeze_and_excitationr	   r
   r   monai.networks.layers.factoriesr   r   r   r   r   monai.utils.moduler   __all__r   Moduler   r   r   r   r   r   r   r   SEnetSenetSEnet154Senet154r   
SEresnet50
Seresnet50
seresnet50SEresnet101Seresnet101seresnet101SEresnet152Seresnet152seresnet152r   SEresnext50Seresnext50seresnext50SEResNeXt101SEresnext101Seresnext101seresnext101rz   rz   rz   r{   <module>   sJ   
 
l3   ""