U
    Ph                     @  sn   d dl mZ d dlZd dlmZ d dlmZmZ d dl	Z	d dl
mZ dgZdd Zdd
dddddddZdS )    )annotationsN)repeat)ListUnionbuild_sincos_position_embeddingc                   s    fdd}|S )Nc                   s.   t | tjjr t | ts t| S tt|  S )N)
isinstancecollectionsabcIterablestrtupler   )xn Z/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/blocks/pos_embed_utils.pyparse   s    z_ntuple.<locals>.parser   )r   r   r   r   r   _ntuple   s    r           @zUnion[int, List[int]]intfloatztorch.nn.Parameter)	grid_size	embed_dimspatial_dimstemperaturereturnc              	   C  sP  |dkrt d}|| }|\}}tj|tjd}tj|tjd}	tj||	dd\}}	|d dkrhtd|d }
tj|
tjd|
 }d||  }td	| |g}td	|	 |g}tjt	|t
|t	|t
|gd
ddddddf }n@|dkr4t d}|| }|\}}}tj|tjd}tj|tjd}	tj|tjd}tj||	|dd\}}	}|d dkr~td|d }
tj|
tjd|
 }d||  }td	| |g}td	|	 |g}td	| |g}tjt	|t
|t	|t
|t	|t
|gd
ddddddf }ntdt|}d|_|S )a  
    Builds a sin-cos position embedding based on the given grid size, embed dimension, spatial dimensions, and temperature.
    Reference: https://github.com/cvlab-stonybrook/SelfMedMAE/blob/68d191dfcc1c7d0145db93a6a570362de29e3b30/lib/models/mae3d.py

    Args:
        grid_size (List[int]): The size of the grid in each spatial dimension.
        embed_dim (int): The dimension of the embedding.
        spatial_dims (int): The number of spatial dimensions (2 for 2D, 3 for 3D).
        temperature (float): The temperature for the sin-cos position embedding.

    Returns:
        pos_embed (nn.Parameter): The sin-cos position embedding as a fixed parameter.
       )dtypeij)indexing   r   zHEmbed dimension must be divisible by 4 for 2D sin-cos position embeddingg      ?zm,d->md   )dimNr      zHEmbed dimension must be divisible by 6 for 3D sin-cos position embeddingz6Spatial Dimension Size {spatial_dims} Not Implemented!F)r   torcharangefloat32meshgridAssertionErroreinsumflattencatsincosNotImplementedErrornn	Parameterrequires_grad)r   r   r   r   Z	to_2tupleZgrid_size_thwZgrid_hZgrid_wZpos_dimomegaout_hout_wZpos_embZ	to_3tupledZgrid_dout_d	pos_embedr   r   r   r   #   s^    D



)r   r   )
__future__r   collections.abcr   	itertoolsr   typingr   r   r%   torch.nnr0   __all__r   r   r   r   r   r   <module>   s      