o
    -i=                     @  sB  d dl mZ d dlZd dlmZ d dlmZmZ d dlZd dl	m
Z
 d dlmZ d dlmZmZmZ d dlmZmZ d dlmZ g d	ZG d
d de
jZG dd de
jZG dd de
jZG dd de
jZd"ddZG dd deZG dd deZG dd deZ G d d! d!eZ!eZ"e Z#Z$e Z%Z&e  Z'Z(e! Z)Z*dS )#    )annotationsN)OrderedDict)CallableSequence)load_state_dict_from_url)ConvDropoutPool)get_act_layerget_norm_layer)look_up_option)DenseNetDensenetDenseNet121densenet121Densenet121DenseNet169densenet169Densenet169DenseNet201densenet201Densenet201DenseNet264densenet264Densenet264c                      s6   e Zd Zdddifdfd fddZdddZ  ZS )_DenseLayerreluinplaceTbatchspatial_dimsintin_channelsgrowth_ratebn_sizedropout_probfloatactstr | tuplenormreturnNonec              
     s   t    || }ttj|f }	ttj|f }
t | _| j	dt
|||d | j	dt|d | j	d|	||ddd | j	d	t
|||d | j	d
t|d | j	d|	||dddd |dkrr| j	d|
| dS dS )aH  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of the input channel.
            growth_rate: how many filters to add each layer (k in paper).
            bn_size: multiplicative factor for number of bottle neck layers.
                (i.e. bn_size * k features in the bottleneck layer)
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        norm1namer   channelsrelu1r-   conv1   Fkernel_sizebiasnorm2relu2conv2   )r4   paddingr5   r   dropoutN)super__init__r   CONVr   DROPOUTnn
Sequentiallayers
add_moduler   r
   )selfr   r!   r"   r#   r$   r&   r(   out_channels	conv_typeZdropout_type	__class__ ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/densenet.pyr=   .   s   

z_DenseLayer.__init__xtorch.Tensorc                 C  s   |  |}t||gdS )Nr2   )rB   torchcat)rD   rK   new_featuresrI   rI   rJ   forwardV   s   
z_DenseLayer.forward)r   r    r!   r    r"   r    r#   r    r$   r%   r&   r'   r(   r'   r)   r*   rK   rL   r)   rL   )__name__
__module____qualname__r=   rP   __classcell__rI   rI   rG   rJ   r   ,   s
    
	(r   c                      s,   e Zd Zdddifdfd fddZ  ZS )_DenseBlockr   r   Tr   r   r    rB   r!   r#   r"   r$   r%   r&   r'   r(   r)   r*   c	              
     sN   t    t|D ]}	t|||||||d}
||7 }| d|	d  |
 q	dS )a{  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            layers: number of layers in the block.
            in_channels: number of the input channel.
            bn_size: multiplicative factor for number of bottle neck layers.
                (i.e. bn_size * k features in the bottleneck layer)
            growth_rate: how many filters to add each layer (k in paper).
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        )r&   r(   zdenselayer%dr2   N)r<   r=   ranger   rC   )rD   r   rB   r!   r#   r"   r$   r&   r(   ilayerrG   rI   rJ   r=   ]   s   
z_DenseBlock.__init__)r   r    rB   r    r!   r    r#   r    r"   r    r$   r%   r&   r'   r(   r'   r)   r*   rR   rS   rT   r=   rU   rI   rI   rG   rJ   rV   [   s    

rV   c                      s,   e Zd Zdddifdfd fddZ  ZS )_Transitionr   r   Tr   r   r    r!   rE   r&   r'   r(   r)   r*   c              	     s~   t    ttj|f }ttj|f }| dt|||d | dt|d | d|||ddd | d	|d
d
d dS )af  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of the input channel.
            out_channels: number of the output classes.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        r(   r,   r   r0   convr2   Fr3   pool   )r4   strideN)	r<   r=   r   r>   r	   AVGrC   r   r
   )rD   r   r!   rE   r&   r(   rF   	pool_typerG   rI   rJ   r=   }   s   
z_Transition.__init__)r   r    r!   r    rE   r    r&   r'   r(   r'   r)   r*   rZ   rI   rI   rG   rJ   r[   {   s    
r[   c                      sD   e Zd ZdZdddddddifd	d
fd! fddZd"dd Z  ZS )#r   a  
    Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
    This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below
    for more details:
    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
            (i.e. bn_size * k features in the bottleneck layer)
        act: activation type and arguments. Defaults to relu.
        norm: feature normalization type and arguments. Defaults to batch norm.
        dropout_prob: dropout rate after each dense layer.
    @                      r   r   Tr   g        r   r    r!   rE   init_featuresr"   block_configSequence[int]r#   r&   r'   r(   r$   r%   r)   r*   c                   s  t    ttj|f }ttj|f }ttj|f }tt	d|||dddddfdt
|	||dfd	t|d
fd|ddddfg| _|}t|D ]R\}}t||||||
||	d}| jd|d  | ||| 7 }|t|d kr| jdt
|	||d qI|d }t|||||	d}| jd|d  | |}qItt	dt|d
fd|dfdtdfdt||fg| _|  D ]F}t||rtjt|j qt|tjtjtjfrtjt|jd tjt|j d qt|tjrtjt|j d qd S )Nconv0   r^   r9   F)r4   r_   r:   r5   norm0r,   relu0r0   pool0r2   )r4   r_   r:   )r   rB   r!   r#   r"   r$   r&   r(   
denseblocknorm5)r!   rE   r&   r(   
transitionr   r]   flattenoutr   )!r<   r=   r   r>   r	   MAXADAPTIVEAVGr@   rA   r   r   r
   features	enumeraterV   rC   lenr[   FlattenLinearclass_layersmodules
isinstanceinitkaiming_normal_rM   	as_tensorweightBatchNorm1dBatchNorm2dBatchNorm3d	constant_r5   )rD   r   r!   rE   rj   r"   rk   r#   r&   r(   r$   rF   ra   Zavg_pool_typerX   
num_layersblockZ_out_channelstransmrG   rI   rJ   r=      st   




zDenseNet.__init__rK   rL   c                 C  s   |  |}| |}|S )N)ry   r~   )rD   rK   rI   rI   rJ   rP      s   

zDenseNet.forward)r   r    r!   r    rE   r    rj   r    r"   r    rk   rl   r#   r    r&   r'   r(   r'   r$   r%   r)   r*   rQ   )rR   rS   rT   __doc__r=   rP   rU   rI   rI   rG   rJ   r      s    
Or   model	nn.Modulearchstrprogressboolc           	        s   dddd}t ||d}|du rtdtd}t||dt D ]#}||}|rH|d	d
 |d |d }| |< |= q%| 	   fdd
 D   |   dS )z
    This function is used to load pretrained models.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.

    z<https://download.pytorch.org/models/densenet121-a639ec97.pthz<https://download.pytorch.org/models/densenet169-b2777c0a.pthz<https://download.pytorch.org/models/densenet201-c1103571.pth)r   r   r   Nz]only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights.z_^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$)r   r2   z.layersr^   r9   c                   s2   i | ]\}}| v r | j | j kr||qS rI   )shape).0kv
model_dict
state_dictrI   rJ   
<dictcomp>!  s    ,z$_load_state_dict.<locals>.<dictcomp>)r   
ValueErrorrecompiler   listkeysmatchgroupr   itemsupdateload_state_dict)	r   r   r   
model_urls	model_urlpatternkeyresnew_keyrI   r   rJ   _load_state_dict  s4   
"
r   c                      .   e Zd ZdZ					dd fddZ  ZS )r   zFDenseNet121 with optional pretrained support when `spatial_dims` is 2.rb   rc   rd   FTr   r    r!   rE   rj   r"   rk   rl   
pretrainedr   r   r)   r*   c	           
   	     H   t  jd||||||d|	 |r"|dkrtdt| d| d S d S )Nr   r!   rE   rj   r"   rk   r^   Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does notprovide pretrained models for more than two spatial dimensions.r   rI   r<   r=   NotImplementedErrorr   
rD   r   r!   rE   rj   r"   rk   r   r   kwargsrG   rI   rJ   r=   +  "   	zDenseNet121.__init__)rb   rc   rd   FTr   r    r!   r    rE   r    rj   r    r"   r    rk   rl   r   r   r   r   r)   r*   rR   rS   rT   r   r=   rU   rI   rI   rG   rJ   r   (      r   c                      r   )r   zFDenseNet169 with optional pretrained support when `spatial_dims` is 2.rb   rc   re   rf   rc   rc   FTr   r    r!   rE   rj   r"   rk   rl   r   r   r   r)   r*   c	           
   	     r   )Nr   r^   r   r   rI   r   r   rG   rI   rJ   r=   L  r   zDenseNet169.__init__)rb   rc   r   FTr   r   rI   rI   rG   rJ   r   I  r   r   c                      r   )r   zFDenseNet201 with optional pretrained support when `spatial_dims` is 2.rb   rc   re   rf   0   rc   FTr   r    r!   rE   rj   r"   rk   rl   r   r   r   r)   r*   c	           
   	     r   )Nr   r^   r   r   rI   r   r   rG   rI   rJ   r=   m  r   zDenseNet201.__init__)rb   rc   r   FTr   r   rI   rI   rG   rJ   r   j  r   r   c                      s.   e Zd Zd Z					dd fddZ  ZS )r   rb   rc   re   rf   rb   r   FTr   r    r!   rE   rj   r"   rk   rl   r   r   r   r)   r*   c	           
   	     s0   t  jd||||||d|	 |rtdd S )Nr   zECurrently PyTorch Hub does not provide densenet264 pretrained models.rI   )r<   r=   r   r   rG   rI   rJ   r=     s   	zDenseNet264.__init__)rb   rc   r   FTr   r   rI   rI   rG   rJ   r     r   r   )r   r   r   r   r   r   )+
__future__r   r   collectionsr   collections.abcr   r   rM   torch.nnr@   Z	torch.hubr   monai.networks.layers.factoriesr   r   r	   monai.networks.layers.utilsr
   r   monai.utils.moduler   __all__Moduler   rA   rV   r[   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rI   rI   rI   rJ   <module>   s2   / 
k%!!!