o
    (i%                     @  s   d dl mZ d dlmZmZmZ d dlZd dlmZ d dl	m  m
Z d dlmZ d dlmZ eddd\ZZG d	d
 d
ejZdS )    )annotations)OptionalTupleUnionN)get_rel_pos_embedding_layer)optional_importzeinops.layers.torch	Rearrange)namec                      sJ   e Zd ZdZ													d$d% fddZd&d'd"d#Z  ZS )(SABlockz
    A self-attention block, based on: "Dosovitskiy et al.,
    An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
            FNThidden_sizeint	num_headsdropout_ratefloatqkv_biasbool	save_attndim_head
int | Nonehidden_input_sizecausalsequence_lengthrel_pos_embedding
str | None
input_sizeTuple | Noneattention_dtypetorch.dtype | None
include_fcuse_combined_linearuse_flash_attentionreturnNonec              	     sx  t    d|  krdkstd td|| dkr!td|r,|| | _|| _n|| dkr6td|| _|| | _|rH|	du rHtd|rP|rPtd|rZ|
durZtd|| _|ra|n|| _|  |rrt| j| j| _	nt
 | _	|  |  |  |  |rtj| j| jd	 |d
| _t
  | _ | _| _tdd	|d| _n-tj| j| j|d
| _tj| j| j|d
| _tj| j| j|d
| _t
 | _td|d| _td| _t|| _t|| _|| _| jd | _|| _t | _|| _|| _|	| _|| _|| _|| _ |r!|	dur!| !dt"t#|	|	$dd|	|	 |  nt | _%|
dur4t&|
|| j| jnd| _'|| _(dS )a  
        Args:
            hidden_size (int): dimension of hidden layer.
            num_heads (int): number of attention heads.
            dropout_rate (float, optional): fraction of the input units to drop. Defaults to 0.0.
            qkv_bias (bool, optional): bias term for the qkv linear layer. Defaults to False.
            save_attn (bool, optional): to make accessible the attention matrix. Defaults to False.
            dim_head (int, optional): dimension of each head. Defaults to hidden_size // num_heads.
            hidden_input_size (int, optional): dimension of the input tensor. Defaults to hidden_size.
            causal: whether to use causal attention (see https://arxiv.org/abs/1706.03762).
            sequence_length: if causal is True, it is necessary to specify the sequence length.
            rel_pos_embedding (str, optional): Add relative positional embeddings to the attention map.
                For now only "decomposed" is supported (see https://arxiv.org/abs/2112.01526). 2D and 3D are supported.
            input_size (tuple(spatial_dim), optional): Input resolution for calculating the relative
                positional parameter size.
            attention_dtype: cast attention operations to this dtype.
            include_fc: whether to include the final linear layer. Default to True.
            use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
            use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
                (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

        r      z'dropout_rate should be between 0 and 1.z-hidden size should be divisible by num_heads.Nz2sequence_length is necessary for causal attention.zsave_attn has been set to True, but use_flash_attention is also setto True. save_attn can only be used if use_flash_attention is False.z@rel_pos_embedding must be None if you are using flash_attention.   )biaszb h (qkv l d) -> qkv b l h d)qkvlzb h (l d) -> b l h d)r(   zb l h d -> b h (l d)g      causal_mask))super__init__
ValueErrorZ	inner_dimr   r   r   nnLinearout_projIdentityr'   to_qto_kto_vr   input_rearrangeout_rearrangeDropoutdrop_outputdrop_weightsr   scaler   torchTensoratt_matr   r   r   r   r    r!   register_buffertrilonesviewr)   r   rel_positional_embeddingr   )selfr   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   	__class__ e/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/selfattention.pyr+       s   
)








zSABlock.__init__	attn_maskOptional[torch.Tensor]c              	   C  s  | j r| | |}|d |d |d }}}n| | |}| | |}| | |}| jdurC|| j}|| j}| jrVt	j
||||| j| j| jd}nytd||| j }| jdurl| |||}| jr|durwtd|| jddddd|jd d|jd f dktd	}|dur|dd}|d
| jd
d
}||dktd	}|jd
d}| jr| | _| |}td||}| |}| jr|  |}| !|}|S )aF  
        Args:
            x (torch.Tensor): input tensor. B x (s_dim_1 * ... * s_dim_n) x C
            attn_mask (torch.Tensor, optional): mask to apply to the attention matrix.
            B x (s_dim_1 * ... * s_dim_n). Defaults to None.

        Return:
            torch.Tensor: B x (s_dim_1 * ... * s_dim_n) x C
        r   r$      N)querykeyvaluerG   r9   	dropout_p	is_causalzblxd,blyd->blxyz2Causal attention does not support attention masks.z-inf)dimzbhxy,bhyd->bhxd)"r    r4   r'   r1   r2   r3   r   tor!   Fscaled_dot_product_attentionr9   r   r   r:   einsumrA   r,   masked_fillr)   shaper   	unsqueezeexpandr   softmaxr   detachr<   r8   r5   r   r/   r7   )rB   xrG   outputqkvr<   rE   rE   rF   forward   sP   



>




zSABlock.forward)r   FFNNFNNNNTTF) r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r   r!   r   r"   r#   )N)rG   rH   )__name__
__module____qualname____doc__r+   ra   __classcell__rE   rE   rC   rF   r
      s"    	|r
   )
__future__r   typingr   r   r   r:   torch.nnr-   torch.nn.functional
functionalrS   monai.networks.layers.utilsr   monai.utilsr   r   _Moduler
   rE   rE   rE   rF   <module>   s   