o
    -ip                     @  sV   d dl mZ d dlmZ d dlmZ d dlmZ ed\ZZ	dgZ
G dd deZdS )	    )annotations)Any)
NetAdapter)optional_importztorchvision.modelsTorchVisionFCModelc                      sD   e Zd ZdZddddddddd	fd
ddddfd fddZ  ZS )r   a<  
    Customize the fully connected layer of (pretrained) TorchVision model or replace it by convolutional layer.

    This class supports two primary use cases:

        - use ``pool=None`` to indicate no modification in the pooling layers. It should be used with ``fc_name``
          to locate the target FC layer to modify:
          In this case, the class will load a torchvision classification model,
          replace the last fully connected (FC) layer with a new FC layer with ``num_classes`` outputs,
          example input arguments: ``use_conv=False, pool=None, fc_name="heads.head"``.
          The ``heads.head`` specifies the target FC of the input model, could be found by ``model.named_modules()``,
          for example::

              from torchvision.models import vit_b_16
              print([name[0] for name in vit_b_16().named_modules()])

        - use ``pool=""`` or set it to a tuple of pooling parameters to indicate modifications of both
          the pooling and the FC layer. It should be used with ``node_name`` to locate the model feature outputs:
          In this case, the class will load a torchvision model, remove the existing pooling and FC layers, and

          - append an additional convolution layer:
            ``use_conv=True, pool="", node_name="permute"``
          - append an additional pooling and FC layers:
            ``use_conv=False, pool=("avg", {"kernel_size": 7, "stride": 1}), node_name="permute"``
          - append an additional pooling and convolution layers:
            ``use_conv=True, pool=("avg", {"kernel_size": 7, "stride": 1}), node_name="permute"``

          The ``permute`` in the example is the target feature extraction node of the input
          `model_name`, could be found by using the torchvision feature extraction utilities, for example::

              from torchvision.models.feature_extraction import get_graph_node_names
              from torchvision.models import swin_t
              print(get_graph_node_names(swin_t())[0])

    Args:
        model_name: name of any torchvision model with fully connected layer at the end.
            ``resnet18`` (default), ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``,
            ``resnext50_32x4d``, ``resnext101_32x8d``, ``wide_resnet50_2``, ``wide_resnet101_2``, ``inception_v3``.
            model details: https://pytorch.org/vision/stable/models.html.
        num_classes: number of classes for the last classification layer. Default to 1.
        dim: number of supported spatial dimensions in the specified model, depends on the model implementation.
            default to 2 as most Torchvision models are for 2D image processing.
        in_channels: number of the input channels of last layer. if None, get it from `in_features` of last layer.
        use_conv: whether to use convolutional layer to replace the last layer, default to False.
        pool: parameters for the pooling layer, when it's a tuple, the first item is name of the pooling layer,
            the second item is dictionary of the initialization args. If None, will not replace the `layers[-2]`.
            default to `("avg", {"kernel_size": 7, "stride": 1})`. ``""`` indicates not adding a pooling layer.
        bias: the bias value when replacing the last layer. if False, the layer will not learn an additive bias,
            default to True.
        pretrained: whether to use the imagenet pretrained weights. Default to False.
        fc_name: the corresponding layer attribute of the last fully connected layer. Defaults to ``"fc"``.
        node_name: the corresponding feature extractor node name of `model`. Defaults to "", not in use.
        weights: additional weights enum for the torchvision model.
        kwargs: additional parameters for the torchvision model.

    Example::

        import torch
        from torchvision.models.inception import Inception_V3_Weights

        from monai.networks.nets import TorchVisionFCModel

        model = TorchVisionFCModel(
            "inception_v3",
            num_classes=4,
            weights=Inception_V3_Weights.IMAGENET1K_V1,
            use_conv=False,
            pool=None,
        )
        # model = TorchVisionFCModel("vit_b_16", num_classes=4, pool=None, in_channels=768, fc_name="heads")
        output = model.forward(torch.randn(2, 3, 299, 299))
        print(output.shape)  # torch.Size([2, 4])

    resnet18      NFavg   )kernel_sizestrideTfc 
model_namestrnum_classesintdimin_channels
int | Noneuse_convboolpool!tuple[str, dict[str, Any]] | Nonebias
pretrainedfc_name	node_namec                   sJ   |r|d u rd}t t|dd|i|}t j||||||||	|
d	 d S )NDEFAULTweights)	modelr   r   r   r   r   r   r   r    )getattrmodelssuper__init__)selfr   r   r   r   r   r   r   r   r   r   r    kwargsr!   	__class__r"   d/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/torchvision_fc.pyr&   d   s   
zTorchVisionFCModel.__init__)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )__name__
__module____qualname____doc__r&   __classcell__r"   r"   r)   r+   r      s    MN)
__future__r   typingr   monai.networks.netsr   monai.utilsr   r$   ___all__r   r"   r"   r"   r+   <module>   s   