o
    )iW                     @  sV   d dl mZ d dlmZ d dlZd dlm  mZ d dlmZ dddZ	dddZ
dS )    )annotations)TupleN)nnq_sizeintk_sizerel_postorch.Tensorreturnc                 C  s   t  }tdt| | d }|jd |kr6tj|d|jd dddd|dd}|d|dd}n|}t 	| dddf t||  d }t 	|dddf t| | d }|| |d t| | d  }||
  S )	aY  
    Get relative positional embeddings according to the relative positions of
    query and key sizes.

    Args:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): relative position embeddings (L, C).

    Returns:
        Extracted positional embeddings according to relative positions.
          r   linear)sizemodeNg      ?)torchTensorr   maxshapeFinterpolatereshapepermutearangelong)r   r   r   Zrel_pos_resizedZmax_rel_distZq_coordsZk_coordsrelative_coords r   g/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/attention_utils.pyget_rel_pos   s    $$r   attnqrel_pos_lstnn.ParameterListr   c              	   C  s  t |d |d |d }t |d |d |d }|j\}}}	t|dkr|dd \}
}|dd \}}|||
||	}td||}td||}| ||
||||dddddddddf  |dddddddddf  ||
| || } | S t|dkr|dd \}
}}|dd \}}}t |||d }|||
|||	}td||}td	||}td	||}| ||
||||||ddddddddddf  |ddddddddddf  |ddddddddddf  ||
| | || | } | S )
a  
    Calculate decomposed Relative Positional Embeddings from mvitv2 implementation:
    https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py

    Only 2D and 3D are supported.

    Encoding the relative position of tokens in the attention matrix: tokens spaced a distance
    `d` apart will have the same embedding value (unlike absolute positional embedding).

    .. math::
        Attn_{logits}(Q, K) = (QK^{T} + E_{rel})*scale

    where

    .. math::
        E_{ij}^{(rel)} = Q_{i}.R_{p(i), p(j)}

    with :math:`R_{p(i), p(j)} \in R^{dim}` and :math:`p(i), p(j)`,
    respectively spatial positions of element :math:`i` and :math:`j`

    When using "decomposed" relative positional embedding, positional embedding is defined ("decomposed") as follow:

    .. math::
        R_{p(i), p(j)} = R^{d1}_{d1(i), d1(j)} + ... + R^{dn}_{dn(i), dn(j)}

    with :math:`n = 1...dim`

    Decomposed relative positional embedding reduces the complexity from :math:`\mathcal{O}(d1*...*dn)` to
    :math:`\mathcal{O}(d1+...+dn)` compared with classical relative positional embedding.

    Args:
        attn (Tensor): attention map.
        q (Tensor): query q in the attention layer with shape (B, s_dim_1 * ... * s_dim_n, C).
        rel_pos_lst (ParameterList): relative position embeddings for each axis: rel_pos_lst[n] for nth axis.
        q_size (Tuple): spatial sequence size of query q with (q_dim_1, ..., q_dim_n).
        k_size (Tuple): spatial sequence size of key k with (k_dim_1, ...,  k_dim_n).

    Returns:
        attn (Tensor): attention logits with added relative positional embeddings.
    r   r   r   Nzbhwc,hkc->bhwkzbhwc,wkc->bhwk   zbhwdc,hkc->bhwdkzbhwdc,wkc->bhwdk)r   r   lenr   r   einsumview)r   r    r!   r   r   rhrwbatch_dimq_hq_wk_hk_wr_qZrel_hZrel_wZq_dk_drdZrel_dr   r   r   add_decomposed_rel_pos4   s>   +V"""r3   )r   r   r   r   r   r	   r
   r	   )r   r	   r    r	   r!   r"   r   r   r   r   r
   r	   )
__future__r   typingr   r   torch.nn.functionalr   
functionalr   r   r3   r   r   r   r   <module>   s   	
!