o
    -io                     @  s   d dl mZ d dlmZ d dlmZ d dlZd dlmZ d dl	m  m
Z d dlmZmZmZ d dlmZmZ eddd	\ZZd
gZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZG dd
 d
ejZdS )    )annotations)Sequence)ListN)ConvolutionSpatialAttentionBlockUpsample)ensure_tuple_repoptional_importzeinops.layers.torch	Rearrange)nameAutoencoderKLc                      s,   e Zd ZdZd fddZdd
dZ  ZS )AsymmetricPadz
    Pad the input tensor asymmetrically along every spatial dimension.

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
    spatial_dimsintreturnNonec                   s   t    d| | _d S )N)r      )super__init__pad)selfr   	__class__ c/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/autoencoderkl.pyr   %   s   
zAsymmetricPad.__init__xtorch.Tensorc                 C  s   t jj|| jddd}|S )Nconstantg        )modevalue)nn
functionalr   r   r   r   r   r   forward)   s   zAsymmetricPad.forward)r   r   r   r   r   r   r   r   __name__
__module____qualname____doc__r   r#   __classcell__r   r   r   r   r      s    r   c                      s,   e Zd ZdZd fddZdddZ  ZS )AEKLDownsamplez
    Convolution-based downsampling layer.

    Args:
        spatial_dims: number of spatial dimensions (1D, 2D, 3D).
        in_channels: number of input channels.
    r   r   in_channelsr   r   c              	     s2   t    t|d| _t|||ddddd| _d S )N)r         r   Tr   r,   out_channelsstrideskernel_sizepadding	conv_only)r   r   r   r   r   conv)r   r   r,   r   r   r   r   7   s   
zAEKLDownsample.__init__r   r   c                 C  s   |  |}| |}|S N)r   r5   r"   r   r   r   r#   E      

zAEKLDownsample.forward)r   r   r,   r   r   r   r$   r%   r   r   r   r   r+   .   s    r+   c                      s,   e Zd ZdZd fddZdddZ  ZS )AEKLResBlocka)  
    Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
    residual connection between input and output.

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
        in_channels: input channels to the layer.
        norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
            channels is divisible by this number.
        norm_eps: epsilon for the normalisation.
        out_channels: number of output channels.
    r   r   r,   norm_num_groupsnorm_epsfloatr0   r   r   c              	     s   t    || _|d u r|n|| _tj|||dd| _t|| j| jddddd| _tj|||dd| _	t|| j| jddddd| _
|  | j| jkrYt|| j| jddddd| _d S t | _d S )NT
num_groupsnum_channelsepsaffiner   r.   r/   r   )r   r   r,   r0   r    	GroupNormnorm1r   conv1norm2conv2nin_shortcutIdentity)r   r   r,   r9   r:   r0   r   r   r   r   Y   sF   
	

zAEKLResBlock.__init__r   r   c                 C  sR   |}|  |}t|}| |}| |}t|}| |}| |}|| S r6   )rB   FsilurC   rD   rE   rF   )r   r   hr   r   r   r#      s   






zAEKLResBlock.forward)r   r   r,   r   r9   r   r:   r;   r0   r   r   r   r$   r%   r   r   r   r   r8   K   s    *r8   c                      s6   e Zd ZdZ				dd fddZdddZ  ZS ) Encodera  
    Convolutional cascade that downsamples the image into a spatial latent space.

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
        in_channels: number of input channels.
        channels: sequence of block output channels.
        out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
        num_res_blocks: number of residual blocks (see _ResBlock) per level.
        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
        norm_eps: epsilon for the normalization.
        attention_levels: indicate which level from channels contain an attention block.
        with_nonlocal_attn: if True use non-local attention block.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    TFr   r   r,   channelsSequence[int]r0   num_res_blocksr9   r:   r;   attention_levelsSequence[bool]with_nonlocal_attnbool
include_fcuse_combined_linearuse_flash_attentionr   r   c                   s  t    || _|| _|| _|| _|| _|| _|| _|| _	g }|
t|||d ddddd |d }tt|D ]E}|}|| }|t|d k}t| j| D ]"}|
t|||||d |}|| rr|
t|||||
||d qP|s~|
t||d q9|	du r|
t||d	 |||d	 d |
t||d	 |||
||d |
t||d	 |||d	 d |
tj||d	 |dd
 |
t| j|d	 |ddddd t|| _d S )Nr   r   r.   Tr/   r   r,   r9   r:   r0   r   r>   r9   r:   rS   rT   rU   )r   r,   r<   )r   r   r   r,   rL   r0   rN   r9   r:   rO   appendr   rangelenr8   r   r+   r    rA   
ModuleListblocks)r   r   r,   rL   r0   rN   r9   r:   rO   rQ   rS   rT   rU   r]   output_channeliinput_channelis_final_block_r   r   r   r      s   
	

zEncoder.__init__r   r   c                 C     | j D ]}||}q|S r6   r]   r   r   blockr   r   r   r#     r7   zEncoder.forward)TTFF)r   r   r,   r   rL   rM   r0   r   rN   rM   r9   r   r:   r;   rO   rP   rQ   rR   rS   rR   rT   rR   rU   rR   r   r   r$   r%   r   r   r   r   rK      s    xrK   c                      s8   e Zd ZdZ					dd fddZd ddZ  ZS )!Decodera  
    Convolutional cascade upsampling from a spatial latent space into an image space.

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
        channels: sequence of block output channels.
        in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
        out_channels: number of output channels.
        num_res_blocks: number of residual blocks (see _ResBlock) per level.
        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
        norm_eps: epsilon for the normalization.
        attention_levels: indicate which level from channels contain an attention block.
        with_nonlocal_attn: if True use non-local attention block.
        use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    TFr   r   rL   rM   r,   r0   rN   r9   r:   r;   rO   rP   rQ   rR   use_convtransposerS   rT   rU   r   r   c                   s
  t    || _|| _|| _|| _|| _|| _|| _|| _	t
t|}g }|t|||d ddddd |	du ri|t||d |||d d |t||d |||||d |t||d |||d d t
t|}t
t|}|d }tt|D ]c}|}|| }|t|d k}t|| D ]"}|t|||||d |}|| r|t|||||||d q|s|
r|t|d||d	 qt|||ddddd}|t|d
||dd|d d q|tj|||dd |t|||ddddd t|| _d S )Nr   r   r.   Tr/   rV   rW   deconv)r   r   r,   r0   nontrainablenearestg       @)r   r   r,   r0   interp_modescale_factor	post_convalign_cornersr<   )r   r   r   rL   r,   r0   rN   r9   r:   rO   listreversedrY   r   r8   r   rZ   r[   r   r    rA   r\   r]   )r   r   rL   r,   r0   rN   r9   r:   rO   rQ   rh   rS   rT   rU   Zreversed_block_out_channelsr]   Zreversed_attention_levelsZreversed_num_res_blocksZblock_out_chr_   Zblock_in_chra   rb   rn   r   r   r   r   9  s   
	
		zDecoder.__init__r   r   c                 C  rc   r6   rd   re   r   r   r   r#     r7   zDecoder.forward)TFTFF)r   r   rL   rM   r,   r   r0   r   rN   rM   r9   r   r:   r;   rO   rP   rQ   rR   rh   rR   rS   rR   rT   rR   rU   rR   r   r   r$   r%   r   r   r   r   rg   $  s     rg   c                      s   e Zd ZdZ													
	
			
	
d>d? fd#d$Zd@d(d)ZdAd,d-ZdBd.d/ZdCd1d2ZdDd4d5Z	dBd6d7Z
dCd8d9ZdEdFd<d=Z  ZS )Gr   a  
    Autoencoder model with KL-regularized latent space based on
    Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
    and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162

    Args:
        spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
        in_channels: number of input channels.
        out_channels: number of output channels.
        num_res_blocks: number of residual blocks (see _ResBlock) per level.
        channels: number of output channels for each block.
        attention_levels: sequence of levels to add attention.
        latent_channels: latent embedding dimension.
        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
        norm_eps: epsilon for the normalization.
        with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
        with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
        use_checkpoint: if True, use activation checkpoint to save memory.
        use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
        include_fc: whether to include the final linear layer in the attention block. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection in the attention block, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    r   r-   r-   r-   r-       @   ru   ru   FFTTr.   rt   ư>TFr   r   r,   r0   rN   Sequence[int] | intrL   rM   rO   rP   latent_channelsr9   r:   r;   with_encoder_nonlocal_attnrR   with_decoder_nonlocal_attnuse_checkpointrh   rS   rT   rU   r   r   c                   s  t    t fdd|D rtdt|t|kr tdt|tr,t|t|}t|t|kr8tdt||||| |	||
|||d| _	t
||||| |	||||||d| _t|||ddd	d
d| _t|||ddd	d
d| _t|||ddd	d
d| _|| _|| _d S )Nc                 3  s    | ]	}|  d kV  qdS )r   Nr   ).0out_channelr9   r   r   	<genexpr>  s    z)AutoencoderKL.__init__.<locals>.<genexpr>zDAutoencoderKL expects all channels being multiple of norm_num_groupszBAutoencoderKL expects channels being same size of attention_levelszf`num_res_blocks` should be a single integer or a tuple of integers with the same length as `channels`.)r   r,   rL   r0   rN   r9   r:   rO   rQ   rS   rT   rU   )r   rL   r,   r0   rN   r9   r:   rO   rQ   rh   rS   rT   rU   r   r   Tr/   )r   r   any
ValueErrorr[   
isinstancer   r   rK   encoderrg   decoderr   quant_conv_muquant_conv_log_sigmapost_quant_convry   r|   )r   r   r,   r0   rN   rL   rO   ry   r9   r:   rz   r{   r|   rh   rS   rT   rU   r   r   r   r     s   

			
zAutoencoderKL.__init__r   r   !tuple[torch.Tensor, torch.Tensor]c                 C  s`   | j rtjjj| j|dd}n| |}| |}| |}t|dd}t|d }||fS )z
        Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.

        Args:
            x: BxCx[SPATIAL DIMS] tensor

        Fuse_reentrantg      >g      4@r-   )	r|   torchutils
checkpointr   r   r   clampexp)r   r   rJ   z_muZ	z_log_varz_sigmar   r   r   encodeR  s   


zAutoencoderKL.encoder   r   c                 C  s   t |}|||  }|S )aE  
        From the mean and sigma representations resulting of encoding an image through the latent space,
        obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
        adding the mean.

        Args:
            z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
            z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image

        Returns:
            sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
        )r   
randn_like)r   r   r   r?   Zz_vaer   r   r   samplingf  s   
zAutoencoderKL.samplingc                 C  s   |  |\}}| |}|S )z
        Encodes and decodes an input image.

        Args:
            x: BxCx[SPATIAL DIMENSIONS] tensor.

        Returns:
            reconstructed image, of the same shape as input
        )r   decode)r   r   r   rb   reconstructionr   r   r   reconstructw  s   

zAutoencoderKL.reconstructzc                 C  s8   |  |}| jrtjjj| j|dd}|S | |}|S )z
        Based on a latent space sample, forwards it through the Decoder.

        Args:
            z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]

        Returns:
            decoded image tensor
        Fr   )r   r|   r   r   r   r   )r   r   decr   r   r   r     s   


zAutoencoderKL.decode/tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s.   |  |\}}| ||}| |}|||fS r6   )r   r   r   )r   r   r   r   r   r   r   r   r   r#     s   

zAutoencoderKL.forwardc                 C  s   |  |\}}| ||}|S r6   )r   r   )r   r   r   r   r   r   r   r   encode_stage_2_inputs  s   z#AutoencoderKL.encode_stage_2_inputsc                 C  s   |  |}|S r6   )r   )r   r   imager   r   r   decode_stage_2_outputs  s   
z$AutoencoderKL.decode_stage_2_outputsold_state_dictdictc                   s  |    t fdd|D rtd | | dS |rB D ]}||vr,td| d qtd |D ]}| vrAtd| d q3 D ]}||v rQ|| |< qDd	d
  D }|D ]r}|| d | d< || d | d< || d | d< || d | d< || d | d< || d | d< t | d jd  | d< t | d j | d< q[ D ]}d|v r|	dd}|| |< q|rtd|
  | j dd dS )z
        Load a state dict from an AutoencoderKL trained with [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).

        Args:
            old_state_dict: state dict from the old AutoencoderKL model.
        c                 3  s    | ]}| v V  qd S r6   r   r}   knew_state_dictr   r   r     s    z4AutoencoderKL.load_old_state_dict.<locals>.<genexpr>z#All keys match, loading state dict.Nzkey z not found in old state dictz.----------------------------------------------z not found in new state dictc                 S  s    g | ]}d |v r| ddqS )zattn.to_q.weight.attn.to_q.weight )replacer   r   r   r   
<listcomp>  s     z5AutoencoderKL.load_old_state_dict.<locals>.<listcomp>z.to_q.weightr   z.to_k.weightz.attn.to_k.weightz.to_v.weightz.attn.to_v.weightz
.to_q.biasz.attn.to_q.biasz
.to_k.biasz.attn.to_k.biasz
.to_v.biasz.attn.to_v.biasz.attn.out_proj.weightr   z.attn.out_proj.biaspostconvr5   z!remaining keys in old_state_dict:T)strict)
state_dictallprintload_state_dictpopr   eyeshapezerosr   keys)r   r   verboser   Zattention_blocksrf   old_namer   r   r   load_old_state_dict  sR   
z!AutoencoderKL.load_old_state_dict)r   r   rr   rs   rv   r.   rt   rw   TTFFTFF)"r   r   r,   r   r0   r   rN   rx   rL   rM   rO   rP   ry   r   r9   r   r:   r;   rz   rR   r{   rR   r|   rR   rh   rR   rS   rR   rT   rR   rU   rR   r   r   )r   r   r   r   )r   r   r   r   r   r   r$   )r   r   r   r   )r   r   r   r   )F)r   r   r   r   )r&   r'   r(   r)   r   r   r   r   r   r#   r   r   r   r*   r   r   r   r   r     s4    
`





)
__future__r   collections.abcr   typingr   r   torch.nnr    torch.nn.functionalr!   rH   Zmonai.networks.blocksr   r   r   monai.utilsr   r	   r
   rb   __all__Moduler   r+   r8   rK   rg   r   r   r   r   r   <module>   s$   G  5