o
    -in                     @  sH  d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 d dl
mZ d dlZd dlmZ d dlmZ d dlmZmZ d d	lmZmZmZ d d
lmZ d dlmZmZ eddd\ZZeddd\ZZdZ dZ!g dZ"dg ddddfdg ddddfdg ddddfdg ddddfdg ddddfdg ddddfdg ddddfd Z#e$e%Z&d!d" Z'd#d$ Z(G d%d& d&ej)Z*G d'd( d(ej)Z+G d)d* d*ej)Z,G d+d, d,e,Z-G d-d. d.e-eZ.d\d=d>Z/d]d^d?d@Z0d]d^dAdBZ1d]d^dCdDZ2d]d^dEdFZ3d]d^dGdHZ4d]d^dIdJZ5d]d^dKdLZ6d_d`dRdSZ7dadTdUZ8dbdcdZd[Z9dS )d    )annotationsN)Callable)partial)Path)Any)BaseEncoder)ConvPool)get_act_layerget_norm_layerget_pool_layer)ensure_tuple_rep)look_up_optionoptional_importZhuggingface_hubhf_hub_downloadnamezhuggingface_hub.utils._errorsEntryNotFoundError#TencentMedicalNet/MedicalNet-Resnetresnet_)
ResNetResNetBlockResNetBottleneckresnet10resnet18resnet34resnet50	resnet101	resnet152	resnet200basic   r"   r"   r"   BFT   r%   r%   r%   A         r(   
bottleneckr(   r)      r(   r(      $   r(   r(      r0   r(   r   r   r   r   r   r   r   c                   C     g dS )N)@             r9   r9   r9   \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/resnet.pyget_inplanes?      r;   c                   C  r4   )N)r   r"   )r"   r"   )r"   r"   r"   r9   r9   r9   r9   r:   get_avgpoolC   r<   r=   c                      s@   e Zd ZdZddddddifdfd fddZdddZ  ZS )r   r"   r(   NreluinplaceTbatch	in_planesintplanesspatial_dimsstride
downsamplenn.Module | partial | Noneactstr | tuplenormreturnNonec           	        s~   t    ttj|f }|||dd|dd| _t|||d| _t|d| _|||dddd| _	t|||d| _
|| _|| _dS )	a  
        Args:
            in_planes: number of input channels.
            planes: number of output channels.
            spatial_dims: number of spatial dimensions of the input image.
            stride: stride to use for first conv layer.
            downsample: which downsample layer to use.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        r(   r"   F)kernel_sizepaddingrE   biasr   rD   channelsr   )rM   rN   rO   N)super__init__r   CONVconv1r   bn1r
   rH   conv2bn2rF   rE   )	selfrA   rC   rD   rE   rF   rH   rJ   	conv_type	__class__r9   r:   rS   J   s   

zResNetBlock.__init__xtorch.Tensorc                 C  s`   |}|  |}| |}| |}| |}| |}| jd ur%| |}||7 }| |}|S N)rU   rV   rH   rW   rX   rF   rY   r]   residualoutr9   r9   r:   forwardj   s   







zResNetBlock.forwardrA   rB   rC   rB   rD   rB   rE   rB   rF   rG   rH   rI   rJ   rI   rK   rL   r]   r^   rK   r^   __name__
__module____qualname__	expansionrS   rc   __classcell__r9   r9   r[   r:   r   G   s    
 r   c                      s@   e Zd ZdZddddddifdfd fddZdddZ  ZS )r   r)   r(   r"   Nr>   r?   Tr@   rA   rB   rC   rD   rE   rF   rG   rH   rI   rJ   rK   rL   c           
        s   t    ttj|f }tt||d}	|||ddd| _|	|d| _|||d|ddd| _|	|d| _	|||| j
 ddd| _|	|| j
 d| _t|d| _|| _|| _d	S )
a  
        Args:
            in_planes: number of input channels.
            planes: number of output channels (taking expansion into account).
            spatial_dims: number of spatial dimensions of the input image.
            stride: stride to use for second conv layer.
            downsample: which downsample layer to use.
            act: activation type and arguments. Defaults to relu.
            norm: feature normalization type and arguments. Defaults to batch norm.
        )r   rD   r"   F)rM   rO   )rQ   r(   rM   rE   rN   rO   r   N)rR   rS   r   rT   r   r   rU   rV   rW   rX   rj   conv3bn3r
   rH   rF   rE   )
rY   rA   rC   rD   rE   rF   rH   rJ   rZ   
norm_layerr[   r9   r:   rS      s   

zResNetBottleneck.__init__r]   r^   c                 C  s~   |}|  |}| |}| |}| |}| |}| |}| |}| |}| jd ur4| |}||7 }| |}|S r_   )rU   rV   rH   rW   rX   rm   rn   rF   r`   r9   r9   r:   rc      s   










zResNetBottleneck.forwardrd   re   rf   r9   r9   r[   r:   r   }   s    
$r   c                      sj   e Zd ZdZddddddddd	d	d
dd	ifdfd5 fd&d'Zd6d7d,d-Z		d8d9d1d2Zd:d3d4Z  ZS );r   a  
    ResNet based on: `Deep Residual Learning for Image Recognition <https://arxiv.org/pdf/1512.03385.pdf>`_
    and `Can Spatiotemporal 3D CNNs Retrace the History of 2D CNNs and ImageNet? <https://arxiv.org/pdf/1711.09577.pdf>`_.
    Adapted from `<https://github.com/kenshohara/3D-ResNets-PyTorch/tree/master/models>`_.

    Args:
        block: which ResNet block to use, either Basic or Bottleneck.
            ResNet block class or str.
            for Basic: ResNetBlock or 'basic'
            for Bottleneck: ResNetBottleneck or 'bottleneck'
        layers: how many layers to use.
        block_inplanes: determine the size of planes at each step. Also tunable with widen_factor.
        spatial_dims: number of spatial dimensions of the input image.
        n_input_channels: number of input channels for first convolutional layer.
        conv1_t_size: size of first convolution layer, determines kernel and padding.
        conv1_t_stride: stride of first convolution layer.
        no_max_pool: bool argument to determine if to use maxpool layer.
        shortcut_type: which downsample block to use. Options are 'A', 'B', default to 'B'.
            - 'A': using `self._downsample_basic_block`.
            - 'B': kernel_size 1 conv + norm.
        widen_factor: widen output for each layer.
        num_classes: number of output (classifications).
        feed_forward: whether to add the FC layer for the output, default to `True`.
        bias_downsample: whether to use bias term in the downsampling block when `shortcut_type` is 'B', default to `True`.
        act: activation type and arguments. Defaults to relu.
        norm: feature normalization type and arguments. Defaults to batch norm.

    r(      r"   Fr#   g      ?i  Tr>   r?   r@   block*type[ResNetBlock | ResNetBottleneck] | strlayers	list[int]block_inplanesrD   rB   n_input_channelsconv1_t_sizetuple[int] | intconv1_t_strideno_max_poolboolshortcut_typestrwiden_factorfloatnum_classesfeed_forwardbias_downsamplerH   rI   rJ   rK   rL   c                   sD  t    t|tr|dkrt}n|dkrt}ntd| ttj|f }t	t	j
|f }t	t	j|f }t } fdd|D }|d | _|| _|| _t||}t||}||| j||tdd |D d	d
| _t||| jd}|| _t|d| _|dddd| _| ||d |d ||	| _| j||d |d ||	dd| _| j||d |d ||	dd| _| j||d |d ||	dd| _||| | _|rt|d |j  |nd | _!| " D ]F}t||rtj#j$t%&|j'ddd qt|t(|rtj#)t%&|j'd tj#)t%&|j*d qt|tjrtj#)t%&|j*d qd S )Nr    r+   z+Unknown block '%s', use basic or bottleneckc                   s   g | ]}t |  qS r9   )rB   ).0r]   r~   r9   r:   
<listcomp>   s    z#ResNet.__init__.<locals>.<listcomp>r   c                 s  s    | ]}|d  V  qdS )r%   Nr9   )r   kr9   r9   r:   	<genexpr>
  s    z"ResNet.__init__.<locals>.<genexpr>Frl   rP   r   r(   r%   r"   )rM   rE   rN   )rE   fan_outr>   )modenonlinearity)+rR   rS   
isinstancer}   r   r   
ValueErrorr   rT   r	   MAXZADAPTIVEAVGr=   rA   rz   r   r   tuplerU   r   rV   r
   rH   maxpool_make_layerlayer1layer2layer3layer4avgpoolnnLinearrj   fcmodulesinitkaiming_normal_torch	as_tensorweighttype	constant_rO   )rY   rq   rs   ru   rD   rv   rw   ry   rz   r|   r~   r   r   r   rH   rJ   rZ   	pool_typeZ	avgp_typeZblock_avgpoolZconv1_kernel_sizeZconv1_stridero   mr[   r   r:   rS      s^   




	    
zResNet.__init__r]   r^   rC   rE   c                 C  sl   t dd|df|d|}tj|d||d g|jdd  R |j|jd}tj|j|gdd}|S )	Navgr"   )rM   rE   )rD   r   r%   )dtypedevice)dim)	r   r   zerossizeshaper   r   catdata)rY   r]   rC   rE   rD   rb   Z	zero_padsr9   r9   r:   _downsample_basic_block"  s   :zResNet._downsample_basic_block$type[ResNetBlock | ResNetBottleneck]blocksnn.Sequentialc              	   C  s   t t j|f }d }	|dks| j||j krEt|ddhdkr+t| j||j ||d}	nt|| j||j d|| j	dt
||||j d}	|| j||||	|dg}
||j | _td|D ]}|
|| j|||d q\tj|
 S )	Nr"   r&   r#   )rC   rE   rD   )rM   rE   rO   rP   )rA   rC   rD   rE   rF   rJ   )rD   rJ   )r   rT   rA   rj   r   r   r   r   
Sequentialr   r   rangeappend)rY   rq   rC   r   rD   r|   rE   rJ   rZ   rF   rs   _ir9   r9   r:   r   (  sB   

zResNet._make_layerc                 C  s   |  |}| |}| |}| js| |}| |}| |}| |}| |}| 	|}|
|dd}| jd urC| |}|S )Nr   )rU   rV   rH   rz   r   r   r   r   r   r   viewr   r   )rY   r]   r9   r9   r:   rc   Z  s   










zResNet.forward) rq   rr   rs   rt   ru   rt   rD   rB   rv   rB   rw   rx   ry   rx   rz   r{   r|   r}   r~   r   r   rB   r   r{   r   r{   rH   rI   rJ   rI   rK   rL   )r(   )
r]   r^   rC   rB   rE   rB   rD   rB   rK   r^   )r"   r@   )rq   r   rC   rB   r   rB   rD   rB   r|   r}   rE   rB   rJ   rI   rK   r   re   )	rg   rh   ri   __doc__rS   r   r   rc   rk   r9   r9   r[   r:   r      s(    "
I2r   c                      s*   e Zd Zdd fddZdddZ  ZS )ResNetFeaturesTr(   r"   
model_namer}   
pretrainedr{   rD   rB   in_channelsrK   rL   c                   s   |t vrdt  }td| d| dt | \}}}}	}
t j||t ||d|d|	d	 |rF|dkrB|d	krBt| ||
d
 dS tddS )a  Initialize resnet18 to resnet200 models as a backbone, the backbone can be used as an encoder for
        segmentation and objection models.

        Compared with the class `ResNet`, the only different place is the forward function.

        Args:
            model_name: name of model to initialize, can be from [resnet10, ..., resnet200].
            pretrained: whether to initialize pretrained MedicalNet weights,
                only available for spatial_dims=3 and in_channels=1.
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels for first convolutional layer.
        z, zinvalid model_name z found, must be one of  r%   F)	rq   rs   ru   rD   rv   ry   r|   r   r   r(   r"   )
datasets23zQPretrained resnet models are only available for in_channels=1 and spatial_dims=3.N)resnet_paramsjoinkeysr   rR   rS   r;   _load_state_dict)rY   r   r   rD   r   Zmodel_name_stringrq   rs   r|   r   r   r[   r9   r:   rS   q  s(   zResNetFeatures.__init__inputsr^   c                 C  s   |  |}| |}| |}g }|| | js| |}| |}|| | |}|| | |}|| | 	|}|| |S )z
        Args:
            inputs: input should have spatially N dimensions
            ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``, N is defined by `dimensions`.

        Returns:
            a list of torch Tensors.
        )
rU   rV   rH   r   rz   r   r   r   r   r   )rY   r   r]   featuresr9   r9   r:   rc     s    
	











zResNetFeatures.forward)Tr(   r"   )
r   r}   r   r{   rD   rB   r   rB   rK   rL   )r   r^   )rg   rh   ri   rS   rc   rk   r9   r9   r[   r:   r   o  s    $r   c                   @  sP   e Zd ZdZg dZedddZeddd	ZedddZedddZ	dS )ResNetEncoderz9Wrap the original resnet to an encoder for flexible-unet.r3   rK   
list[dict]c                 C  s(   g }| j D ]}||dddd q|S )z6Get the initialization parameter for resnet backbones.Tr(   r"   )r   r   rD   r   )backbone_namesr   )clsparameter_listZbackbone_namer9   r9   r:   get_encoder_parameters  s   
z$ResNetEncoder.get_encoder_parameterslist[tuple[int, ...]]c                 C  r4   )z:Get number of resnet backbone output feature maps channel.)r5   r5   r6   r7   r8   r   r   r5   r7   r8   i   i   r   r   r   r9   r   r9   r9   r:   num_channels_per_output  s   z%ResNetEncoder.num_channels_per_outputrt   c                 C  s
   dgd S )zGet number of resnet backbone output feature maps.

        Since every backbone contains the same 5 output feature maps, the number list should be `[5] * 7`.
           rp   r9   r   r9   r9   r:   num_outputs  s   
zResNetEncoder.num_outputs	list[str]c                 C  s   | j S )zGet names of resnet backbones.)r   r   r9   r9   r:   get_encoder_names  s   zResNetEncoder.get_encoder_namesN)rK   r   )rK   r   )rK   rt   )rK   r   )
rg   rh   ri   r   r   classmethodr   r   r   r   r9   r9   r9   r:   r     s    	r   archr}   rq   r   rs   rt   ru   r   
bool | strprogressr{   kwargsr   rK   c                 K  sN  t |||fi |}|rtj rdnd}t|tr6t| r2t	d| d tj
||dd}	nctd|dd	d	kr|d
d	dkr|dddu rtd| }
|
r^t|
d}ntdt|\}}||ddkr||ddkrt||dd}	ntd| d| d| tdtddd |	 D }	|j|	dd |S )NcudacpuzLoading weights from z...Tmap_locationweights_onlyz+The pretrained checkpoint file is not foundrD   r(   rv   r"   r   Fresnet(\d+)z1arch argument should be as 'resnet_{resnet_depth}r|   r#   r   r   r   zPlease set shortcut_type to z and bias_downsample to z( when using pretrained MedicalNet resnetzgPlease set n_input_channels to 1and feed_forward to False in order to use MedicalNet pretrained weightsz>MedicalNet pretrained weights are only avalaible for 3D modelsc                 S     i | ]\}}| d d|qS zmodule. replacer   keyvaluer9   r9   r:   
<dictcomp>      z_resnet.<locals>.<dictcomp>)strict)r   r   r   is_availabler   r}   r   existsloggerinfoloadFileNotFoundErrorgetresearchrB   groupr   %get_medicalnet_pretrained_resnet_args get_pretrained_resnet_medicalnetNotImplementedErroritemsload_state_dict)r   rq   rs   ru   r   r   r   modelr   model_state_dict
search_resresnet_depthr   r|   r9   r9   r:   _resnet  s@   	
 r   c                 K      t dtg dt | |fi |S )aw  ResNet-10 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r!   r   r   r;   r   r   r   r9   r9   r:   r         	r   c                 K  r   )aw  ResNet-18 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r$   r   r   r9   r9   r:   r   $  r   r   c                 K  r   )aw  ResNet-34 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r'   r   r   r9   r9   r:   r   0  r   r   c                 K  r   )aw  ResNet-50 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 23 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r'   r   r   r;   r   r9   r9   r:   r   <  r   r   c                 K  r   )aw  ResNet-101 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r,   r  r   r9   r9   r:   r   H  r   r   c                 K  r   )aw  ResNet-152 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r.   r  r   r9   r9   r:   r   T  r   r   c                 K  r   )aw  ResNet-200 with optional pretrained support when `spatial_dims` is 3.

    Pretraining from `Med3D: Transfer Learning for 3D Medical Image Analysis <https://arxiv.org/pdf/1904.00625.pdf>`_.

    Args:
        pretrained (bool): If True, returns a model pre-trained on 8 medical datasets
        progress (bool): If True, displays a progress bar of the download to stderr
    r   r1   r  r   r9   r9   r:   r   `  r   r   r   r   rB   r   r   c           	      C  s  d}d}g d}t d| |   | |v rz|s| |  dn| |  d}zt| |  |d}W n; tym   |r_t | d|   | |  d}t d	|  t| |  |d}nt| d
| |  dY nw tj|t|dd}ntdt | d |	dS )a  
    Download resnet pretrained weights from https://huggingface.co/TencentMedicalNet

    Args:
        resnet_depth: depth of the pretrained model. Supported values are 10, 18, 34, 50, 101, 152 and 200
        device: device on which the returned state dict will be loaded. "cpu" or "cuda" for example.
        datasets23: if True, get the weights trained on more datasets (23).
                    Not all depths are available. If not, standard weights are returned.

    Returns:
        Pretrained state dict

    Raises:
        huggingface_hub.utils._errors.EntryNotFoundError: if pretrained weights are not found on huggingface hub
        NotImplementedError: if `resnet_depth` is not supported
    r   r   )
      "   2   e         z@Loading MedicalNet pretrained model from https://huggingface.co/z.pthz_23dataset.pth)Zrepo_idfilenamez not available for resnetzTrying with z not found on NTr   z;Supported resnet_depth are: [10, 18, 34, 50, 101, 152, 200]z downloaded
state_dict)
r   r   r   	Exceptionr   r   r   r   r   r   )	r   r   r   Z$medicalnet_huggingface_repo_basenameZ%medicalnet_huggingface_files_basenameZsupported_depthr	  Zpretrained_path
checkpointr9   r9   r:   r   l  sD   

r   c                 C  s    | dv }| dv r
dnd}||fS )z{
    Return correct shortcut_type and bias_downsample
    for pretrained MedicalNet weights according to resnet depth.
    )r  r  r&   r#   r9   )r   r   r|   r9   r9   r:   r     s   r   r   	nn.Moduler   rL   c                 C  s`   t d|}|rt|d}|d}ntdt|d|d}dd | D }| | d S )	Nr   r"   Z_23datasetszZmodel_name argument should contain resnet depth. Example: resnet18 or resnet18_23datasets.r   r   c                 S  r   r   r   r   r9   r9   r:   r     r   z$_load_state_dict.<locals>.<dictcomp>)	r   r   rB   r   endswithr   r   r   r   )r   r   r   r   r   r   r9   r9   r:   r     s   r   )r   r}   rq   r   rs   rt   ru   rt   r   r   r   r{   r   r   rK   r   )FT)r   r{   r   r{   r   r   rK   r   )r   T)r   rB   r   r}   r   r{   )r   rB   )T)r   r  r   r}   r   r{   rK   rL   ):
__future__r   loggingr   collections.abcr   	functoolsr   pathlibr   typingr   r   torch.nnr   Zmonai.networks.blocks.encoderr   monai.networks.layers.factoriesr   r	   Zmonai.networks.layers.utilsr
   r   r   monai.utilsr   monai.utils.moduler   r   r   _r   Z$MEDICALNET_HUGGINGFACE_REPO_BASENAMEZ%MEDICALNET_HUGGINGFACE_FILES_BASENAME__all__r   	getLoggerrg   r   r;   r=   Moduler   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r9   r9   r9   r:   <module>   s^   
6> 5H
*7
7