U
    Ph@                     @  s   d Z ddlmZ ddlmZmZ ddlmZ ddlm	Z	 ddl
mZmZmZ e	d\ZZd	gZG d
d	 d	ejZddddddd	dddZdS )z
This script is modified from from torchvision to support N-D images,
by overriding the definition of convolutional layers and pooling layers.

https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
    )annotations)Tensornn)resnet)optional_import   )ExtraFPNBlockFeaturePyramidNetworkLastLevelMaxPoolztorchvision.modelsBackboneWithFPNc                	      sD   e Zd ZdZdddddddd	d
 fddZdddddZ  ZS )r   a  
    Adds an FPN on top of a model.
    Internally, it uses torchvision.models._utils.IntermediateLayerGetter to
    extract a submodel that returns the feature maps specified in return_layers.
    The same limitations of IntermediateLayerGetter apply here.

    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
    Except that this class uses spatial_dims

    Args:
        backbone: backbone network
        return_layers: a dict containing the names
            of the modules for which the activations will be returned as
            the key of the dict, and the value of the dict is the name
            of the returned activation (which the user can specify).
        in_channels_list: number of channels for each feature map
            that is returned, in the order they are present in the OrderedDict
        out_channels: number of channels in the FPN.
        spatial_dims: 2D or 3D images
    Nz	nn.Modulezdict[str, str]z	list[int]intz
int | NoneExtraFPNBlock | NoneNone)backbonereturn_layersin_channels_listout_channelsspatial_dimsextra_blocksreturnc                   s   t    |d kr`t|dr0t|jtr0|j}n0t|jtjrDd}nt|jtj	rXd}nt
d|d krpt|}tjj||d| _t||||d| _|| _d S )Nr         z;Could not find spatial_dims of backbone, please specify it.)r   )r   r   r   r   )super__init__hasattr
isinstancer   r   conv1r   Conv2dConv3d
ValueErrorr
   torchvision_models_utilsZIntermediateLayerGetterbodyr	   fpnr   )selfr   r   r   r   r   r   	__class__ ]/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/backbone_fpn_utils.pyr   Y   s&    	
zBackboneWithFPN.__init__r   zdict[str, Tensor])xr   c                 C  s   |  |}| |}|S )z
        Computes the resulted feature maps of the network.

        Args:
            x: input images

        Returns:
            feature maps after FPN layers. They are ordered from highest resolution first.
        )r"   r#   )r$   r)   yr'   r'   r(   forward{   s    


zBackboneWithFPN.forward)NN)__name__
__module____qualname____doc__r   r+   __classcell__r'   r'   r%   r(   r   C   s
      "   Nzresnet.ResNetr   zlist[int] | Noner   )r   r   trainable_layersreturned_layersr   r   c           
        s  |dk s|dkrt d| dddddgd	| }|dkrF|d
 |  D ](\}tfdd|D rN|d qN|d	krt|}|d	krddddg}t|dkst|dkrt d| dd t|D }| j	d   fdd|D }d}	t
| |||	||dS )a)  
    Same code as https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/backbone_utils.py
    Except that ``in_channels_stage2 = backbone.in_planes // 8`` instead of ``in_channels_stage2 = backbone.inplanes // 8``,
    and it requires spatial_dims: 2D or 3D images.
    r   r1   z3Trainable layers should be in the range [0,5], got layer4layer3layer2layer1r   Nbn1c                 3  s   | ]}  | V  qd S )N)
startswith).0layer)namer'   r(   	<genexpr>   s     z(_resnet_fpn_extractor.<locals>.<genexpr>Fr   r   r      z6Each returned layer should be in the range [1,4]. Got c                 S  s    i | ]\}}d | t |qS )r;   )str)r:   vkr'   r'   r(   
<dictcomp>   s      z)_resnet_fpn_extractor.<locals>.<dictcomp>   c                   s   g | ]} d |d   qS )r   r   r'   )r:   i)in_channels_stage2r'   r(   
<listcomp>   s     z)_resnet_fpn_extractor.<locals>.<listcomp>   )r   r   )r   appendnamed_parametersallrequires_grad_r
   minmax	enumerateZ	in_planesr   )
r   r   r2   r3   r   Zlayers_to_train	parameterr   r   r   r'   )rE   r<   r(   _resnet_fpn_extractor   s4    

     rP   )r1   NN)r/   
__future__r   torchr   r   monai.networks.netsr   monai.utilsr   feature_pyramid_networkr   r	   r
   r    ___all__Moduler   rP   r'   r'   r'   r(   <module>.   s   J   