U
    Ph=                     @  sb  d dl mZ d dlZd dlmZ d dlmZmZ d dlZd dl	m
Z
 d dlmZ d dlmZmZmZ d dlmZmZ d dlmZ d	d
ddddddddddddgZG dd de
jZG dd de
jZG dd de
jZG dd	 d	e
jZddd d!d"d#ZG d$d deZG d%d deZG d&d deZ G d'd deZ!eZ"e Z#Z$e Z%Z&e  Z'Z(e! Z)Z*dS )(    )annotationsN)OrderedDict)CallableSequence)load_state_dict_from_url)ConvDropoutPool)get_act_layerget_norm_layer)look_up_optionDenseNetDensenetDenseNet121densenet121Densenet121DenseNet169densenet169Densenet169DenseNet201densenet201Densenet201DenseNet264densenet264Densenet264c                
      sN   e Zd Zdddifdfddddddddd	 fd
dZdddddZ  ZS )_DenseLayerreluinplaceTbatchintfloatstr | tupleNone)spatial_dimsin_channelsgrowth_ratebn_sizedropout_probactnormreturnc              
     s   t    || }ttj|f }	ttj|f }
t | _| j	dt
|||d | j	dt|d | j	d|	||ddd | j	d	t
|||d | j	d
t|d | j	d|	||dddd |dkr| j	d|
| dS )aH  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of the input channel.
            growth_rate: how many filters to add each layer (k in paper).
            bn_size: multiplicative factor for number of bottle neck layers.
                (i.e. bn_size * k features in the bottleneck layer)
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        norm1namer#   channelsrelu1r-   conv1   Fkernel_sizebiasnorm2relu2conv2   )r4   paddingr5   r   dropoutN)super__init__r   CONVr   DROPOUTnn
Sequentiallayers
add_moduler   r
   )selfr#   r$   r%   r&   r'   r(   r)   out_channels	conv_typeZdropout_type	__class__ Q/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/densenet.pyr=   .   s    

z_DenseLayer.__init__torch.Tensorxr*   c                 C  s   |  |}t||gdS )Nr2   )rB   torchcat)rD   rM   new_featuresrI   rI   rJ   forwardV   s    
z_DenseLayer.forward)__name__
__module____qualname__r=   rQ   __classcell__rI   rI   rG   rJ   r   ,   s   	
"(r   c                      s@   e Zd Zdddifdfdddddddddd		 fd
dZ  ZS )_DenseBlockr   r   Tr   r   r    r!   r"   )	r#   rB   r$   r&   r%   r'   r(   r)   r*   c	              
     sN   t    t|D ]6}	t|||||||d}
||7 }| d|	d  |
 qdS )a{  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            layers: number of layers in the block.
            in_channels: number of the input channel.
            bn_size: multiplicative factor for number of bottle neck layers.
                (i.e. bn_size * k features in the bottleneck layer)
            growth_rate: how many filters to add each layer (k in paper).
            dropout_prob: dropout rate after each dense layer.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        )r(   r)   zdenselayer%dr2   N)r<   r=   ranger   rC   )rD   r#   rB   r$   r&   r%   r'   r(   r)   ilayerrG   rI   rJ   r=   ]   s
    
z_DenseBlock.__init__rR   rS   rT   r=   rU   rI   rI   rG   rJ   rV   [   s   

rV   c                      s:   e Zd Zdddifdfddddddd fd	d
Z  ZS )_Transitionr   r   Tr   r   r!   r"   )r#   r$   rE   r(   r)   r*   c              	     s~   t    ttj|f }ttj|f }| dt|||d | dt|d | d|||ddd | d	|d
d
d dS )af  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of the input channel.
            out_channels: number of the output classes.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        r)   r,   r   r0   convr2   Fr3   pool   )r4   strideN)	r<   r=   r   r>   r	   AVGrC   r   r
   )rD   r#   r$   rE   r(   r)   rF   	pool_typerG   rI   rJ   r=   }   s    
z_Transition.__init__rZ   rI   rI   rG   rJ   r[   {   s   
r[   c                      sb   e Zd ZdZdddddddifd	d
fdddddddddddd fddZdddddZ  ZS )r   a  
    Densenet based on: `Densely Connected Convolutional Networks <https://arxiv.org/pdf/1608.06993.pdf>`_.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.
    This network is non-deterministic When `spatial_dims` is 3 and CUDA is enabled. Please check the link below
    for more details:
    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

    Args:
        spatial_dims: number of spatial dimensions of the input image.
        in_channels: number of the input channel.
        out_channels: number of the output classes.
        init_features: number of filters in the first convolution layer.
        growth_rate: how many filters to add each layer (k in paper).
        block_config: how many layers in each pooling block.
        bn_size: multiplicative factor for number of bottle neck layers.
            (i.e. bn_size * k features in the bottleneck layer)
        act: activation type and arguments. Defaults to relu.
        norm: feature normalization type and arguments. Defaults to batch norm.
        dropout_prob: dropout rate after each dense layer.
    @                      r   r   Tr   g        r   Sequence[int]r!   r    r"   )r#   r$   rE   init_featuresr%   block_configr&   r(   r)   r'   r*   c                   s  t    ttj|f }ttj|f }ttj|f }tt	d|||dddddfdt
|	||dfd	t|d
fd|ddddfg| _|}t|D ]\}}t||||||
||	d}| jd|d  | ||| 7 }|t|d kr| jdt
|	||d q|d }t|||||	d}| jd|d  | |}qtt	dt|d
fd|dfdtdfdt||fg| _|  D ]}t||rtjt|j njt|tjtjtjfrtjt|jd tjt|j d n$t|tjrtjt|j d qd S )Nconv0   r^   r9   F)r4   r_   r:   r5   norm0r,   relu0r0   pool0r2   )r4   r_   r:   )r#   rB   r$   r&   r%   r'   r(   r)   
denseblocknorm5)r$   rE   r(   r)   
transitionr   r]   flattenoutr   )!r<   r=   r   r>   r	   MAXADAPTIVEAVGr@   rA   r   r   r
   features	enumeraterV   rC   lenr[   FlattenLinearclass_layersmodules
isinstanceinitkaiming_normal_rN   	as_tensorweightBatchNorm1dBatchNorm2dBatchNorm3d	constant_r5   )rD   r#   r$   rE   rk   r%   rl   r&   r(   r)   r'   rF   ra   Zavg_pool_typerX   
num_layersblockZ_out_channelstransmrG   rI   rJ   r=      sz    

     
zDenseNet.__init__rK   rL   c                 C  s   |  |}| |}|S )N)ry   r~   )rD   rM   rI   rI   rJ   rQ      s    

zDenseNet.forward)rR   rS   rT   __doc__r=   rQ   rU   rI   rI   rG   rJ   r      s   
(Oz	nn.Modulestrbool)modelarchprogressc           	        s   dddd}t ||d}|dkr(tdtd}t||dt D ]F}||}|rJ|d	d
 |d |d }| |< |= qJ| 	   fdd
 D   |   dS )z
    This function is used to load pretrained models.
    Adapted from PyTorch Hub 2D version: https://pytorch.org/vision/stable/models.html#id16.

    z<https://download.pytorch.org/models/densenet121-a639ec97.pthz<https://download.pytorch.org/models/densenet169-b2777c0a.pthz<https://download.pytorch.org/models/densenet201-c1103571.pth)r   r   r   Nz]only 'densenet121', 'densenet169' and 'densenet201' are supported to load pretrained weights.z_^(.*denselayer\d+)(\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$)r   r2   z.layersr^   r9   c                   s2   i | ]*\}}| kr | j | j kr||qS rI   )shape).0kv
model_dict
state_dictrI   rJ   
<dictcomp>!  s
       z$_load_state_dict.<locals>.<dictcomp>)r   
ValueErrorrecompiler   listkeysmatchgroupr   itemsupdateload_state_dict)	r   r   r   
model_urls	model_urlpatternkeyresnew_keyrI   r   rJ   _load_state_dict  s2    
"
r   c                      s8   e Zd ZdZdddddddd	d	d
d	 fddZ  ZS )r   zFDenseNet121 with optional pretrained support when `spatial_dims` is 2.rb   rc   rd   FTr   rj   r   r"   	r#   r$   rE   rk   r%   rl   
pretrainedr   r*   c	           
   	     sD   t  jf ||||||d|	 |r@|dkr4tdt| d| d S )Nr#   r$   rE   rk   r%   rl   r^   Parameter `spatial_dims` is > 2 ; currently PyTorch Hub does notprovide pretrained models for more than two spatial dimensions.r   r<   r=   NotImplementedErrorr   
rD   r#   r$   rE   rk   r%   rl   r   r   kwargsrG   rI   rJ   r=   +  s     	zDenseNet121.__init__)rb   rc   rd   FTrR   rS   rT   r   r=   rU   rI   rI   rG   rJ   r   (  s        c                      s8   e Zd ZdZdddddddd	d	d
d	 fddZ  ZS )r   zFDenseNet169 with optional pretrained support when `spatial_dims` is 2.rb   rc   re   rf   rc   rc   FTr   rj   r   r"   r   c	           
   	     sD   t  jf ||||||d|	 |r@|dkr4tdt| d| d S )Nr   r^   r   r   r   r   rG   rI   rJ   r=   L  s     	zDenseNet169.__init__)rb   rc   r   FTr   rI   rI   rG   rJ   r   I  s        c                      s8   e Zd ZdZdddddddd	d	d
d	 fddZ  ZS )r   zFDenseNet201 with optional pretrained support when `spatial_dims` is 2.rb   rc   re   rf   0   rc   FTr   rj   r   r"   r   c	           
   	     sD   t  jf ||||||d|	 |r@|dkr4tdt| d| d S )Nr   r^   r   r   r   r   rG   rI   rJ   r=   m  s     	zDenseNet201.__init__)rb   rc   r   FTr   rI   rI   rG   rJ   r   j  s        c                      s8   e Zd Zd Zdddddddddd	d
	 fddZ  ZS )r   rb   rc   re   rf   rb   r   FTr   rj   r   r"   r   c	           
   	     s0   t  jf ||||||d|	 |r,tdd S )Nr   zECurrently PyTorch Hub does not provide densenet264 pretrained models.)r<   r=   r   r   rG   rI   rJ   r=     s    	zDenseNet264.__init__)rb   rc   r   FTr   rI   rI   rG   rJ   r     s        )+
__future__r   r   collectionsr   collections.abcr   r   rN   torch.nnr@   Z	torch.hubr   monai.networks.layers.factoriesr   r   r	   monai.networks.layers.utilsr
   r   monai.utils.moduler   __all__Moduler   rA   rV   r[   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rI   rI   rI   rJ   <module>   sL   / k%!!!