o
    ,ia2                     @  s   d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 d dl
mZ d dlmZ d dlmZ d dlmZmZ G d	d
 d
ejZG dd de	ZG dd dejZdS )    )annotationsN)CABlockFeedForward)Convolution)
DownSample)UpSample)Norm)DownsampleModeUpsampleModec                      s2   e Zd ZdZ		dd fddZdddZ  ZS )MDTATransformerBlocka  Basic transformer unit combining MDTA and GDFN with skip connections.
    Unlike standard transformers that use LayerNorm, this block uses Instance Norm
    for better adaptation to image restoration tasks.

    Args:
        spatial_dims: Number of spatial dimensions (2D or 3D)
        dim: Number of input channels
        num_heads: Number of attention heads
        ffn_expansion_factor: Expansion factor for feed-forward network
        bias: Whether to use bias in attention layers
        layer_norm_use_bias: Whether to use bias in layer normalization. Defaults to False.
        flash_attention: Whether to use flash attention optimization. Defaults to False.
    Fspatial_dimsintdim	num_headsffn_expansion_factorfloatbiasboollayer_norm_use_biasflash_attentionc                   s`   t    ttj|f ||d| _t|||||| _ttj|f ||d| _t||||| _	d S )N)affine)
super__init__r   INSTANCEnorm1r   attnnorm2r   ffn)selfr   r   r   r   r   r   r   	__class__ _/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/restormer.pyr   '   s
   

zMDTATransformerBlock.__init__xtorch.Tensorreturnc                 C  s,   ||  | | }|| | | }|S N)r   r   r   r   r   r#   r!   r!   r"   forward7   s   zMDTATransformerBlock.forward)FF)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r#   r$   r%   r$   __name__
__module____qualname____doc__r   r(   __classcell__r!   r!   r   r"   r      s    r   c                      s2   e Zd ZdZdd fddZd fddZ  ZS )OverlapPatchEmbeda  Initial feature extraction using overlapped convolutions.
    Unlike standard patch embeddings that use non-overlapping patches,
    this approach maintains spatial continuity through 3x3 convolutions.

    Args:
        spatial_dims: Number of spatial dimensions (2D or 3D)
        in_channels: Number of input channels
        embed_dim: Dimension of embedded features. Defaults to 48.
        bias: Whether to use bias in convolution layer. Defaults to False.
       0   Fr   r   in_channels	embed_dimr   r   c              
     s    t  j|||ddd|dd d S )Nr1      Tr   r3   out_channelskernel_sizestridespaddingr   	conv_only)r   r   )r   r   r3   r4   r   r   r!   r"   r   I   s   
zOverlapPatchEmbed.__init__r#   r$   r%   c                   s   t  |}|S r&   )r   r(   r'   r   r!   r"   r(   U   s   zOverlapPatchEmbed.forward)r1   r2   F)r   r   r3   r   r4   r   r   r   r)   r*   r!   r!   r   r"   r0   =   s    r0   c                      sF   e Zd ZdZ													d"d# fddZd$d d!Z  ZS )%	Restormera{  Restormer: Efficient Transformer for High-Resolution Image Restoration.

    Implements a U-Net style architecture with transformer blocks, combining:
    - Multi-scale feature processing through progressive down/upsampling
    - Efficient attention via MDTA blocks
    - Local feature mixing through GDFN
    - Skip connections for preserving spatial details

    Architecture:
        - Encoder: Progressive feature downsampling with increasing channels
        - Latent: Deep feature processing at lowest resolution
        - Decoder: Progressive upsampling with skip connections
        - Refinement: Final feature enhancement
       r1   r2   r5   r5   r5   r5      HzG@FTr   r   r3   r7   r   
num_blockstuple[int, ...]headsnum_refinement_blocksr   r   r   r   r   dual_pixel_taskr   r%   Nonec                   s  t    	 t|dksJ dt|tksJ dtdd |D s)J dt
||| _t | _t | _	t | _
t | _t | _t|d 		| _
| _d
d  }t	D ]7|d  | }| jtj 
fdd	t| D   | j	t| j|tjd d
 q_|d	  tj 	
fdd	t|	 D  | _tt	D ]T|d  |dd   }| jt| j|tjd dd dkr| jt| j|d dd n|| j
tj 
fdd	t| D   qtj 
fdd	t|D  | _|| _| jr:t| j||d d dd| _t| j|d |ddd dd| _d S )Nr5   z'Number of blocks must be greater than 1z(Number of blocks and heads must be equalc                 s  s    | ]}|d kV  qdS )r   Nr!   ).0nr!   r!   r"   	<genexpr>   s    z%Restormer.__init__.<locals>.<genexpr>z'Number of blocks must be greater than 0r=   c                   &   g | ]}t   d qS r   r   r   r   r   r   r   r   rG   _)r   current_dimr   r   rC   r   rH   r   r!   r"   
<listcomp>       
z&Restormer.__init__.<locals>.<listcomp>)r   r3   r7   modescale_factorr   c                   s&   g | ]}t   d qS rK   rM   rN   )r   r   r   rC   
latent_dimr   	num_stepsr   r!   r"   rQ      rR   F)r   r3   r7   rS   rT   r   apply_pad_poolr   T)r   r3   r7   r8   r   r;   c                   rJ   rK   rM   rN   )r   decoder_dimr   r   rC   r   rH   r   r!   r"   rQ      rR   c                   s&   g | ]}t d   dqS )r   rL   rM   rN   )r   rX   r   r   rC   r   r   r!   r"   rQ     rR   r1   r6   )r   r   lenallr0   patch_embednn
ModuleListencoder_levelsdownsamplesdecoder_levels	upsamplesreduce_channelsrV   r   rangeappend
Sequentialr   r	   PIXELUNSHUFFLElatentreversedr   r
   PIXELSHUFFLEr   
refinementrE   	skip_convoutput)r   r   r3   r7   r   rA   rC   rD   r   r   r   rE   r   Zspatial_multiplierZnext_dimr   )r   rP   rX   r   r   rC   rU   r   rH   rV   r   r"   r   j   s   













zRestormer.__init__r#   r$   c                   s&  t  fddtd jd D sJ d g }tt j jD ]\}\}}|| |q& 	tt
 jD ]-} j| t||d   gd|t
 jd k rl j|  j| qF  jr |d   S  S )a  Forward pass of Restormer.
        Processes input through encoder-decoder architecture with skip connections.
        Args:
            inp_img: Input image tensor of shape (B, C, H, W, [D])

        Returns:
            Restored image tensor of shape (B, C, H, W, [D])
        c                 3  s&    | ]}j |  d  j kV  qdS )r=   N)shaperV   )rG   ir'   r!   r"   rI   /  s    
z$Restormer.forward.<locals>.<genexpr>r5   z=All spatial dimensions should be larger than 2^number_of_stepr   )rZ   rc   r   r[   	enumeratezipr^   r_   rd   rg   rY   r`   ra   torchconcatrb   rj   rE   rk   rl   )r   r#   Zskip_connections_idxencoder
downsampleidxr!   r'   r"   r(   &  s2   	







zRestormer.forward)r=   r1   r1   r2   r>   r>   r?   r@   FTFF)r   r   r3   r   r7   r   r   r   rA   rB   rC   rB   rD   r   r   r   r   r   r   r   rE   r   r   r   r%   rF   r)   r*   r!   r!   r   r"   r<   Z   s"     =r<   )
__future__r   rq   torch.nnr\   Zmonai.networks.blocks.cablockr   r   "monai.networks.blocks.convolutionsr   Z monai.networks.blocks.downsampler   monai.networks.blocks.upsampler   monai.networks.layers.factoriesr   monai.utils.enumsr	   r
   Moduler   r0   r<   r!   r!   r!   r"   <module>   s   
%