U
    Phj                     @  sl   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZ dgZG dd dejZdS )	    )annotations)SequenceN)PatchEmbeddingBlockTransformerBlock)deprecated_argViTc                      sb   e Zd ZdZeddddddddddddddddddddddddd fddZdd Z  ZS )r   z
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    ViT supports Torchscript but only works for Pytorch after 1.8.
    	pos_embedz1.2z1.4	proj_typezplease use `proj_type` instead.)namesinceremovednew_name
msg_suffix         conv	learnableF              TanhintzSequence[int] | intstrboolfloatNone)in_channelsimg_size
patch_sizehidden_sizemlp_dim
num_layers	num_headsr	   r
   pos_embed_typeclassificationnum_classesdropout_ratespatial_dimsqkv_bias	save_attnreturnc                   s   t    d   krdks(n td dkr<td|| _t||||	|
 |d	| _t fddt|D | _	t
| _| jrttdd| _|dkrtt|t | _nt|| _d	S )
a	  
        Args:
            in_channels (int): dimension of input channels.
            img_size (Union[Sequence[int], int]): dimension of input image.
            patch_size (Union[Sequence[int], int]): dimension of patch size.
            hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
            mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
            num_layers (int, optional): number of transformer blocks. Defaults to 12.
            num_heads (int, optional): number of attention heads. Defaults to 12.
            proj_type (str, optional): patch embedding layer type. Defaults to "conv".
            pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
            classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
            num_classes (int, optional): number of classes if classification is used. Defaults to 2.
            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
            spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
            post_activation (str, optional): add a final acivation function to the classification head
                when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
                Set to other values to remove this function.
            qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
            save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.

        .. deprecated:: 1.4
            ``pos_embed`` is deprecated in favor of ``proj_type``.

        Examples::

            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
            >>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')

            # for 3-channel with image size of (128,128,128), 24 layers and classification backbone
            >>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)

            # for 3-channel with image size of (224,224), 12 layers and classification backbone
            >>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,
            >>>           spatial_dims=2)

        r      z'dropout_rate should be between 0 and 1.z-hidden_size should be divisible by num_heads.)	r   r   r    r!   r$   r
   r%   r(   r)   c              	     s   g | ]}t  qS  r   ).0ir(   r!   r"   r$   r*   r+   r.   L/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/vit.py
<listcomp>t   s   z ViT.__init__.<locals>.<listcomp>r   N)super__init__
ValueErrorr&   r   patch_embeddingnn
ModuleListrangeblocks	LayerNormnorm	Parametertorchzeros	cls_token
SequentialLinearr   classification_head)selfr   r   r    r!   r"   r#   r$   r	   r
   r%   r&   r'   r(   r)   Zpost_activationr*   r+   	__class__r1   r2   r5   "   s8    =
zViT.__init__c                 C  s   |  |}t| dr<| j|jd dd}tj||fdd}g }| jD ]}||}|| qF| 	|}t| dr| 
|d d df }||fS )NrA   r   r-   )dimrD   )r7   hasattrrA   expandshaper?   catr;   appendr=   rD   )rE   xrA   hidden_states_outblkr.   r.   r2   forward   s    




zViT.forward)r   r   r   r   r   r   r   Fr   r   r   r   FF)__name__
__module____qualname____doc__r   r5   rR   __classcell__r.   r.   rF   r2   r      s0                     6\)
__future__r   collections.abcr   r?   torch.nnr8   Z$monai.networks.blocks.patchembeddingr   Z&monai.networks.blocks.transformerblockr   monai.utilsr   __all__Moduler   r.   r.   r.   r2   <module>   s   