U
    Ph>l                     @  s  d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ d dlZd dlmZ d dlmZ d dlmZmZmZ d d	lmZmZ d d
lmZ d dlmZmZ eddd\ZZeddd\ZZdZ dZ!ddddddddddg
Z"dddddgdd d!fdd"d"d"d"gd#d!d!fdd$d%d&d$gd#d!d!fd'd$d%d&d$gdd d!fd'd$d%d(d$gdd d fd'd$d)d*d$gdd d fd'd$d+d*d$gdd d fd,Z#e$e%Z&d-d. Z'd/d0 Z(G d1d dej)Z*G d2d dej)Z+G d3d dej)Z,G d4d5 d5e,Z-G d6d7 d7e-eZ.d8d9d:d:d;d<d=dd>d?d@Z/dVd<d<d=ddAdBdZ0dWd<d<d=ddAdCdZ1dXd<d<d=ddAdDdZ2dYd<d<d=ddAdEdZ3dZd<d<d=ddAdFdZ4d[d<d<d=ddAdGdZ5d\d<d<d=ddAdHdZ6d]dJd8d<dKdLdMZ7dJdNdOdPZ8d^dQd8d<dRdSdTdUZ9dS )_    )annotationsN)Callable)partial)Path)Any)BaseEncoder)ConvNormPool)get_act_layerget_pool_layer)ensure_tuple_rep)look_up_optionoptional_importZhuggingface_hubhf_hub_downloadnamezhuggingface_hub.utils._errorsEntryNotFoundError#TencentMedicalNet/MedicalNet-Resnetresnet_ResNetResNetBlockResNetBottleneckresnet10resnet18resnet34resnet50	resnet101	resnet152	resnet200basic   BFT   A         
bottleneck      $      )r   r   r   r   r   r   r   c                   C  s   ddddgS )N@             r1   r1   r1   O/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/resnet.pyget_inplanes?   s    r3   c                   C  s   ddddgS )Nr   r!   )r!   r!   )r!   r!   r!   r1   r1   r1   r1   r2   get_avgpoolC   s    r4   c                	      sT   e Zd ZdZddddddiffdddddd	d
d fddZdddddZ  ZS )r   r!   r%   NreluinplaceTintnn.Module | partial | Nonestr | tupleNone	in_planesplanesspatial_dimsstride
downsampleactreturnc           	        s   t    ttj|f }ttj|f }|||dd|dd| _||| _t|d| _	|||dddd| _
||| _|| _|| _dS )as  
        Args:
            in_planes: number of input channels.
            planes: number of output channels.
            spatial_dims: number of spatial dimensions of the input image.
            stride: stride to use for first conv layer.
            downsample: which downsample layer to use.
            act: activation type and arguments. Defaults to relu.
        r%   r!   F)kernel_sizepaddingr?   biasr   )rC   rD   rE   N)super__init__r   CONVr	   BATCHconv1bn1r   rA   conv2bn2r@   r?   	selfr<   r=   r>   r?   r@   rA   	conv_type	norm_type	__class__r1   r2   rG   J   s    


zResNetBlock.__init__torch.TensorxrB   c                 C  s`   |}|  |}| |}| |}| |}| |}| jd k	rJ| |}||7 }| |}|S N)rJ   rK   rA   rL   rM   r@   rO   rV   residualoutr1   r1   r2   forwardi   s    







zResNetBlock.forward__name__
__module____qualname__	expansionrG   r[   __classcell__r1   r1   rR   r2   r   G   s   
 c                	      sT   e Zd ZdZddddddiffddddd	d
dd fddZdddddZ  ZS )r   r&   r%   r!   Nr5   r6   Tr7   r8   r9   r:   r;   c           	        s   t    ttj|f }ttj|f }|||ddd| _||| _|||d|ddd| _||| _	|||| j
 ddd| _||| j
 | _t|d| _|| _|| _dS )a  
        Args:
            in_planes: number of input channels.
            planes: number of output channels (taking expansion into account).
            spatial_dims: number of spatial dimensions of the input image.
            stride: stride to use for second conv layer.
            downsample: which downsample layer to use.
            act: activation type and arguments. Defaults to relu.
        r!   F)rC   rE   r%   rC   r?   rD   rE   r   N)rF   rG   r   rH   r	   rI   rJ   rK   rL   rM   r`   conv3bn3r   rA   r@   r?   rN   rR   r1   r2   rG      s    


zResNetBottleneck.__init__rT   rU   c                 C  s~   |}|  |}| |}| |}| |}| |}| |}| |}| |}| jd k	rh| |}||7 }| |}|S rW   )rJ   rK   rA   rL   rM   rc   rd   r@   rX   r1   r1   r2   r[      s    










zResNetBottleneck.forwardr\   r1   r1   rR   r2   r   |   s   
 "c                      s   e Zd ZdZddddddddd	d	d
dd	iffdddddddddddddddd fddZd$ddddddddZd%dddddddddd Zddd!d"d#Z  ZS )&r   ag  
    ResNet based on: `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`_
    and `Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet? <https://arxiv.org/pdf/1711.09577.pdf>`_.
    Adapted from `<https://github.com/kenshohara/3D-ResNets-PyTorch/tree/master/models>`_.

    Args:
        block: which ResNet block to use, either Basic or Bottleneck.
            ResNet block class or str.
            for Basic: ResNetBlock or 'basic'
            for Bottleneck: ResNetBottleneck or 'bottleneck'
        layers: how many layers to use.
        block_inplanes: determine the size of planes at each step. Also tunable with widen_factor.
        spatial_dims: number of spatial dimensions of the input image.
        n_input_channels: number of input channels for first convolutional layer.
        conv1_t_size: size of first convolution layer, determines kernel and padding.
        conv1_t_stride: stride of first convolution layer.
        no_max_pool: bool argument to determine if to use maxpool layer.
        shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.
            - 'A': using `self._downsample_basic_block`.
            - 'B': kernel_size 1 conv + norm.
        widen_factor: widen output for each layer.
        num_classes: number of output (classifications).
        feed_forward: whether to add the FC layer for the output, default to `True`.
        bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
        act: activation type and arguments. Defaults to relu.

    r%      r!   Fr"   g      ?i  Tr5   r6   z*type[ResNetBlock | ResNetBottleneck] | str	list[int]r7   ztuple[int] | intboolstrfloatr9   r:   )blocklayersblock_inplanesr>   n_input_channelsconv1_t_sizeconv1_t_strideno_max_poolshortcut_typewiden_factornum_classesfeed_forwardbias_downsamplerA   rB   c                   sJ  t    t|tr<|dkr"t}n|dkr0t}ntd| ttj|f }t	t	j
|f }ttj|f }ttj|f }t } fdd|D }|d | _|| _|| _t||}t||}||| j||tdd |D d	d
| _|| j| _t|d| _|dddd| _| ||d |d ||	| _| j||d |d ||	dd| _| j||d |d ||	dd| _| j||d |d ||	dd| _||| | _|rt |d |j! |nd | _"| # D ]}t||rtj$j%t&'|j(ddd n^t||rtj$)t&'|j(d tj$)t&'|j*d n$t|tj rtj$)t&'|j*d qd S )Nr    r(   z+Unknown block '%s', use basic or bottleneckc                   s   g | ]}t |  qS r1   )r7   ).0rV   rr   r1   r2   
<listcomp>   s     z#ResNet.__init__.<locals>.<listcomp>r   c                 s  s   | ]}|d  V  qdS )r#   Nr1   )rv   kr1   r1   r2   	<genexpr>  s     z"ResNet.__init__.<locals>.<genexpr>Frb   r   r%   r#   r!   )rC   r?   rD   )r?   fan_outr5   )modenonlinearity)+rF   rG   
isinstancerh   r   r   
ValueErrorr   rH   r	   rI   r
   MAXZADAPTIVEAVGr4   r<   rp   ru   r   tuplerJ   rK   r   rA   maxpool_make_layerlayer1layer2layer3layer4avgpoolnnLinearr`   fcmodulesinitkaiming_normal_torch	as_tensorweight	constant_rE   )rO   rj   rk   rl   r>   rm   rn   ro   rp   rq   rr   rs   rt   ru   rA   rP   rQ   	pool_typeZ	avgp_typeZblock_avgpoolZconv1_kernel_sizeZconv1_stridemrR   rw   r2   rG      sZ    




   "zResNet.__init__rT   )rV   r=   r?   r>   rB   c                 C  sj   t dd|df|d|}tj|d||d f|jdd  |j|jd}tj|j|gdd}|S )	Navgr!   )rC   r?   r>   r   r#   )dtypedevice)dim)	r   r   zerossizeshaper   r   catdata)rO   rV   r=   r?   r>   rZ   Z	zero_padsr1   r1   r2   _downsample_basic_block  s    8zResNet._downsample_basic_block$type[ResNetBlock | ResNetBottleneck]znn.Sequential)rj   r=   blocksr>   rq   r?   rB   c              	   C  s   t t j|f }ttj|f }d }	|dks8| j||j krt|ddhdkrdt| j||j ||d}	n.t	
|| j||j d|| jd|||j }	|| j||||	dg}
||j | _td|D ]}|
|| j||d qt	j
|
 S )Nr!   r$   r"   )r=   r?   r>   )rC   r?   rE   )r<   r=   r>   r?   r@   r   )r   rH   r	   rI   r<   r`   r   r   r   r   
Sequentialru   rangeappend)rO   rj   r=   r   r>   rq   r?   rP   rQ   r@   rk   _ir1   r1   r2   r   "  sB    	    zResNet._make_layerrU   c                 C  s   |  |}| |}| |}| js.| |}| |}| |}| |}| |}| 	|}|
|dd}| jd k	r| |}|S )Nr   )rJ   rK   rA   rp   r   r   r   r   r   r   viewr   r   )rO   rV   r1   r1   r2   r[   O  s    










zResNet.forward)r%   )r!   )	r]   r^   r_   __doc__rG   r   r   r[   ra   r1   r1   rR   r2   r      s"   !
0G -c                      s:   e Zd Zddddddd fd	d
ZddddZ  ZS )ResNetFeaturesTr%   r!   rh   rg   r7   r:   )
model_name
pretrainedr>   in_channelsrB   c                   s   |t kr,dt  }td| d| dt | \}}}}	}
t j||t ||d|d|	d	 |r|dkr|d	krt| ||
d
 ntddS )a  Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
        segmentation and objection models.

        Compared with the class `ResNet`, the only different place is the forward function.

        Args:
            model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
            pretrained: whether to initialize pretrained MedicalNet weights,
                only available for spatial_dims=3 and in_channels=1.
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels for first convolutional layer.
        z, zinvalid model_name z found, must be one of  r#   F)	rj   rk   rl   r>   rm   ro   rq   rt   ru   r%   r!   )
datasets23zQPretrained resnet models are only available for in_channels=1 and spatial_dims=3.N)resnet_paramsjoinkeysr   rF   rG   r3   _load_state_dict)rO   r   r   r>   r   Zmodel_name_stringrj   rk   rq   ru   r   rR   r1   r2   rG   f  s&    zResNetFeatures.__init__rT   )inputsc                 C  s   |  |}| |}| |}g }|| | js<| |}| |}|| | |}|| | |}|| | 	|}|| |S )z
        Args:
            inputs: input should have spatially N dimensions
            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.

        Returns:
            a list of torch Tensors.
        )
rJ   rK   rA   r   rp   r   r   r   r   r   )rO   r   rV   featuresr1   r1   r2   r[     s     	












zResNetFeatures.forward)Tr%   r!   )r]   r^   r_   rG   r[   ra   r1   r1   rR   r2   r   d  s   $r   c                   @  sj   e Zd ZdZdddddddgZed	d
ddZedd
ddZedd
ddZedd
ddZ	dS )ResNetEncoderz9Wrap the original resnet to an encoder for flexible-unet.r   r   r   r   r   r   r   z
list[dict])rB   c                 C  s(   g }| j D ]}||dddd q
|S )z6Get the initialization parameter for resnet backbones.Tr%   r!   )r   r   r>   r   )backbone_namesr   )clsparameter_listZbackbone_namer1   r1   r2   get_encoder_parameters  s    
z$ResNetEncoder.get_encoder_parameterszlist[tuple[int, ...]]c                 C  s   dddddddgS )z:Get number of resnet backbone output feature maps channel.)r-   r-   r.   r/   r0   )r-   r/   r0   i   i   r1   r   r1   r1   r2   num_channels_per_output  s    z%ResNetEncoder.num_channels_per_outputrf   c                 C  s
   dgd S )zGet number of resnet backbone output feature maps.

        Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
           re   r1   r   r1   r1   r2   num_outputs  s    zResNetEncoder.num_outputsz	list[str]c                 C  s   | j S )zGet names of resnet backbones.)r   r   r1   r1   r2   get_encoder_names  s    zResNetEncoder.get_encoder_namesN)
r]   r^   r_   r   r   classmethodr   r   r   r   r1   r1   r1   r2   r     s   	r   rh   r   rf   z
bool | strrg   r   )archrj   rk   rl   r   progresskwargsrB   c                 K  s|  t |||f|}|rxtj r$dnd}t|trjt| r`t	d| d tj
||d}	ntdn|dddkrP|d	dd
krF|dddkrFtd| }
|
rt|
d
}ntdt|\}}||ddkr|dkrt||ddkrn t||dd}	n,td| d|dkr4t|nd d| ntdntddd |	 D }	|j|	dd |S )NcudacpuzLoading weights from z...map_locationz+The pretrained checkpoint file is not foundr>   r%   rm   r!   rt   TFresnet(\d+)z1arch argument should be as 'resnet_{resnet_depth}rq   r"   r   ru   r   r   zPlease set shortcut_type to z and bias_downsample tozTrue or Falsez'when using pretrained MedicalNet resnetzgPlease set n_input_channels to 1and feed_forward to False in order to use MedicalNet pretrained weightsz>MedicalNet pretrained weights are only avalaible for 3D modelsc                 S  s   i | ]\}}| d d|qS zmodule. replacerv   keyvaluer1   r1   r2   
<dictcomp>	  s     
 z_resnet.<locals>.<dictcomp>)strict)r   r   r   is_availabler~   rh   r   existsloggerinfoloadFileNotFoundErrorgetresearchr7   groupr   %get_medicalnet_pretrained_resnet_argsrg    get_pretrained_resnet_medicalnetNotImplementedErroritemsload_state_dict)r   rj   rk   rl   r   r   r   modelr   model_state_dict
search_resresnet_depthru   rq   r1   r1   r2   _resnet  sJ    	

$    &r   )r   r   r   rB   c                 K  s    t dtddddgt | |f|S )aw  ResNet-10 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r!   r   r   r3   r   r   r   r1   r1   r2   r     s    	c                 K  s    t dtddddgt | |f|S )aw  ResNet-18 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r#   r   r   r1   r1   r2   r     s    	c                 K  s    t dtddddgt | |f|S )aw  ResNet-34 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r%   r&   r'   r   r   r1   r1   r2   r   &  s    	c                 K  s    t dtddddgt | |f|S )aw  ResNet-50 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r%   r&   r'   r   r   r3   r   r1   r1   r2   r   2  s    	c                 K  s    t dtddddgt | |f|S )aw  ResNet-101 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r%   r&   r)   r   r   r1   r1   r2   r   >  s    	c                 K  s    t dtddddgt | |f|S )aw  ResNet-152 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r%   r*   r+   r   r   r1   r1   r2   r   J  s    	c                 K  s    t dtddddgt | |f|S )aw  ResNet-200 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r%   r,   r+   r   r   r1   r1   r2   r   V  s    	r   r7   )r   r   r   c           	      C  s   d}d}ddddddd	g}t d
| |   | |kr|sH| |  dn| |  d}zt| |  |d}W nx tk
r   |rt | d|   | |  d}t d|  t| |  |d}nt| d| |  dY nX tj|t|d}ntdt | d |	dS )a  
    Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet

    Args:
        resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
        device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example.
        datasets23: if True, get the weights trained on more datasets (23).
                    Not all depths are available. If not, standard weights are returned.

    Returns:
        Pretrained state dict

    Raises:
        huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub
        NotImplementedError: if `resnet_depth` is not supported
    r   r   
      "   2   e         z@Loading MedicalNet pretrained model from https://huggingface.co/z.pthz_23dataset.pth)Zrepo_idfilenamez not available for resnetzTrying with z not found on Nr   z;Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]z downloaded
state_dict)
r   r   r   	Exceptionr   r   r   r   r   r   )	r   r   r   Z$medicalnet_huggingface_repo_basenameZ%medicalnet_huggingface_files_basenameZsupported_depthr   Zpretrained_path
checkpointr1   r1   r2   r   b  sD    
 

 r   )r   c                 C  s(   | dkrdnd}| dkrdnd}||fS )z{
    Return correct shortcut_type and bias_downsample
    for pretrained MedicalNet weights according to resnet depth.
    )r   r   r   r   r$   r"   r1   )r   ru   rq   r1   r1   r2   r     s    r   z	nn.Moduler:   )r   r   r   rB   c                 C  s`   t d|}|r*t|d}|d}ntdt|d|d}dd | D }| | d S )	Nr   r!   Z_23datasetszZmodel_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.r   r   c                 S  s   i | ]\}}| d d|qS r   r   r   r1   r1   r2   r     s     
 z$_load_state_dict.<locals>.<dictcomp>)	r   r   r7   r   endswithr   r   r   r   )r   r   r   r   r   r   r1   r1   r2   r     s    r   )FT)FT)FT)FT)FT)FT)FT)r   T)T):
__future__r   loggingr   collections.abcr   	functoolsr   pathlibr   typingr   r   torch.nnr   Zmonai.networks.blocks.encoderr   monai.networks.layers.factoriesr   r	   r
   Zmonai.networks.layers.utilsr   r   monai.utilsr   monai.utils.moduler   r   r   _r   Z$MEDICALNET_HUGGINGFACE_REPO_BASENAMEZ%MEDICALNET_HUGGINGFACE_FILES_BASENAME__all__r   	getLoggerr]   r   r3   r4   Moduler   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r1   r1   r1   r2   <module>   sp   
5< -H*87