o
    -i%                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m
Z
 d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ h d
ZdgZG dd dejZdS )    )annotations)SequenceN)PatchEmbeddingBlock)build_sincos_position_embeddingTransformerBlock)trunc_normal_)ensure_tuple_rep)look_up_option>   sincos	learnablenoneMaskedAutoEncoderViTc                      sd   e Zd ZdZ															
			d1d2 fd'd(Zd)d* Zd3d4d-d.Zd3d4d/d0Z  ZS )5r   a3  
    Masked Autoencoder (ViT), based on: "Kaiming et al.,
    Masked Autoencoders Are Scalable Vision Learners <https://arxiv.org/abs/2111.06377>"
    Only a subset of the patches passes through the encoder. The decoder tries to reconstruct
    the masked patches, resulting in improved training speed.
                   ?     convr              Fin_channelsintimg_sizeSequence[int] | int
patch_sizehidden_sizemlp_dim
num_layers	num_headsmasking_ratiofloatdecoder_hidden_sizedecoder_mlp_dimdecoder_num_layersdecoder_num_heads	proj_typestrpos_embed_typedecoder_pos_embed_typedropout_ratespatial_dimsqkv_biasbool	save_attnreturnNonec                   s  t    d  krdksn td d dkr"td  dkr,tdt||| _t||| _|| _t| j| jD ]\}}|| dkrWtd| d| dqB | _|dksc|dkrktd	| d|| _	t
tdd| _t|||||| jd
	| _fddt|D }t
jg |t
R  | _t
 | _t
tdd | _t|t| _t
td| jj | _ fddt|D }t
jg |t
 R  | _t
 tt | j| | _!| "  dS )a  
        Args:
            in_channels: dimension of input channels or the number of channels for input.
            img_size: dimension of input image.
            patch_size: dimension of patch size
            hidden_size: dimension of hidden layer. Defaults to 768.
            mlp_dim: dimension of feedforward layer. Defaults to 512.
            num_layers:  number of transformer blocks. Defaults to 12.
            num_heads: number of attention heads. Defaults to 12.
            masking_ratio: ratio of patches to be masked. Defaults to 0.75.
            decoder_hidden_size: dimension of hidden layer for decoder. Defaults to 384.
            decoder_mlp_dim: dimension of feedforward layer for decoder. Defaults to 512.
            decoder_num_layers: number of transformer blocks for decoder. Defaults to 4.
            decoder_num_heads: number of attention heads for decoder. Defaults to 12.
            proj_type: position embedding layer type. Defaults to "conv".
            pos_embed_type: position embedding layer type. Defaults to "sincos".
            decoder_pos_embed_type: position embedding layer type for decoder. Defaults to "sincos".
            dropout_rate: fraction of the input units to drop. Defaults to 0.0.
            spatial_dims: number of spatial dimensions. Defaults to 3.
            qkv_bias: apply bias to the qkv linear layer in self attention block. Defaults to False.
            save_attn: to make accessible the attention in self attention block. Defaults to False.
        Examples::
            # for single channel input with image size of (96,96,96), and sin-cos positional encoding
            >>> net = MaskedAutoEncoderViT(in_channels=1, img_size=(96,96,96), patch_size=(16,16,16),
            pos_embed_type='sincos')
            # for 3-channel with image size of (128,128,128) and a learnable positional encoding
            >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=128, patch_size=16, pos_embed_type='learnable')
            # for 3-channel with image size of (224,224) and a masking ratio of 0.25
            >>> net = MaskedAutoEncoderViT(in_channels=3, img_size=(224,224), patch_size=(16,16), masking_ratio=0.25,
            spatial_dims=2)
        r      z,dropout_rate should be between 0 and 1, got .z-hidden_size should be divisible by num_heads.z=decoder_hidden_size should be divisible by decoder_num_heads.zpatch_size=z! should be divisible by img_size=z1masking_ratio should be in the range (0, 1), got )	r   r   r   r   r    r'   r)   r+   r,   c              	     s   g | ]}t  qS  r   .0_)r+   r   r   r    r-   r/   r4   l/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/masked_autoencoder_vit.py
<listcomp>       z1MaskedAutoEncoderViT.__init__.<locals>.<listcomp>c              	     s   g | ]}t  qS r4   r   r5   )r#   r$   r&   r+   r-   r/   r4   r8   r9      r:   N)#super__init__
ValueErrorr	   r   r   r,   zipr#   r!   nn	Parametertorchzeros	cls_tokenr   patch_embeddingrange
Sequential	LayerNormblocksLineardecoder_embedmask_tokensr
   SUPPORTED_POS_EMBEDDING_TYPESr*   	n_patchesdecoder_pos_embeddingdecoder_blocksr   npproddecoder_pred_init_weights)selfr   r   r   r   r   r   r    r!   r#   r$   r%   r&   r'   r)   r*   r+   r,   r-   r/   mprH   rO   	__class__)	r#   r$   r&   r+   r   r   r    r-   r/   r8   r<   (   sX   
6zMaskedAutoEncoderViT.__init__c                 C  s   | j dkrn=| j dkrt| jddddd n-| j dkr:g }t| j| jD ]\}}|||  q$t|| j| j	| _n	t
d	| j  d
t| jddddd t| jddddd dS )z
        similar to monai/networks/blocks/patchembedding.py for the decoder positional encoding and for mask and
        classification tokens
        r   r   r   g{Gz?g       g       @)meanstdabr   zdecoder_pos_embed_type z not supported.N)r*   r   rN   r>   r   r   appendr   r#   r,   r=   rK   rC   )rT   	grid_sizein_sizepa_sizer4   r4   r8   rS      s   



z"MaskedAutoEncoderViT._init_weightsNfloat | Nonec           
      C  s   |j \}}}|d urd| nd| j }tjt||t|| dd}|t|d|f }tj||tjd|j	}	d|	t|d|f< |||	fS )Nr2   F)replacement)dtyper   )
shaper!   rA   multinomialonesr   arange	unsqueezetodevice)
rT   xr!   
batch_size
num_tokensr7   Zpercentage_to_keepselected_indicesx_maskedmaskr4   r4   r8   _masking   s   
zMaskedAutoEncoderViT._maskingc                 C  s  |  |}| j||d\}}}| j|jd dd}tj||fdd}| |}| |}| j	
|jd |jd d}|d d dd d d f |t|jd d|f< || j }tj|d d d dd d f |gdd}| |}| |}|d d dd d d f }||fS )N)r!   r   rd   r2   )dim)rD   rr   rC   expandre   rA   catrH   rJ   rK   repeatrh   ri   rN   rO   rR   )rT   rl   r!   ro   rq   Z
cls_tokensx_r4   r4   r8   forward   s   


4
(

zMaskedAutoEncoderViT.forward)r   r   r   r   r   r   r   r   r   r   r   r   r   r   FF)(r   r   r   r   r   r   r   r   r   r   r   r   r    r   r!   r"   r#   r   r$   r   r%   r   r&   r   r'   r(   r)   r(   r*   r(   r+   r"   r,   r   r-   r.   r/   r.   r0   r1   )N)r!   ra   )	__name__
__module____qualname____doc__r<   rS   rr   rx   __classcell__r4   r4   rW   r8   r       s,    r)
__future__r   collections.abcr   numpyrP   rA   torch.nnr?   Z$monai.networks.blocks.patchembeddingr   %monai.networks.blocks.pos_embed_utilsr   Z&monai.networks.blocks.transformerblockr   monai.networks.layersr   monai.utilsr	   monai.utils.moduler
   rL   __all__Moduler   r4   r4   r4   r8   <module>   s   