o
    ,i.                     @  s`   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 dgZG dd dejZdS )    )annotations)SequenceN)PatchEmbeddingBlockTransformerBlockViTc                      sF   e Zd ZdZ											
			d&d' fd"d#Zd$d% Z  ZS )(r   z
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    ViT supports Torchscript but only works for Pytorch after 1.8.
             conv	learnableF              Tanhin_channelsintimg_sizeSequence[int] | int
patch_sizehidden_sizemlp_dim
num_layers	num_heads	proj_typestrpos_embed_typeclassificationboolnum_classesdropout_ratefloatspatial_dimsqkv_bias	save_attnreturnNonec                   s   t    d   krdkstd td dkr!td|
| _t|||||	 |d	| _t fddt|D | _	t
| _| jrvttdd| _|dkrmtt|t | _d	S t|| _d	S d	S )
aO	  
        Args:
            in_channels (int): dimension of input channels.
            img_size (Union[Sequence[int], int]): dimension of input image.
            patch_size (Union[Sequence[int], int]): dimension of patch size.
            hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
            mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
            num_layers (int, optional): number of transformer blocks. Defaults to 12.
            num_heads (int, optional): number of attention heads. Defaults to 12.
            proj_type (str, optional): patch embedding layer type. Defaults to "conv".
            pos_embed_type (str, optional): position embedding type. Defaults to "learnable".
            classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
            num_classes (int, optional): number of classes if classification is used. Defaults to 2.
            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
            spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
            post_activation (str, optional): add a final acivation function to the classification head
                when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
                Set to other values to remove this function.
            qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
            save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.

        Examples::

            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
            >>> net = ViT(in_channels=1, img_size=(96,96,96), proj_type='conv', pos_embed_type='sincos')

            # for 3-channel with image size of (128,128,128), 24 layers and classification backbone
            >>> net = ViT(in_channels=3, img_size=(128,128,128), proj_type='conv', pos_embed_type='sincos', classification=True)

            # for 3-channel with image size of (224,224), 12 layers and classification backbone
            >>> net = ViT(in_channels=3, img_size=(224,224), proj_type='conv', pos_embed_type='sincos', classification=True,
            >>>           spatial_dims=2)

        r      z'dropout_rate should be between 0 and 1.z-hidden_size should be divisible by num_heads.)	r   r   r   r   r   r   r   r    r"   c              	     s   g | ]}t  qS  r   ).0ir    r   r   r   r#   r$   r(   Y/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/vit.py
<listcomp>l   s    z ViT.__init__.<locals>.<listcomp>r   N)super__init__
ValueErrorr   r   patch_embeddingnn
ModuleListrangeblocks	LayerNormnorm	Parametertorchzeros	cls_token
SequentialLinearr   classification_head)selfr   r   r   r   r   r   r   r   r   r   r   r    r"   Zpost_activationr#   r$   	__class__r+   r,   r/   !   s>   
6zViT.__init__c                 C  s   |  |}t| dr| j|jd dd}tj||fdd}g }| jD ]}||}|| q#| 	|}t| drD| 
|d d df }||fS )Nr;   r   r'   )dimr>   )r1   hasattrr;   expandshaper9   catr5   appendr7   r>   )r?   xr;   hidden_states_outblkr(   r(   r,   forwardy   s   




zViT.forward)r   r	   r
   r
   r   r   Fr   r   r   r   FF) r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r   r#   r   r$   r   r%   r&   )__name__
__module____qualname____doc__r/   rL   __classcell__r(   r(   r@   r,   r      s"    X)
__future__r   collections.abcr   r9   torch.nnr2   Z$monai.networks.blocks.patchembeddingr   Z&monai.networks.blocks.transformerblockr   __all__Moduler   r(   r(   r(   r,   <module>   s   