U
    Ph                     @  s   d dl mZ d dlZd dlmZ d dlZd dlmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZmZmZ dgZG d	d dejZdS )
    )annotationsN)Sequence)PatchEmbeddingBlockTransformerBlock)Conv)deprecated_argensure_tuple_repis_sqrt
ViTAutoEncc                      s`   e Zd ZdZedddddddddddddddddddddddd fddZdd Z  ZS )r   a  
    Vision Transformer (ViT), based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"

    Modified to also give same dimension outputs as the input size of the image
    	pos_embedz1.2z1.4	proj_typezplease use `proj_type` instead.)namesinceremovednew_name
msg_suffix               conv           FintzSequence[int] | intstrfloatboolNone)in_channelsimg_size
patch_sizeout_channelsdeconv_chnshidden_sizemlp_dim
num_layers	num_headsr   r   dropout_ratespatial_dimsqkv_bias	save_attnreturnc              
     s  t    t|s"td| dt||| _t||| _|| _t| j| jD ]*\}}|| dkrNtd| d| dqNt	|||| | jd| _
t fddt|D | _t| _ttj| jf }d	d | jD }||||d
| _|||||d| _dS )a  
        Args:
            in_channels: dimension of input channels or the number of channels for input.
            img_size: dimension of input image.
            patch_size: dimension of patch size
            out_channels:  number of output channels. Defaults to 1.
            deconv_chns: number of channels for the deconvolution layers. Defaults to 16.
            hidden_size: dimension of hidden layer. Defaults to 768.
            mlp_dim: dimension of feedforward layer. Defaults to 3072.
            num_layers:  number of transformer blocks. Defaults to 12.
            num_heads: number of attention heads. Defaults to 12.
            proj_type: position embedding layer type. Defaults to "conv".
            dropout_rate: fraction of the input units to drop. Defaults to 0.0.
            spatial_dims: number of spatial dimensions. Defaults to 3.
            qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
            save_attn: to make accessible the attention in self attention block. Defaults to False. Defaults to False.

        .. deprecated:: 1.4
            ``pos_embed`` is deprecated in favor of ``proj_type``.

        Examples::

            # for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
            # It will provide an output of same size as that of the input
            >>> net = ViTAutoEnc(in_channels=1, patch_size=(16,16,16), img_size=(96,96,96), proj_type='conv')

            # for 3-channel with image size of (128,128,128), output will be same size as of input
            >>> net = ViTAutoEnc(in_channels=3, patch_size=(16,16,16), img_size=(128,128,128), proj_type='conv')

        z(patch_size should be square number, got .r   zpatch_size=z! should be divisible by img_size=)r    r!   r"   r%   r(   r   r)   r*   c              	     s   g | ]}t  qS  r   .0ir)   r%   r&   r(   r+   r,   r/   S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/vitautoenc.py
<listcomp>m   s   z'ViTAutoEnc.__init__.<locals>.<listcomp>c                 S  s   g | ]}t t|qS r/   )r   mathsqrtr0   r/   r/   r4   r5   v   s     )kernel_sizestride)r    r#   r8   r9   N)super__init__r
   
ValueErrorr	   r"   r!   r*   zipr   patch_embeddingnn
ModuleListrangeblocks	LayerNormnormr   	CONVTRANSconv3d_transposeconv3d_transpose_1)selfr    r!   r"   r#   r$   r%   r&   r'   r(   r   r   r)   r*   r+   r,   mpZ
conv_transup_kernel_size	__class__r3   r4   r;   $   sD    4

   zViTAutoEnc.__init__c                 C  s   |j dd }| |}g }| jD ]}||}|| q"| |}|dd}dd t|| jD }t	||j d |j d f|}| 
|}| |}||fS )z
        Args:
            x: input tensor must have isotropic spatial dimensions,
                such as ``[batch_size, channels, sp_size, sp_size[, sp_size]]``.
           Nr   c                 S  s   g | ]\}}|| qS r/   r/   )r1   srJ   r/   r/   r4   r5      s     z&ViTAutoEnc.forward.<locals>.<listcomp>r   )shaper>   rB   appendrD   	transposer=   r"   torchreshaperF   rG   )rH   xspatial_sizehidden_states_outblkdr/   r/   r4   forward|   s    


 

zViTAutoEnc.forward)r   r   r   r   r   r   r   r   r   r   FF)__name__
__module____qualname____doc__r   r;   rZ   __classcell__r/   r/   rL   r4   r      s,                   4U)
__future__r   r6   collections.abcr   rS   torch.nnr?   Z$monai.networks.blocks.patchembeddingr   Z&monai.networks.blocks.transformerblockr   monai.networks.layersr   monai.utilsr   r	   r
   __all__Moduler   r/   r/   r/   r4   <module>   s   