o
    i                     @  s(  d dl mZ d dlZd dlZd dlmZ d dlZd dlmZ d dl	m  m
Z d dlmZ d dlmZ d dlmZmZ d dlmZ eeZdddZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZ G dd dejZ!G dd deZ"dS )    )annotationsN)Sequence)Convolution)SpatialAttentionBlock)AEKLResBlockAutoencoderKL)convert_to_tensorsave_memboolreturnNonec                 C  s   t j r| rt j  d S N)torchcudais_availableempty_cache)r	    r   z/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/generation/maisi/networks/autoencoderkl_maisi.py_empty_cuda_cache   s   
r   c                      s@   e Zd ZdZ					dd fddZdddZdd Z  ZS )MaisiGroupNorm3DaY  
    Custom 3D Group Normalization with optional print_info output.

    Args:
        num_groups: Number of groups for the group norm.
        num_channels: Number of channels for the group norm.
        eps: Epsilon value for numerical stability.
        affine: Whether to use learnable affine parameters, default to `True`.
        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
        print_info: Whether to print information, default to `False`.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    h㈵>TF
num_groupsintnum_channelsepsfloataffiner
   norm_float16
print_infor	   c                   s(   t  |||| || _|| _|| _d S r   )super__init__r   r   r	   )selfr   r   r   r   r   r   r	   	__class__r   r   r    3   s   

zMaisiGroupNorm3D.__init__inputtorch.Tensorr   c              	   C  s  | j rtd|   t|jdkrtd|j\}}}}}||| j|| j |||}g }t	|dD ]H}|d d ||d df j
tjd}	|	jg ddd	}
|	jg dd
dd| j }| jrv||	|
 | j
tjd q7||	|
 |  q7~t| j t|d  dk rtj|ddn| |}||||||}| jr|| jd|ddd| jd|ddd | j rtd|   |S )Nz"MaisiGroupNorm3D with input size:    zExpected a 5D tensor   .)dtype)         r&   T)keepdimF)unbiasedr,   r     dimz#MaisiGroupNorm3D with output size: )r   loggerinfosizelenshape
ValueErrorviewr   rangetor   float32meanvaradd_r   sqrt_r   appendfloat16r   r	   maxcat_cat_inputsr   mul_weightbias)r!   r$   Zparam_nZparam_cZparam_dZparam_hZparam_winputsiarrayr;   stdr   r   r   forwardB   s.   $ 
,0zMaisiGroupNorm3D.forwardc                 C  s   |d j j}|dkr|d  jdddn|d  }d|d< t| j tt|d D ]5}tj	|||d  
 fdd}d||d < t| j t  | jratd|d  d	t|d  d
 q,|dkrm|jdddS |S )Nr   r   cpuTnon_blockingr'   r/   z"MaisiGroupNorm3D concat progress: /.)devicetypecloner9   r   r	   r8   r4   r   rB   rL   gccollectr   r1   r2   )r!   rG   
input_typer$   kr   r   r   rC   d   s   *

$zMaisiGroupNorm3D._cat_inputs)r   TFFT)r   r   r   r   r   r   r   r
   r   r
   r   r
   r	   r
   )r$   r%   r   r%   )__name__
__module____qualname____doc__r    rK   rC   __classcell__r   r   r"   r   r   %   s    
"r   c                      s`   e Zd ZdZ																	d5d6 fd(d)Zd7d.d/Zd8d1d2Zd9d3d4Z  ZS ):MaisiConvolutiona  
    Convolutional layer with optional print_info output and custom splitting mechanism.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        print_info: Whether to print information.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
        Additional arguments for the convolution operation.
        https://docs.monai.io/en/stable/networks.html#convolution
    Tr'   r*   NDAPRELUINSTANCENFspatial_dimsr   in_channelsout_channels
num_splits	dim_splitr   r
   r	   stridesSequence[int] | intkernel_sizeadn_orderingstracttuple | str | Nonenormdropouttuple | str | float | Nonedropout_dimdilationgroupsrF   	conv_onlyis_transposedpaddingSequence[int] | int | Noneoutput_paddingr   r   c                   s   t    tdi d|d|d|d|d|	d|
d|d|d	|d
|d|d|d|d|d|d|d|| _|| _t|trK|| j n|| _|| _|| _	|| _
d S )Nra   rb   rc   rf   rh   ri   rk   rm   rn   rp   rq   rr   rF   rs   rt   ru   rw   r   )r   r    r   convre   
isinstanceliststriderd   r   r	   )r!   ra   rb   rc   rd   re   r   r	   rf   rh   ri   rk   rm   rn   rp   rq   rr   rF   rs   rt   ru   rw   r"   r   r   r       sR   
	

zMaisiConvolution.__init__xr%   
split_sizelist[torch.Tensor]c           
   
   C  s   dg|g| j d   }|| jd | }td gd }g }t| j D ]*}t|| ||  |d | || j d kr;|n| || jd < ||t|  q#| jrptt|D ]}	t	
d|	d  dt| d||	    qW|S )Nr   r'   r)   r&   zSplit rO   z size: )rd   r3   re   slicer8   r?   tupler   r4   r1   r2   )
r!   r|   r}   ru   overlapsZlast_paddingslicessplitsrH   jr   r   r   _split_tensor   s    .zMaisiConvolution._split_tensoroutputsc              
   C  s  t d gd }t| jD ]!}|dkrt d |nt ||| || jd < || t| ||< q| jrOt| jD ]}td|d  dt| d|| 	   q6t
|d 	 dk retj|| jd d	}|S |d  jd
dd}td|d< t| j tt|d D ];}tj|||d   f| jd d	}td||d < t| j t  | jrtd|d  dt|d  d q|jddd}|S )Nr&   r   r)   Output r'   rO   z size after: r.   r/   rL   TrM   z"MaisiConvolution concat progress: rP   r   )r   r8   rd   re   r   r   r1   r2   r4   r3   rA   r   rB   rS   r9   Tensorr   r	   rL   rT   rU   )r!   r   r}   ru   r   rH   r|   rW   r   r   r   _concatenate_tensors   s.   *.
$
$z%MaisiConvolution._concatenate_tensorsc              
     s   j rtd j   jdkr |}|S | jd }| j }d}| j dkr8| j d  j } j rCtd|   |||}~t	 j
  fdd|D } j r{tt|D ]}td	|d  d
t| d||    qb|}|}	 jdk r jd nd}
|d |
d |d |
d  dkr|d9 }|	d9 }	n|d |
d |d |
d  dkr|d }|	d }	 |||	}~t	 j
 |S )NzNumber of splits: r'   r)   r*   r   zPadding size: c                   s   g | ]}  |qS r   )rx   ).0splitr!   r   r   
<listcomp>  s    z,MaisiConvolution.forward.<locals>.<listcomp>r   rO   z size before: )r   r1   r2   rd   rx   r3   re   r{   r   r   r	   r8   r4   r   )r!   r|   lr}   ru   r   r   r   Zsplit_size_outZ	padding_sZnon_dim_splitr   r   r   rK      s@   



.(
(
zMaisiConvolution.forward)Tr'   r*   r^   r_   r`   Nr'   r'   r'   TFFNN),ra   r   rb   r   rc   r   rd   r   re   r   r   r
   r	   r
   rf   rg   rh   rg   ri   rj   rk   rl   rm   rl   rn   ro   rp   r   rq   rg   rr   r   rF   r
   rs   r
   rt   r
   ru   rv   rw   rv   r   r   )r|   r%   r}   r   ru   r   r   r~   )r   r~   r}   r   ru   r   r   r%   r|   r%   r   r%   )	rX   rY   rZ   r[   r    r   r   rK   r\   r   r   r"   r   r]   v   s*    
3
r]   c                      s0   e Zd ZdZ	dd fddZdddZ  ZS )MaisiUpsamplea  
    Convolution-based upsampling layer.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        in_channels: Number of input channels to the layer.
        use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        print_info: Whether to print information.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    Tra   r   rb   use_convtransposer
   rd   re   r   r	   r   r   c                   sD   t    t||||rdndddd|||||d| _|| _|| _d S )Nr)   r'   r*   T)ra   rb   rc   rf   rh   ru   rs   rt   rd   re   r   r	   )r   r    r]   rx   r   r	   )r!   ra   rb   r   rd   re   r   r	   r"   r   r   r    (  s"   



zMaisiUpsample.__init__r|   r%   c                 C  sV   | j r| |}t|}|S tj|ddd}t| j | |}t| j t|}|S )Ng       @	trilinear)scale_factormode)r   rx   r   Finterpolater   r	   )r!   r|   x_tensor
out_tensorr   r   r   rK   D  s   



zMaisiUpsample.forwardT)ra   r   rb   r   r   r
   rd   r   re   r   r   r
   r	   r
   r   r   r   rX   rY   rZ   r[   r    rK   r\   r   r   r"   r   r     s
    r   c                      s0   e Zd ZdZ	dd fddZdddZ  ZS )MaisiDownsamplea  
    Convolution-based downsampling layer.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        in_channels: Number of input channels.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        print_info: Whether to print information.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    Tra   r   rb   rd   re   r   r
   r	   r   r   c                   s8   t    d| | _t|||dddd||||d| _d S )N)r   r'   r)   r*   r   Tra   rb   rc   rf   rh   ru   rs   rd   re   r   r	   )r   r    padr]   rx   )r!   ra   rb   rd   re   r   r	   r"   r   r   r    `  s   
	
zMaisiDownsample.__init__r|   r%   c                 C  s"   t j|| jddd}| |}|S )Nconstantg        )r   value)r   r   rx   )r!   r|   r   r   r   rK   y  s   
zMaisiDownsample.forwardr   )ra   r   rb   r   rd   r   re   r   r   r
   r	   r
   r   r   r   r   r   r   r"   r   r   S  s
    r   c                      s4   e Zd ZdZ			dd fddZdddZ  ZS )MaisiResBlockaJ  
    Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
    residual connection between input and output.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        in_channels: Input channels to the layer.
        norm_num_groups: Number of groups for the group norm layer.
        norm_eps: Epsilon for the normalization.
        out_channels: Number of output channels.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
        print_info: Whether to print information, default to `False`.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    FTra   r   rb   norm_num_groupsnorm_epsr   rc   rd   re   r   r
   r   r	   r   r   c                   s   t    || _|d u r|n|| _|
| _t|||d||	|
d| _t|| j| jdddd|||	|
d| _t|||d||	|
d| _	t|| j| jdddd|||	|
d| _
| j| jkrjt|| j| jdddd|||	|
d| _d S t | _d S )NTr   r   r   r   r   r   r	   r'   r*   r   r   )r   r    rb   rc   r	   r   norm1r]   conv1norm2conv2nnIdentitynin_shortcut)r!   ra   rb   r   r   rc   rd   re   r   r   r	   r"   r   r   r      s   
		zMaisiResBlock.__init__r|   r%   c                 C  s   |  |}t| j t|}t| j | |}t| j | |}t| j t|}t| j | |}t| j | j| j	krL| 
|}t| j || }t|}|S r   )r   r   r	   r   silur   r   r   rb   rc   r   r   )r!   r|   houtr   r   r   r   rK     s$   













zMaisiResBlock.forward)FFT)ra   r   rb   r   r   r   r   r   rc   r   rd   r   re   r   r   r
   r   r
   r	   r
   r   r   r   r   r   r   r"   r   r     s    Qr   c                      s<   e Zd ZdZ							d"d# fddZd$d d!Z  ZS )%MaisiEncodera  
    Convolutional cascade that downsamples the image into a spatial latent space.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        in_channels: Number of input channels.
        num_channels: Sequence of block output channels.
        out_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
        num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
        norm_num_groups: Number of groups for the group norm layers.
        norm_eps: Epsilon for the normalization.
        attention_levels: Indicate which level from num_channels contain an attention block.
        with_nonlocal_attn: If True, use non-local attention block.
        include_fc: whether to include the final linear layer in the attention block. Default to False.
        use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
        print_info: Whether to print information, default to `False`.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    FTra   r   rb   r   Sequence[int]rc   num_res_blocksr   r   r   attention_levelsSequence[bool]rd   re   r   r
   r   r	   with_nonlocal_attn
include_fcuse_combined_linearuse_flash_attentionr   r   c                   s  t    t|t|krtdt|t|krtd|| _g }|t|||d dddd|	|
||d |d }tt|D ]M}|}|| }|t|d k}t|| D ]'}|t||||||	|
|||d
 |}|| r}|t	|||||||d	 qV|s|t
|||	|
||d
 q@|r|t||d |||d d |t	||d |||||d	 |t||d |||d d |t||d |d|||d |t||d |dddd|	|
||d t|| _d S )Nz9attention_levels and num_channels must have the same sizez7num_res_blocks and num_channels must have the same sizer   r'   r*   Tr   
ra   rb   r   r   rc   rd   re   r   r   r	   ra   r   r   r   r   r   r   )ra   rb   rd   re   r   r	   ra   rb   r   r   rc   r   )r   r    r4   r6   r	   r?   r]   r8   r   r   r   r   r   r   
ModuleListblocks)r!   ra   rb   r   rc   r   r   r   r   rd   re   r   r   r	   r   r   r   r   r   output_channelrH   input_channelis_final_block_r"   r   r   r      s   


zMaisiEncoder.__init__r|   r%   c                 C  "   | j D ]}||}t| j q|S r   r   r   r	   r!   r|   blockr   r   r   rK        
zMaisiEncoder.forward)FFTTFFF)$ra   r   rb   r   r   r   rc   r   r   r   r   r   r   r   r   r   rd   r   re   r   r   r
   r   r
   r	   r
   r   r
   r   r
   r   r
   r   r
   r   r   r   r   r   r   r"   r   r     s    # r   c                      s>   e Zd ZdZ								d#d$ fddZd%d!d"Z  ZS )&MaisiDecodera  
    Convolutional cascade upsampling from a spatial latent space into an image space.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        num_channels: Sequence of block output channels.
        in_channels: Number of channels in the bottom layer (latent space) of the autoencoder.
        out_channels: Number of output channels.
        num_res_blocks: Number of residual blocks (see AEKLResBlock) per level.
        norm_num_groups: Number of groups for the group norm layers.
        norm_eps: Epsilon for the normalization.
        attention_levels: Indicate which level from num_channels contain an attention block.
        with_nonlocal_attn: If True, use non-local attention block.
        include_fc: whether to include the final linear layer in the attention block. Default to False.
        use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
        use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
        print_info: Whether to print information, default to `False`.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    FTra   r   r   r   rb   rc   r   r   r   r   r   r   rd   re   r   r
   r   r	   r   r   r   r   r   r   r   c                   s  t    || _|| _tt|}g }|t|||d dddd|	|
||d |rY|t||d |||d d |t	||d |||||d |t||d |||d d tt|}tt|}|d }t
t|D ]N}|}|| }|t|d k}t
|| D ]'}|t||||||	|
|||d
 |}|| r|t	|||||||d q|s|t||||	|
||d	 qo|t|||d|||d
 |t|||dddd|	|
||d t|| _d S )Nr   r'   r*   Tr   r   r   r   )ra   rb   r   rd   re   r   r	   r   )r   r    r   r	   rz   reversedr?   r]   r   r   r8   r4   r   r   r   r   r   r   )r!   ra   r   rb   rc   r   r   r   r   rd   re   r   r   r	   r   r   r   r   r   reversed_block_out_channelsr   reversed_attention_levelsreversed_num_res_blocksblock_out_chrH   block_in_chr   r   r"   r   r   r      s   
	
zMaisiDecoder.__init__r|   r%   c                 C  r   r   r   r   r   r   r   rK   n  r   zMaisiDecoder.forward)FFTTFFFF)&ra   r   r   r   rb   r   rc   r   r   r   r   r   r   r   r   r   rd   r   re   r   r   r
   r   r
   r	   r
   r   r
   r   r
   r   r
   r   r
   r   r
   r   r   r   r   r   r   r"   r   r     s    $ r   c                      sB   e Zd ZdZ															d'd( fd%d&Z  ZS ))AutoencoderKlMaisiar  
    AutoencoderKL with custom MaisiEncoder and MaisiDecoder.

    Args:
        spatial_dims: Number of spatial dimensions (1D, 2D, 3D).
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        num_res_blocks: Number of residual blocks per level.
        num_channels: Sequence of block output channels.
        attention_levels: Indicate which level from num_channels contain an attention block.
        latent_channels: Number of channels in the latent space.
        norm_num_groups: Number of groups for the group norm layers.
        norm_eps: Epsilon for the normalization.
        with_encoder_nonlocal_attn: If True, use non-local attention block in the encoder.
        with_decoder_nonlocal_attn: If True, use non-local attention block in the decoder.
        include_fc: whether to include the final linear layer. Default to False.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
        use_checkpointing: If True, use activation checkpointing.
        use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
        num_splits: Number of splits for the input tensor.
        dim_split: Dimension of splitting for the input tensor.
        norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
        print_info: Whether to print information, default to `False`.
        save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
    r*       ư>F   r   Tra   r   rb   rc   r   r   r   r   r   latent_channelsr   r   r   with_encoder_nonlocal_attnr
   with_decoder_nonlocal_attnr   r   r   use_checkpointingr   rd   re   r   r   r	   r   r   c                   s  t  |||||||||	|
|||||| tdi d|d|d|d|d|d|d|	d|d	|
d
|d|d|d|d|d|d|d|| _tdi d|d|d|d|d|d|d|	d|d	|d
|d|d|d|d|d|d|d|d|| _d S )Nra   rb   r   rc   r   r   r   r   r   r   r   r   rd   re   r   r   r	   r   r   )r   r    r   encoderr   decoder)r!   ra   rb   rc   r   r   r   r   r   r   r   r   r   r   r   r   r   rd   re   r   r   r	   r"   r   r   r      s   	
	

zAutoencoderKlMaisi.__init__)r*   r   r   FFFFFFFr   r   FFT),ra   r   rb   r   rc   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r
   r   r
   r   r
   r   r
   r   r
   r   r
   r   r
   rd   r   re   r   r   r
   r   r
   r	   r
   r   r   )rX   rY   rZ   r[   r    r\   r   r   r"   r   r   u  s$    #r   )r	   r
   r   r   )#
__future__r   rT   loggingcollections.abcr   r   torch.nnr   torch.nn.functional
functionalr   monai.networks.blocksr   Z&monai.networks.blocks.spatialattentionr   Z!monai.networks.nets.autoencoderklr   r   monai.utils.type_conversionr   	getLoggerrX   r1   r   	GroupNormr   Moduler]   r   r   r   r   r   r   r   r   r   r   <module>   s0   

Q %9,} > =