o
    )ir                     @  sN   d dl mZ d dlmZ d dlZd dlmZ d dlmZ G dd dej	Z
dS )    )annotations)OptionalN)SABlockc                      s<   e Zd ZdZ							dd fddZdddZ  ZS ) SpatialAttentionBlockaZ  Perform spatial self-attention on the input tensor.

    The input tensor is reshaped to B x (x_dim * y_dim [ * z_dim]) x C, where C is the number of channels, and then
    self-attention is performed on the reshaped tensor. The output tensor is reshaped back to the original shape.

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
        num_channels: number of input channels. Must be divisible by num_head_channels.
        num_head_channels: number of channels per head.
        norm_num_groups: Number of groups for the group norm layer.
        norm_eps: Epsilon for the normalization.
        attention_dtype: cast attention operations to this dtype.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

    N    ư>TFspatial_dimsintnum_channelsnum_head_channels
int | Nonenorm_num_groupsnorm_epsfloatattention_dtypeOptional[torch.dtype]
include_fcbooluse_combined_linearuse_flash_attentionreturnNonec
              	     sp   t    || _tj|||dd| _|d ur || dkr td|d ur(|| nd}
t||
d||||	d| _d S )NT)
num_groupsr
   epsaffiner   z3num_channels must be divisible by num_head_channels   )hidden_size	num_headsqkv_biasr   r   r   r   )	super__init__r   nn	GroupNormnorm
ValueErrorr   attn)selfr   r
   r   r   r   r   r   r   r   r   	__class__ h/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/spatialattention.pyr    *   s   
zSpatialAttentionBlock.__init__xtorch.Tensorc                 C  s`   |}|j }| |}|jg |d d dR  dd}| |}|dd|}|| }|S )N   r   )shaper#   reshape	transposer%   )r&   r+   residualr/   r)   r)   r*   forwardH   s   
$
zSpatialAttentionBlock.forward)Nr   r   NTFF)r   r	   r
   r	   r   r   r   r	   r   r   r   r   r   r   r   r   r   r   r   r   )r+   r,   )__name__
__module____qualname____doc__r    r3   __classcell__r)   r)   r'   r*   r      s    r   )
__future__r   typingr   torchtorch.nnr!   Zmonai.networks.blocksr   Moduler   r)   r)   r)   r*   <module>   s   