o
    -i                     @  sZ   d dl mZ d dlZd dlmZ d dlmZ dgZG dd dejZ	G dd dejZ
dS )    )annotationsNTransformerBlockDecoderOnlyTransformerc                      s,   e Zd ZdZd fddZdddZ  ZS )AbsolutePositionalEmbeddingzAbsolute positional embedding.

    Args:
        max_seq_len: Maximum sequence length.
        embedding_dim: Dimensionality of the embedding.
    max_seq_lenintembedding_dimreturnNonec                   s(   t    || _|| _t||| _d S N)super__init__r   r	   nn	Embedding	embedding)selfr   r	   	__class__ a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/transformer.pyr      s   
z$AbsolutePositionalEmbedding.__init__xtorch.Tensorc                 C  s2   |  \}}tj||jd|d}| |}|S )N)device   )sizetorcharanger   repeatr   )r   r   
batch_sizeseq_len	positionsr   r   r   r   forward$   s   
z#AbsolutePositionalEmbedding.forward)r   r   r	   r   r
   r   )r   r   r
   r   )__name__
__module____qualname____doc__r   r"   __classcell__r   r   r   r   r      s    r   c                      sF   e Zd ZdZ					d!d" fddZd#d$ddZd%d&dd Z  ZS )'r   a  Decoder-only (Autoregressive) Transformer model.

    Args:
        num_tokens: Number of tokens in the vocabulary.
        max_seq_len: Maximum sequence length.
        attn_layers_dim: Dimensionality of the attention layers.
        attn_layers_depth: Number of attention layers.
        attn_layers_heads: Number of attention heads.
        with_cross_attention: Whether to use cross attention for conditioning.
        embedding_dropout_rate: Dropout rate for the embedding.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    F        T
num_tokensr   r   attn_layers_dimattn_layers_depthattn_layers_headswith_cross_attentionboolembedding_dropout_ratefloat
include_fcuse_combined_linearuse_flash_attentionr
   r   c              	     s   t    || _| _ | _|| _| _| _t	| | _
t d| _t|| _t fddt|D | _t || _d S )N)r   r	   c                   s.   g | ]}t   d  ddddqS )   r(   FT)hidden_sizemlp_dim	num_headsdropout_rateqkv_biascausalsequence_lengthr-   r1   r2   r3   r   ).0_r*   r,   r1   r   r2   r3   r-   r   r   
<listcomp>V   s     z3DecoderOnlyTransformer.__init__.<locals>.<listcomp>)r   r   r)   r   r*   r+   r,   r-   r   r   token_embeddingsr   position_embeddingsDropoutembedding_dropout
ModuleListrangeblocksLinear	to_logits)r   r)   r   r*   r+   r,   r-   r/   r1   r2   r3   r   r>   r   r   <   s    
zDecoderOnlyTransformer.__init__Nr   r   contexttorch.Tensor | Nonec                 C  sH   |  |}| |}| || }| jD ]}|||d}q| |}|S )N)rI   )r@   rA   rC   rF   rH   )r   r   rI   Ztok_embpos_embblocklogitsr   r   r   r"   j   s   



zDecoderOnlyTransformer.forwardold_state_dictdictc                   s  |    t fdd|D rtd | | dS |rB D ]}||vr,td| d qtd |D ]}| vrAtd| d q3 D ]}||v rQ|| |< qDt| D ] }d	|v ri|| |d	d
< d|v rx|| |dd	< qX|rtd|  |   dS )z
        Load a state dict from a DecoderOnlyTransformer trained with
        [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).

        Args:
            old_state_dict: state dict from the old DecoderOnlyTransformer  model.
        c                 3  s    | ]}| v V  qd S r   r   )r<   knew_state_dictr   r   	<genexpr>   s    z=DecoderOnlyTransformer.load_old_state_dict.<locals>.<genexpr>z#All keys match, loading state dict.Nzkey z not found in old state dictz.----------------------------------------------z not found in new state dictnorm2norm_cross_attnnorm3z!remaining keys in old_state_dict:)
state_dictallprintload_state_dictpoplistkeysreplace)r   rN   verboserP   r   rQ   r   load_old_state_dictt   s8   	
z*DecoderOnlyTransformer.load_old_state_dict)Fr(   TFF)r)   r   r   r   r*   r   r+   r   r,   r   r-   r.   r/   r0   r1   r.   r2   r.   r3   r.   r
   r   r   )r   r   rI   rJ   r
   r   )F)rN   rO   r
   r   )r#   r$   r%   r&   r   r"   r`   r'   r   r   r   r   r   +   s    .
)
__future__r   r   torch.nnr   monai.networks.blocksr   __all__Moduler   r   r   r   r   r   <module>   s   