o
    ,i                     @  s   d dl mZ d dlmZmZ d dlZd dlmZmZ d dl	m
Z
mZ d dlmZmZ eddd	\ZZedd
d	\ZZG dd dejjZdS )    )annotations)AnyDictN)Convget_pool_layer)look_up_named_moduleset_named_module)look_up_optionoptional_importz%torchvision.models.feature_extractionget_graph_node_names)namecreate_feature_extractorc                      sF   e Zd ZdZddddddddfd	d
dfd fddZdd Z  ZS ) 
NetAdapteraa  
    Wrapper to replace the last layer of model by convolutional layer or FC layer.

    See also: :py:class:`monai.networks.nets.TorchVisionFCModel`

    Args:
        model: a PyTorch model, which can be both 2D and 3D models. typically, it can be a pretrained model
            in Torchvision, like: ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, etc.
            more details: https://pytorch.org/vision/stable/models.html.
        num_classes: number of classes for the last classification layer. Default to 1.
        dim: number of supported spatial dimensions in the specified model, depends on the model implementation.
            default to 2 as most Torchvision models are for 2D image processing.
        in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
        use_conv: whether to use convolutional layer to replace the last layer, default to False.
        pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,
            the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.
            default to `("avg", {"kernel_size": 7, "stride": 1})`.
        bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
            default to True.
        fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``"fc"``.
        node_name: the corresponding feature extractor node name of `model`.
            Defaults to "", the extractor is not in use.

          NFavg   )kernel_sizestrideTfc modeltorch.nn.Modulenum_classesintdimin_channels
int | Noneuse_convboolpool!tuple[str, dict[str, Any]] | Nonebiasfc_namestr	node_namec
                   sT  t    t| }
t||}|d u r|
d }|d u r)t|ds%td|j}n|}|d u rV|	dkr7tdt||rGt||t	j
 | _nt	j
j|
d d  | _d | _n,|	rptrpt|	t||jrddnd }	t||	g| _nt	j
j|
d d  | _t||d	| _|  |rttj|f ||d|d
| _n
t	j
j|||d| _|| _|| _|	| _d S )Nin_featureszSplease specify input channels of the last fully connected layer with `in_channels`.r   zE`node_name` is not compatible with `pool=None`, please set `pool=''`.r   r   )r   spatial_dims)r   out_channelsr   r"   )r'   out_featuresr"   )super__init__listchildrenr   hasattr
ValueErrorr'   r   torchnnIdentityfeatures
Sequentialr    
_has_utilsr	   r   trainingr   r   r   CONVr   Linearr   r   r%   )selfr   r   r   r   r   r    r"   r#   r%   layersZorig_fcZin_channels_	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/netadapter.pyr-   4   s:   




zNetAdapter.__init__c                 C  s   |  |}t|tr|d }ntj|tttjf r || j }| j	d ur*| 	|}| j
s4t|d}nt|j| jd k rL|d }t|j| jd k s>| |}|S )Nr   r   r   ).N)r5   
isinstancetupler2   jitr   r$   Tensorr%   r    r   flattenlenshaper   r   )r;   xr?   r?   r@   forwardl   s   






zNetAdapter.forward)r   r   r   r   r   r   r   r   r   r   r    r!   r"   r   r#   r$   r%   r$   )__name__
__module____qualname____doc__r-   rI   __classcell__r?   r?   r=   r@   r      s    8r   )
__future__r   typingr   r   r2   monai.networks.layersr   r   monai.networks.utilsr   r   monai.utilsr	   r
   r   r7   r   _r3   Moduler   r?   r?   r?   r@   <module>   s   