o
    )i                      @  sn   d dl mZ d dlmZmZ d dlZd dlmZ d dlm	Z	 d dl
mZ eddd\ZZG d	d
 d
ejZdS )    )annotations)OptionalTupleN)get_rel_pos_embedding_layer)optional_importzeinops.layers.torch	Rearrange)namec                      sH   e Zd ZdZ												d$d% fddZd&d'd"d#Z  ZS )(CrossAttentionBlocka  
    A cross-attention block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
    One can setup relative positional embedding as described in <https://arxiv.org/abs/2112.01526>
            NFhidden_sizeint	num_headsdropout_ratefloathidden_input_size
int | Nonecontext_input_sizedim_headqkv_biasbool	save_attncausalsequence_lengthrel_pos_embeddingOptional[str]
input_sizeOptional[Tuple]attention_dtypeOptional[torch.dtype]use_flash_attentionreturnNonec              	     s  t    d|  krdkstd td|r!|| }|| _n|| dkr+td|}|| | _|	r<|
du r<td|rD|rDtd|rN|durNtd|| _|rU|n|| _|r\|n|| _t|| j| _	tj| j||d	| _
tj| j||d	| _tj| j||d	| _td
|d| _td| _t|| _t|| _|| _| jd | _|| _|| _|	| _|
| _|| _|	r|
dur| dtt|
|
dd|
|
 |  nt | _ t | _!|durt"||| j| jnd| _#|| _$dS )a=  
        Args:
            hidden_size (int): dimension of hidden layer.
            num_heads (int): number of attention heads.
            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
            hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
            context_input_size (int, optional): dimension of the context tensor. Defaults to hidden_size.
            dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
            qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
            save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
            causal (bool, optional): whether to use causal attention.
            sequence_length (int, optional): if causal is True, it is necessary to specify the sequence length.
            rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map. For now only
                "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
            input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative positional
                parameter size.
            attention_dtype: cast attention operations to this dtype.
            use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
                (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
        r      z'dropout_rate should be between 0 and 1.z-hidden size should be divisible by num_heads.Nz2sequence_length is necessary for causal attention.zsave_attn has been set to True, but use_flash_attention is also setto True. save_attn can only be used if use_flash_attention is Falsez@rel_pos_embedding must be None if you are using flash_attention.)biaszb h (l d) -> b l h d)lzb l h d -> b h (l d)g      causal_mask)%super__init__
ValueErrorhead_dimr   r   r   nnLinearout_projto_qto_kto_vr   input_rearrangeout_rearrangeDropoutdrop_outputdrop_weightsr   scaler   r   r   r   r   register_buffertorchtrilonesviewTensorr%   att_matr   rel_positional_embeddingr   )selfr   r   r   r   r   r   r   r   r   r   r   r   r   r   Z
inner_size	__class__ f/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/crossattention.pyr'       sf   
&




zCrossAttentionBlock.__init__xtorch.TensorcontextOptional[torch.Tensor]c                 C  s`  |  \}}}| | |}|dur|n|}|  \}}	}| | |}
| | |}| jdur?|| j}|
| j}
| jrStj	j
j||
|| j| j| jd}nLtd||
| j }| jduri| |||}| jr|| jddddd|d|	f dktd}|jdd}| jr| | _| |}td||}| |}| |}| |}|S )	a  
        Args:
            x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
            context (torch.Tensor, optional): context tensor. B x (s_dim_1 * ... * s_dim_n) x C

        Return:
            torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
        N)querykeyvaluer5   	dropout_p	is_causalzblxd,blyd->blxyr   z-inf)dimzbhxy,bhyd->bhxd)sizer0   r-   r.   r/   r   tor   r7   r*   
functionalscaled_dot_product_attentionr5   r   r   einsumr=   masked_fillr%   r   softmaxr   detachr<   r4   r1   r,   r3   )r>   rC   rE   btcqkv_Zkv_tkvr<   rA   rA   rB   forward   s6   


2




zCrossAttentionBlock.forward)r
   NNNFFFNNNNF)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   )N)rC   rD   rE   rF   )__name__
__module____qualname____doc__r'   r^   __classcell__rA   rA   r?   rB   r	      s     
ir	   )
__future__r   typingr   r   r7   torch.nnr*   monai.networks.layers.utilsr   monai.utilsr   r   r[   Moduler	   rA   rA   rA   rB   <module>   s   