o
    -i                     @  sr  d dl mZ d dlZd dlmZmZmZmZmZ d dl	Z
d dlZd dlm  mZ d dlmZ d dlZd dlmZmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZmZ eddd\ZZ ddgZ!d%d&ddZ"G dd dej#Z$G dd dej#Z%G dd dej#Z&G dd dej#Z'G dd dej#Z(G dd  d ej#Z)G d!d" d"ej#Z*G d#d$ d$ej#Z+dS )'    )annotationsN)AnyCallableOptionalSequenceTuple)nn)MLPBlockUnetrBasicBlock)SegResNetDS2)convert_points_to_disc)!keep_merge_components_with_points)sample_points_from_label)optional_importunsqueeze_leftunsqueeze_righteinops	rearrange)nameVISTA3D
vista3d1320      encoder_embed_dimintin_channelsc                 C  sB   t |dd| | dd}t| ddd}td| dd	}t|||d
}|S )a  
    Exact VISTA3D network configuration used in https://arxiv.org/abs/2406.05285>`_.
    The model treats class index larger than 132 as zero-shot.

    Args:
        encoder_embed_dim: hidden dimension for encoder.
        in_channels: input channel number.
    )r      r      r   instancer   )r   blocks_downnormout_channelsinit_filtersdsdepth      )feature_size	n_classeslast_supportedT)r'   r&   use_mlp)image_encoder
class_head
point_head)r   PointMappingSAMClassMappingClassifyr   )r   r   	segresnetr,   r+   Zvista r0   ]/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/vista3d.pyr   #   s   	c                      s   e Zd ZdZdG fddZdHddZdIddZ		dJdKddZ			 	!dLdMd)d*ZdNd,d-Z		.dOdPd4d5Z
	dQdRd8d9ZdSdTd=d>Z											:dUdVdEdFZ  ZS )Wr   a  
    VISTA3D based on:
        `VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography
        <https://arxiv.org/abs/2406.05285>`_.

    Args:
        image_encoder: image encoder backbone for feature extraction.
        class_head: class head used for class index based segmentation
        point_head: point head used for interactive segmetnation
    r*   	nn.Moduler+   r,   c                   s>   t    || _|| _|| _d | _d| _d| _d| _d| _	d S )NFii'  )
super__init__r*   r+   r,   image_embeddingsauto_freezepoint_freeze
NINF_VALUEZ
PINF_VALUE)selfr*   r+   r,   	__class__r0   r1   r4   F   s   

zVISTA3D.__init__pad_sizelist | Nonelabelstorch.Tensor | None	prev_maskpoint_coordsc                 C  s~   |du r	|||fS |durt j||ddd}|dur#t j||ddd}|dur:|tj|d |d |d g|jd }|||fS )	a  
        Image has been padded by sliding window inferer.
        The related padding need to be performed outside of slidingwindow inferer.

        Args:
            pad_size: padding size passed from sliding window inferer.
            labels: image label ground truth.
            prev_mask: previous segmentation mask.
            point_coords: point click coordinates.
        Nconstantr   )padmodevalueidevice)FrC   torchtensorrI   )r9   r<   r>   r@   rA   r0   r0   r1   update_slidingwindow_paddingQ   s   

z$VISTA3D.update_slidingwindow_paddingclass_vectorreturnr   c                 C  s,   |du r|du rt d|jd S |jd S )zAGet number of foreground classes based on class and point prompt.Nz2class_vector and point_coords cannot be both None.r   )
ValueErrorshape)r9   rN   rA   r0   r0   r1   get_foreground_class_countn   s
   

z"VISTA3D.get_foreground_class_countN               9      point_labeltorch.Tensor	label_setSequence[int] | Nonespecial_indexSequence[int]c                 C  s   |du r|S |j d t|kstdtt|D ]-}|| |v rFtt|| D ]}|||f dkr;|||f d n|||f |||f< q)q|S )a  
        Convert point label based on its class prompt. For special classes defined in special index,
        the positive/negative point label will be converted from 1/0 to 3/2. The purpose is to separate those
        classes with ambiguous classes.

        Args:
            point_label: the point label tensor, [B, N].
            label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
                this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
                evaluation, this label_set should be the original index.
            special_index: the special class index that needs to be converted.
        Nr   z4point_label and label_set must have the same length.r   )rQ   lenrP   range)r9   r[   r]   r_   ijr0   r0   r1   convert_point_labelw   s   6zVISTA3D.convert_point_labelTr   r   patch_coordsSequence[slice]
use_centerboolmapped_label_set
max_ppoint
max_npointc           
      C  sH   t || ||||j|d\}}	| |	|}	||	t||jdfS )aM  
        Sample points for patch during sliding window validation. Only used for point only validation.

        Args:
            labels: shape [1, 1, H, W, D].
            patch_coords: a sequence of sliding window slice objects.
            label_set: local index, must match values in labels.
            use_center: sample points from the center.
            mapped_label_set: global index, it is used to identify special classes and is the global index
                for the sampled points.
            max_ppoint/max_npoint: positive points and negative points to sample.
        )rl   rm   rI   ri   ra   )r   rI   rf   rK   rL   to	unsqueeze)
r9   r>   rg   r]   ri   rk   rl   rm   rA   point_labelsr0   r0   r1   sample_points_patch_val   s   
zVISTA3D.sample_points_patch_valrp   c           
      C  s   |d j |d j |d j g}|d j|d j|d jg}ttj||jdd}ttj||jdd}t|| dkd|| dkd}| | }| }|	 ryd|| < d|| < |dk	d}	|dd|	f }|dd|	f }||fS dS )	a  
        Update point_coords with respect to patch coords.
        If point is outside of the patch, remove the coordinates and set label to -1.

        Args:
            patch_coords: a sequence of the python slice objects representing the patch coordinates during sliding window inference.
                This value is passed from sliding_window_inferer.
            point_coords: point coordinates, [B, N, 3].
            point_labels: point labels, [B, N].
        rF   ra   rH   r   r   NNN)
stopstartr   rK   rL   rI   logical_andallcloneany)
r9   rg   rA   rp   Z
patch_endsZpatch_startsZpatch_starts_tensorZpatch_ends_tensorindicesZnot_pad_indicesr0   r0   r1   update_point_to_patch   s"    

zVISTA3D.update_point_to_patch      ?logitspoint_logitsmapping_indexthredfloatc              	     sb  t |tjjr| n|}||  g }t jd D ]|t	 fdd| 
   tD  qt||j}t }	tj | jd  | |k}
tt |kt|d|
}t |k|
 }t||||d}||j}t|
| }t|	|}|	 rd||< t|	||j}||  d| 9  < ||  || 7  < |S )	a  
        Combine auto results with point click response. The auto results have shape [B, 1, H, W, D] which means B foreground masks
        from a single image patch.
        Out of those B foreground masks, user may add points to a subset of B1 foreground masks for editing.
        mapping_index represents the correspondence between B and B1.
        For mapping_index with point clicks, NaN values in logits will be replaced with point_logits. Meanwhile, the added/removed
        region in point clicks must be updated by the lcc function.
        Notice, if a positive point is within logits/prev_mask, the components containing the positive point will be added.

        Args:
            logits: automatic branch results, [B, 1, H, W, D].
            point_logits: point branch results, [B1, 1, H, W, D].
            point_coords: point coordinates, [B1, N, 3].
            point_labels: point labels, [B1, N].
            mapping_index: [B].
            thred: the threshold to convert logits to binary.
        r   c              	     s2   g | ]} d |d  |d |d f   d kqS )r   r   r   )item).0pZ_logitsrd   r0   r1   
<listcomp>   s    $z8VISTA3D.connected_components_combine.<locals>.<listcomp>)nan   )rA   rp   ra   r   )
isinstancemonaidata
MetaTensor	as_tensorrc   rQ   appendnpry   cpunumpyroundastyper   rK   rL   rn   rI   isnan
nan_to_numr8   sigmoidrv   
logical_orr   lccdtype)r9   r}   r~   rA   rp   r   r   insideZinside_tensornan_maskZ
pos_regionZdiff_posZdiff_negccZuc_pos_region	fill_maskr0   r   r1   connected_components_combine   s6   
z$VISTA3D.connected_components_combineradius
int | Nonec                 C  s   |du rt |jdd d }dt|jdd |||djddd }d||dk < t|tjjr4| n|}||  |9  < ||  d| | 7  < |S )	a  
        Combine point results with auto results using gaussian.

        Args:
            logits: automatic branch results, [B, 1, H, W, D].
            point_logits: point branch results, [B1, 1, H, W, D].
            point_coords: point coordinates, [B1, N, 3].
            point_labels: point labels, [B1, N].
            mapping_index: [B].
            radius: gaussian ball radius.
        Nrr   r   r   )r   T)keepdimsr   )	minrQ   r   sumr   r   r   r   r   )r9   r}   r~   rA   rp   r   r   weightr0   r0   r1   gaussian_combine  s   zVISTA3D.gaussian_combineFr6   r7   c                 C  s   || j kr2t| jdr| jj||d n| j D ]	}| o | |_q| j D ]}| |_q(|| _ || jkrft| jdrF| jj||d n| j D ]	}| oR| |_qK| j D ]}| |_qZ|| _dS dS )z
        Freeze auto-branch or point-branch.

        Args:
            auto_freeze: whether to freeze the auto branch.
            point_freeze: whether to freeze the point branch.
        set_auto_grad)r6   r7   N)	r6   hasattrr*   r   
parametersrequires_gradr+   r7   r,   )r9   r6   r7   paramr0   r0   r1   r   4  s"   




zVISTA3D.set_auto_gradinput_imageslist[Sequence[slice]] | Noneprompt_classval_point_samplerCallable | None	transposec                 K  s  |  |dd||	|\}}	}|jdd }|j}|du r/|du r/| jtjddg||d S | ||}|durs|durd|durd|du rH| j}|||d |\}}}|d 	 dkr_d|d< d\}}	n|durs| 
|d ||\}}|dur|dur|dkddktj}| r|| }|| }|dur|| }n| js|du r|du r|d	 nd\}}|du r|du r| jtj|dg||d }|r|dd}|S | jdur|d
dr|du r| jd}}n| j||du|dud\}}d}tj  |dur2| ||\}}|dur1| j||||d}|du r(| ||||||
}nF| |||||}n<| jtj|dg|||jd }| j||||d||< |	durn|durn| |	|d  dd|j|| |||}|d
dr|du r| | _|r|dd}|S )a7  
        The forward function for VISTA3D. We only support single patch in training and inference.
        One exception is allowing sliding window batch size > 1 for automatic segmentation only case.
        B represents number of objects, N represents number of points for each objects.

        Args:
            input_images: [1, 1, H, W, D]
            point_coords: [B, N, 3]
            point_labels: [B, N], -1 represents padding. 0/1 means negative/positive points for regular class.
                2/3 means negative/postive ponits for special supported class like tumor.
            class_vector: [B, 1], the global class index.
            prompt_class: [B, 1], the global class index. This value is associated with point_coords to identify if
                the points are for zero-shot or supported class. When class_vector and point_coords are both
                provided, prompt_class is the same as class_vector. For prompt_class[b] > 512, point_coords[b]
                will be considered novel class.
            patch_coords: a list of sequence of the python slice objects representing the patch coordinates during sliding window
                inference. This value is passed from sliding_window_inferer.
                This is an indicator for training phase or validation phase.
                Notice for sliding window batch size > 1 (only supported by automatic segmentation), patch_coords will inlcude
                coordinates of multiple patches. If point prompts are included, the batch size can only be one and all the
                functions using patch_coords will by default use patch_coords[0].
            labels: [1, 1, H, W, D], the groundtruth label tensor, only used for point-only evaluation
            label_set: the label index matching the indexes in labels. If labels are mapped to global index using RelabelID,
                this label_set should be global mapped index. If labels are not mapped to global index, e.g. in zero-shot
                evaluation, this label_set should be the original index.
            prev_mask: [B, N, H_fullsize, W_fullsize, D_fullsize].
                This is the transposed raw output from sliding_window_inferer before any postprocessing.
                When user click points to perform auto-results correction, this can be the auto-results.
            radius: single float value controling the gaussian blur when combining point and auto results.
                The gaussian combine is not used in VISTA3D training but might be useful for finetuning purposes.
            val_point_sampler: function used to sample points from labels. This is only used for point-only evaluation.
            transpose: bool. If true, the output will be transposed to be [1, B, H, W, D]. Required to be true if calling from
                sliding window inferer/point inferer.
        r<   Nrr   r   rH   r   ra   rs   TZ
keep_cacheF)
with_point
with_label)rN   rI   r   )rM   getrQ   rI   r8   rK   zerosrR   rq   r   r{   r   rn   rj   ry   r6   fill_r   r5   r*   cudaempty_cacher+   r,   r   r   r   detach)r9   r   rg   rA   rp   rN   r   r>   r]   r@   r   r   r   kwargs
image_sizerI   bsr   r}   outZout_auto_r~   r0   r0   r1   forwardP  s   2







 
zVISTA3D.forward)r*   r2   r+   r2   r,   r2   )r<   r=   r>   r?   r@   r?   rA   r?   )rN   r?   rA   r?   rO   r   )NrS   )r[   r\   r]   r^   r_   r`   )TNr   r   )r>   r\   rg   rh   r]   r`   ri   rj   rk   r^   rl   r   rm   r   )rg   rh   rA   r\   rp   r\   )r|   )r}   r\   r~   r\   rA   r\   rp   r\   r   r\   r   r   N)r}   r\   r~   r\   rA   r\   rp   r\   r   r\   r   r   )FF)r6   rj   r7   rj   )NNNNNNNNNNF)r   r\   rg   r   rA   r?   rp   r?   rN   r?   r   r?   r>   r?   r]   r^   r@   r?   r   r   r   r   r   rj   )__name__
__module____qualname____doc__r4   rM   rR   rf   rq   r{   r   r   r   r   __classcell__r0   r0   r:   r1   r   :   s>    

"
!*Dc                      s.   e Zd Zdd fd	d
Z	ddddZ  ZS )r-       r$   r%   r&   r   
max_promptr'   r(   c                   s`  t    |}|| _ttj||ddddt|t tj||ddddt|| _tjdddddd| _	t
d|ddd| _t|d | _ttd|td|g| _td|| _td|| _td|| _ttj||dddddt|t tj||dddd	| _t|||d| _|| _|| _t||| _td|| _td|| _d
S )aH  Interactive point head used for VISTA3D.
        Adapted from segment anything:
        `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.

        Args:
            feature_size: feature channel from encoder.
            max_prompt: max prompt number in each forward iteration.
            n_classes: number of classes the model can potentially support. This is the maximum number of class embeddings.
            last_supported: number of classes the model support, this value should match the trained model weights.
           r   r   )r   r!   kernel_sizestridepaddingr$   r   )depthembedding_dimmlp_dim	num_heads)r   r   r   output_padding)r   r   r   N)r3   r4   r   r   
SequentialConv3dInstanceNorm3dGELUfeat_downsampleZmask_downsampleTwoWayTransformertransformerPositionEmbeddingRandompe_layer
ModuleList	Embeddingpoint_embeddingsnot_a_point_embedspecial_class_embedmask_tokensConvTranspose3doutput_upscalingMLPoutput_hypernetworks_mlpsr'   r(   class_embeddingszeroshot_embedsupported_embed)r9   r&   r   r'   r(   Ztransformer_dimr:   r0   r1   r4     s:   
 zPointMappingSAM.__init__Nr   r\   rA   rp   rN   r?   c              
   C  s  |  |}t|jdd }d}tj  |d }| j||}d||dk< ||dk  | jj	7  < ||dk  | j
d j	7  < ||dk  | j
d j	7  < ||dk  | j
d j	| jj	 7  < ||d	k  | j
d j	| jj	 7  < | jj	}	|	d|ddd}	|du rtj|	|| jj	d|dddfdd
}
n(g }|D ]}|| jkr|| jj	 q|| jj	 qtj|	|t|fdd
}
g }| j}ttt|
jd | D ]}d\}}}tj  || t|d | |
jd f}|
|d |d  }tj||jd dd
}tj| |jdd d|jd dd
}|j\}}}}}| |||\}}|ddddddf }| |}|dd |||||}| !|}|j\}}}}}|| |||| |  }|| dd||| qt"|S )zArgs:
        out: feature from encoder, [1, C, H, W, C]
        point_coords: point coordinates, [B, N, 3]
        point_labels: point labels, [B, N]
        class_vector: class prompts, [B]
        rr   Nr|           ra   r   r   r   r   dim)NNN)#r   tuplerQ   rK   r   r   r   forward_with_coordsr   r   r   r   r   ro   expandsizecatr   r(   r   r   stackr   rc   r   r   ceilr   repeat_interleaver   r   r   viewr   vstack)r9   r   rA   rp   rN   Zout_low	out_shapepointspoint_embeddingoutput_tokensZ
tokens_allr   rd   masksr   srcZupscaled_embeddingZhyper_inidxtokensZpos_srcbchwdhsZmask_tokens_outmaskr0   r0   r1   r   
  s^   

$$

 

 ,


zPointMappingSAM.forward)r   r$   r%   )r&   r   r   r   r'   r   r(   r   r   )r   r\   rA   r\   rp   r\   rN   r?   r   r   r   r4   r   r   r0   r0   r:   r1   r-     s    3r-   c                      s.   e Zd ZdZdd fdd	ZdddZ  ZS )r.   zFClass head that performs automatic segmentation based on class vector.Tr'   r   r&   r)   rj   c                   s   t    || _|r tt||tdt t||| _t	||| _
ttd||dddddtd||ddddd| _dS )zArgs:
        n_classes: maximum number of class embedding.
        feature_size: class embedding size.
        use_mlp: use mlp to further map class embedding.
        r   r   r   T)spatial_dimsr   r!   r   r   	norm_name	res_blockN)r3   r4   r)   r   r   LinearInstanceNorm1dr   mlpr   r   r
   image_post_mapping)r9   r'   r&   r)   r:   r0   r1   r4   V  s<   


	
zClassMappingClassify.__init__r   r\   rN   c           
      C  st   |j \}}}}}| |}| |}| jr| |}| ||||| |  }	|	|d|||dd}	|	|fS )Nra   r   r   )rQ   r
  r   r)   r	  squeezer   r   )
r9   r   rN   r   r   r   r   r   class_embeddingZmasks_embeddingr0   r0   r1   r   {  s   


zClassMappingClassify.forward)T)r'   r   r&   r   r)   rj   )r   r\   rN   r\   r   r   r   r   r4   r   r   r0   r0   r:   r1   r.   S  s    %r.   c                      s.   e Zd Z		dd fddZdddZ  ZS )r   relur   r   r   r   r   r   
activationtuple | strattention_downsample_raterO   Nonec                   sz   t    || _|| _|| _|| _t | _t	|D ]}| j
t||||||dkd qt|||d| _t|| _dS )a  
        A transformer decoder that attends to an input image using
        queries whose positional embedding is supplied.
        Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.

        Args:
            depth: number of layers in the transformer.
            embedding_dim: the channel dimension for the input embeddings.
            num_heads: the number of heads for multihead attention. Must divide embedding_dim.
            mlp_dim: the channel dimension internal to the MLP block.
            activation: the activation to use in the MLP block.
            attention_downsample_rate: the rate at which to downsample the image before projecting.
        r   )r   r   r   r  r  skip_first_layer_pedownsample_rateN)r3   r4   r   r   r   r   r   r   layersrc   r   TwoWayAttentionBlock	Attentionfinal_attn_token_to_image	LayerNormnorm_final_attn)r9   r   r   r   r   r  r  rd   r:   r0   r1   r4     s&   

zTwoWayTransformer.__init__image_embeddingr\   image_per   !Tuple[torch.Tensor, torch.Tensor]c           
      C  s   | dddd}| dddd}|}|}| jD ]}|||||d\}}q|| }|| }| j|||d}	||	 }| |}||fS )a-  
        Args:
            image_embedding: image to attend to. Should be shape
                B x embedding_dim x h x w for any h and w.
            image_pe: the positional encoding to add to the image. Must
                have the same shape as image_embedding.
            point_embedding: the embedding to add to the query points.
                Must have shape B x N_points x embedding_dim for any N_points.

        Returns:
            torch.Tensor: the processed point_embedding.
            torch.Tensor: the processed image_embedding.
        r   r   r   )querieskeysquery_pekey_peqkv)flattenpermuter  r  r  )
r9   r  r  r   r  r   layerr$  r%  attn_outr0   r0   r1   r     s   

zTwoWayTransformer.forward)r  r   )r   r   r   r   r   r   r   r   r  r  r  r   rO   r  )r  r\   r  r\   r   r\   rO   r  r  r0   r0   r:   r1   r     s
    ,r   c                      s2   e Zd Z				dd fddZdddZ  ZS )r     r  r   Fr   r   r   r   r  r  r  r  rj   rO   r  c                   s   t    t||| _t|| _t|||d| _t|| _t	|||dd| _
t|| _t|| _t|||d| _|| _dS )a  
        A transformer block with four layers: (1) self-attention of sparse
        inputs, (2) cross attention of sparse inputs to dense inputs, (3) mlp
        block on sparse inputs, and (4) cross attention of dense inputs to sparse
        inputs.
        Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.

        Args:
            embedding_dim: the channel dimension of the embeddings.
            num_heads: the number of heads in the attention layers.
            mlp_dim: the hidden dimension of the mlp block.
            activation: the activation of the mlp block.
            skip_first_layer_pe: skip the PE on the first layer.
        r  vista3d)hidden_sizer   actdropout_modeN)r3   r4   r  	self_attnr   r  norm1cross_attn_token_to_imagenorm2r	   r	  norm3norm4cross_attn_image_to_tokenr  )r9   r   r   r   r  r  r  r:   r0   r1   r4     s   

zTwoWayAttentionBlock.__init__r  r\   r   r!  r"  r  c           	      C  s   | j r| j|||d}n|| }| j|||d}|| }| |}|| }|| }| j|||d}|| }| |}| |}|| }| |}|| }|| }| j|||d}|| }| |}||fS )Nr#  )	r  r0  r1  r2  r3  r	  r4  r6  r5  )	r9   r  r   r!  r"  r$  r*  r%  Zmlp_outr0   r0   r1   r     s(   




zTwoWayAttentionBlock.forward)r+  r  r   F)r   r   r   r   r   r   r  r  r  r   r  rj   rO   r  )
r  r\   r   r\   r!  r\   r"  r\   rO   r  r  r0   r0   r:   r1   r    s    &r  c                      sB   e Zd ZdZdd fd	d
ZdddZdddZdddZ  ZS )r  a  
    An attention layer that allows for downscaling the size of the embedding
    after projection to queries, keys, and values.
    Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/transformer.py`.

    Args:
        embedding_dim: the channel dimension of the embeddings.
        num_heads: the number of heads in the attention layers.
        downsample_rate: the rate at which to downsample the image before projecting.
    r   r   r   r   r  rO   r  c                   sz   t    || _|| | _|| _| j| dkstdt|| j| _t|| j| _	t|| j| _
t| j|| _d S )Nr   z$num_heads must divide embedding_dim.)r3   r4   r   Zinternal_dimr   rP   r   r  q_projk_projv_projout_proj)r9   r   r   r  r:   r0   r1   r4   /  s   

zAttention.__init__xr\   c                 C  s,   |j \}}}|||||| }|ddS Nr   r   )rQ   reshaper   )r9   r;  r   r   nr   r0   r0   r1   _separate_heads<  s   zAttention._separate_headsc                 C  s,   |j \}}}}|dd}||||| S r<  )rQ   r   r=  )r9   r;  r   n_headsZn_tokens
c_per_headr0   r0   r1   _recombine_headsB  s   zAttention._recombine_headsr$  r%  r&  c                 C  s   |  |}| |}| |}| || j}| || j}| || j}|j\}}}}||dddd }|t| }t	j
|dd}|| }| |}| |}|S )Nr   r   r   r   ra   r   )r7  r8  r9  r?  r   rQ   r(  mathsqrtrK   softmaxrB  r:  )r9   r$  r%  r&  r   rA  attnr   r0   r0   r1   r   H  s   




zAttention.forward)r   )r   r   r   r   r  r   rO   r  )r;  r\   r   r   rO   r\   r;  r\   rO   r\   )r$  r\   r%  r\   r&  r\   rO   r\   )	r   r   r   r   r4   r?  rB  r   r   r0   r0   r:   r1   r  #  s    

r  c                      sB   e Zd ZdZdd fd
dZdddZdddZdddZ  ZS )r   aA  
    Positional encoding using random spatial frequencies.
    Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/prompt_encoder.py`.

    Args:
        num_pos_feats: the number of positional encoding features.
        scale: the scale of the positional encoding.
    @   Nnum_pos_featsr   scaleOptional[float]rO   r  c                   s<   t    |d u s|dkrd}| d|td|f  d S )Nr   g      ?#positional_encoding_gaussian_matrixr   )r3   r4   register_bufferrK   randn)r9   rI  rJ  r:   r0   r1   r4   k  s   
z PositionEmbeddingRandom.__init__coordstorch.torch.Tensorc                 C  sB   d| d }|| j  }dtj | }tjt|t|gddS )z8Positionally encode points that are normalized to [0,1].r   r   ra   r   )rL  r   pirK   r   sincos)r9   rO  r0   r0   r1   _pe_encodingq  s   
z$PositionEmbeddingRandom._pe_encodingr   Tuple[int, int, int]c                 C  s   |\}}}| j j}tj|||f|tjd}|jddd }|jddd }|jddd }	|| }|| }|	| }	| tj|||	gdd}
|
ddddS )	z>Generate positional encoding for a grid of the specified size.r   r   r   r|   r   r   ra   r   )	rL  rI   rK   onesfloat32cumsumrT  r   r(  )r9   r   r   r   r   rI   gridZx_embedZy_embedZz_embedper0   r0   r1   r   }  s   
zPositionEmbeddingRandom.forwardcoords_inputr   c                 C  s   |  }|dddddf |d  |dddddf< |dddddf |d  |dddddf< |dddddf |d  |dddddf< | |tjS )z<Positionally encode points that are not normalized to [0,1].Nr   r   r   )rx   rT  rn   rK   r   )r9   r[  r   rO  r0   r0   r1   r     s
   000z+PositionEmbeddingRandom.forward_with_coords)rH  N)rI  r   rJ  rK  rO   r  )rO  rP  rO   rP  )r   rU  rO   rP  )r[  rP  r   rU  rO   rP  )	r   r   r   r   r4   rT  r   r   r   r0   r0   r:   r1   r   a  s    	

r   c                      s0   e Zd ZdZ	dd fddZdddZ  ZS )r   a  
    Multi-layer perceptron. This class is only used for `PointMappingSAM`.
    Adapted from `https://github.com/facebookresearch/segment-anything/blob/main/segment_anything/modeling/mask_decoder.py`.

    Args:
        input_dim: the input dimension.
        hidden_dim: the hidden dimension.
        output_dim: the output dimension.
        num_layers: the number of layers.
        sigmoid_output: whether to apply a sigmoid activation to the output.
    F	input_dimr   
hidden_dim
output_dim
num_layerssigmoid_outputrj   rO   r  c                   sP   t    || _|g|d  }tdd t|g| ||g D | _|| _d S )Nr   c                 s  s     | ]\}}t ||V  qd S r   )r   r  )r   r>  r%  r0   r0   r1   	<genexpr>  s    zMLP.__init__.<locals>.<genexpr>)r3   r4   r_  r   r   zipr  r`  )r9   r\  r]  r^  r_  r`  r   r:   r0   r1   r4     s
   
(
zMLP.__init__r;  r\   c                 C  sL   t | jD ]\}}|| jd k rt||n||}q| jr$t|}|S )Nr   )	enumerater  r_  rJ   r  r`  r   )r9   r;  rd   r)  r0   r0   r1   r     s
   &
zMLP.forward)F)r\  r   r]  r   r^  r   r_  r   r`  rj   rO   r  rG  r  r0   r0   r:   r1   r     s
    	r   )r   r   )r   r   r   r   ),
__future__r   rC  typingr   r   r   r   r   r   r   rK   torch.nn.functionalr   
functionalrJ   r   monai.networks.blocksr	   r
   monai.networks.netsr   monai.transforms.utilsr   r   r   r   monai.utilsr   r   r   r   r   __all__r   Moduler   r-   r.   r   r  r  r   r   r0   r0   r0   r1   <module>   s8      $x4SI>7