o
    -il                     @  s   d dl mZ d dlZd dlmZ d dlZd dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZmZ dgZG d	d dejZdS )
    )annotationsN)Sequence)PatchEmbeddingBlockTransformerBlock)Conv)ensure_tuple_repis_sqrt
ViTAutoEncc                      sB   e Zd ZdZ											
	
d$d% fd d!Zd"d# Z  ZS )&r
   a  
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Modified to also give same dimension outputs as the input size of the image
                   conv           Fin_channelsintimg_sizeSequence[int] | int
patch_sizeout_channelsdeconv_chnshidden_sizemlp_dim
num_layers	num_heads	proj_typestrdropout_ratefloatspatial_dimsqkv_biasbool	save_attnreturnNonec              
     s  t    t|std| dt||| _t||| _|| _t| j| jD ]\}}|| dkr<td| d| dq't	||||
 | jd| _
t fddt|D | _t| _ttj| jf }d	d | jD }||||d
| _|||||d| _dS )a]  
        Args:
            in_channels: dimension of input channels or the number of channels for input.
            img_size: dimension of input image.
            patch_size: dimension of patch size
            out_channels:  number of output channels. Defaults to 1.
            deconv_chns: number of channels for the deconvolution layers. Defaults to 16.
            hidden_size: dimension of hidden layer. Defaults to 768.
            mlp_dim: dimension of feedforward layer. Defaults to 3072.
            num_layers:  number of transformer blocks. Defaults to 12.
            num_heads: number of attention heads. Defaults to 12.
            proj_type: position embedding layer type. Defaults to "conv".
            dropout_rate: fraction of the input units to drop. Defaults to 0.0.
            spatial_dims: number of spatial dimensions. Defaults to 3.
            qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
            save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False.

        Examples::

            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
            # It will provide an output of same size as that of the input
            >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv')

            # for 3-channel with image size of (128,128,128), output will be same size as of input
            >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv')

        z(patch_size should be square number, got .r   zpatch_size=z! should be divisible by img_size=)r   r   r   r   r   r   r    r"   c              	     s   g | ]}t  qS  r   .0ir    r   r   r   r#   r%   r)   `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/vitautoenc.py
<listcomp>f   s    z'ViTAutoEnc.__init__.<locals>.<listcomp>c                 S  s   g | ]	}t t|qS r)   )r   mathsqrtr*   r)   r)   r.   r/   o   s    )kernel_sizestride)r   r   r2   r3   N)super__init__r	   
ValueErrorr   r   r   r"   zipr   patch_embeddingnn
ModuleListrangeblocks	LayerNormnormr   	CONVTRANSconv3d_transposeconv3d_transpose_1)selfr   r   r   r   r   r   r   r   r   r   r    r"   r#   r%   mpZ
conv_transup_kernel_size	__class__r-   r.   r5   $   s@   
-
zViTAutoEnc.__init__c                 C  s   |j dd }| |}g }| jD ]}||}|| q| |}|dd}dd t|| jD }t	||j d |j d g|}| 
|}| |}||fS )z
        Args:
            x: input tensor must have isotropic spatial dimensions,
                such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
           Nr   c                 S  s   g | ]\}}|| qS r)   r)   )r+   srD   r)   r)   r.   r/      s    z&ViTAutoEnc.forward.<locals>.<listcomp>r   )shaper8   r<   appendr>   	transposer7   r   torchreshaper@   rA   )rB   xspatial_sizehidden_states_outblkdr)   r)   r.   forwardu   s   


 

zViTAutoEnc.forward)r   r   r   r   r   r   r   r   r   FF)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r   r#   r$   r%   r$   r&   r'   )__name__
__module____qualname____doc__r5   rT   __classcell__r)   r)   rF   r.   r
      s    Q)
__future__r   r0   collections.abcr   rM   torch.nnr9   Z$monai.networks.blocks.patchembeddingr   Z&monai.networks.blocks.transformerblockr   monai.networks.layersr   monai.utilsr   r	   __all__Moduler
   r)   r)   r)   r.   <module>   s   