U
    Ph                     @  s   d dl mZ d dlmZmZ d dlZd dlmZmZ d dl	m
Z
mZ d dlmZmZ eddd	\ZZedd
d	\ZZG dd dejjZdS )    )annotations)AnyDictN)Convget_pool_layer)look_up_named_moduleset_named_module)look_up_optionoptional_importz%torchvision.models.feature_extractionget_graph_node_names)namecreate_feature_extractorc                      sZ   e Zd ZdZddddddddfd	d
dfdddddddddd	 fddZdd Z  ZS )
NetAdapteraa  
    Wrapper to replace the last layer of model by convolutional layer or FC layer.

    See also: :py:class:`monai.networks.nets.TorchVisionFCModel`

    Args:
        model: a PyTorch model, which can be both 2D and 3D models. typically, it can be a pretrained model
            in Torchvision, like: ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, etc.
            more details: https://pytorch.org/vision/stable/models.html.
        num_classes: number of classes for the last classification layer. Default to 1.
        dim: number of supported spatial dimensions in the specified model, depends on the model implementation.
            default to 2 as most Torchvision models are for 2D image processing.
        in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
        use_conv: whether to use convolutional layer to replace the last layer, default to False.
        pool: parameters for the pooling layer, it should be a tuple, the first item is name of the pooling layer,
            the second item is dictionary of the initialization args. if None, will not replace the `layers[-2]`.
            default to `("avg", {"kernel_size": 7, "stride": 1})`.
        bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
            default to True.
        fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``"fc"``.
        node_name: the corresponding feature extractor node name of `model`.
            Defaults to "", the extractor is not in use.

          NFavg   )kernel_sizestrideTfc ztorch.nn.Moduleintz
int | Noneboolz!tuple[str, dict[str, Any]] | Nonestr)	modelnum_classesdimin_channelsuse_convpoolbiasfc_name	node_namec
                   sV  t    t| }
t||}|d kr0|
d }|d krRt|dsJtd|j}n|}|d kr|	dkrntdt||rt||t	j
 | _nt	j
j|
d d  | _d | _nX|	rtrt|	t||jrdnd }	t||	g| _nt	j
j|
d d  | _t||d	| _|  |r,ttj|f ||d|d
| _nt	j
j|||d| _|| _|| _|	| _d S )Nin_featureszSplease specify input channels of the last fully connected layer with `in_channels`.r   zE`node_name` is not compatible with `pool=None`, please set `pool=''`.r   r   )r   spatial_dims)r   out_channelsr   r    )r$   out_featuresr    )super__init__listchildrenr   hasattr
ValueErrorr$   r   torchnnIdentityfeatures
Sequentialr   
_has_utilsr	   r   trainingr   r   r   CONVr   Linearr   r   r"   )selfr   r   r   r   r   r   r    r!   r"   layersZorig_fcZin_channels_	__class__ S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/netadapter.pyr*   4   s:    



zNetAdapter.__init__c                 C  s   |  |}t|tr|d }n"tj|tttjf r@|| j }| j	d k	rT| 	|}| j
sht|d}nt|j| jd k r|d }qh| |}|S )Nr   r   r   ).N)r2   
isinstancetupler/   jitr   r   Tensorr"   r   r   flattenlenshaper   r   )r8   xr<   r<   r=   forwardl   s    







zNetAdapter.forward)__name__
__module____qualname____doc__r*   rF   __classcell__r<   r<   r:   r=   r      s   $8r   )
__future__r   typingr   r   r/   monai.networks.layersr   r   monai.networks.utilsr   r   monai.utilsr	   r
   r   r4   r   _r0   Moduler   r<   r<   r<   r=   <module>   s   