o
    (i                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlm  mZ	 d dl
mZ d dlmZ eddd\ZZd	d
gZG dd	 d	ejZG dd
 d
ejZdS )    )annotations)castN)Convolution)optional_importeinops	rearrange)nameFeedForwardCABlockc                      s,   e Zd ZdZd fd	d
ZdddZ  ZS )r	   a  Gated-DConv Feed-Forward Network (GDFN) that controls feature flow using gating mechanism.
    Uses depth-wise convolutions for local context mixing and GELU-activated gating for refined feature selection.

    Args:
        spatial_dims: Number of spatial dimensions (2D or 3D)
        dim: Number of input channels
        ffn_expansion_factor: Factor to expand hidden features dimension
        bias: Whether to use bias in convolution layers
    spatial_dimsintdimffn_expansion_factorfloatbiasboolc                   sr   t    t|| }t|||d d|dd| _t||d |d ddd|d |dd	| _t|||d|dd| _d S )N      Tr   in_channelsout_channelskernel_sizer   	conv_only   	r   r   r   r   stridespaddinggroupsr   r   )super__init__r   r   
project_indwconvproject_out)selfr   r   r   r   Zhidden_features	__class__ _/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/cablock.pyr   &   s:   
	zFeedForward.__init__xtorch.Tensorreturnc                 C  s>   |  |}| |jddd\}}ttj| t|| S )Nr   r   r   )	r    r!   chunkr   torchTensorr"   Fgelu)r#   r(   x1x2r&   r&   r'   forwardH   s   
zFeedForward.forward)r   r   r   r   r   r   r   r   r(   r)   r*   r)   )__name__
__module____qualname____doc__r   r3   __classcell__r&   r&   r$   r'   r	      s    
"c                      sF   e Zd ZdZdd fd	d
Zdd Zdd Zdd ZdddZ  Z	S )r
   a5  Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
    by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
    convolutions for local mixing before attention, achieving linear complexity vs quadratic
    in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>

    Args:
        spatial_dims: Number of spatial dimensions (2D or 3D)
        dim: Number of input channels
        num_heads: Number of attention heads
        bias: Whether to use bias in convolution layers
        flash_attention: Whether to use flash attention optimization. Defaults to False.

    Raises:
        ValueError: If flash attention is not available in current PyTorch version
        ValueError: If spatial_dims is greater than 3
    Fr   r   	num_headsr   r   flash_attentionc                   s   t    |rttdstd|dkrtd| || _|| _tt	
|dd| _|| _t|||d d|dd| _t||d |d ddd|d |dd	| _t|||d|dd| _|  | _d S )	Nscaled_dot_product_attentionzFlash attention not availabler   z6Only 2D and 3D inputs are supported. Got spatial_dims=r   Tr   r   )r   r   hasattrr/   
ValueErrorr   r:   nn	Parameterr-   onestemperaturer;   r   qkv
qkv_dwconvr"   _get_attention_fn_attention_fn)r#   r   r   r:   r   r;   r$   r&   r'   r   `   s6   
zCABlock.__init__c                 C  s   | j r| jS | jS )N)r;   _flash_attention_normal_attention)r#   r&   r&   r'   rE      s   zCABlock._get_attention_fnc                 C  s(   t | j }tj||||ddd}|S )zBFlash attention implementation using scaled dot-product attention.g        F)scale	dropout_p	is_causal)r   rB   meanr/   r<   )r#   qkvrI   outr&   r&   r'   rG      s   zCABlock._flash_attentionc                 C  s*   || dd | j }|jdd}|| S )z=Attention matrix multiplication with depth-wise convolutions.r+   )	transposerB   softmax)r#   rM   rN   rO   attnr&   r&   r'   rH      s   zCABlock._normal_attentionr(   r)   r*   c           
      C  s   |j dd }| | |}|jddd\}}}| jdkr#d}d}nd}d	}t||| jd
}t||| jd
}t||| jd
}tjj	j
|dd}tjj	j
|dd}| |||}	t|	|fd| jitt| jdkrkddgng d|}	ttj| |	S )a   Forward pass for MDTA attention.
        1. Apply depth-wise convolutions to Q, K, V
        2. Reshape Q, K, V for multi-head attention
        3. Compute attention matrix using flash or normal attention
        4. Reshape and project out attention outputr   Nr   r   r+   z b (head c) h w -> b head c (h w)z b head c (h w) -> b (head c) h wz$b (head c) d h w -> b head c (d h w)z$b head c (d h w) -> b (head c) d h w)headrR   rV   hw)drW   rX   )shaperD   rC   r,   r   r   r:   r-   r?   
functional	normalizerF   dictzipr   r.   r"   )
r#   r(   r   rC   rM   rN   rO   Zqkv_to_multiheadZmultihead_to_qkvrP   r&   r&   r'   r3      s.   
"zCABlock.forward)F)r   r   r:   r   r   r   r;   r   r4   )
r5   r6   r7   r8   r   rE   rG   rH   r3   r9   r&   r&   r$   r'   r
   N   s    !)
__future__r   typingr   r-   torch.nnr?   torch.nn.functionalr[   r/   "monai.networks.blocks.convolutionsr   monai.utilsr   r   ___all__Moduler	   r
   r&   r&   r&   r'   <module>   s   
3