o
    ,inL                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlm  mZ	 d dl
mZmZmZ d dlmZ d dlmZ d dlmZ dgZG d	d
 d
ejZG dd dejZG dd dejZdS )    )annotations)SequenceN)ConvolutionSpatialAttentionBlockUpsample)SPADE)Encoder)ensure_tuple_repSPADEAutoencoderKLc                      s,   e Zd ZdZd fddZdddZ  ZS )SPADEResBlocka2  
    Residual block consisting of a cascade of 2 convolutions + activation + normalisation block, and a
    residual connection between input and output.
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: number of spatial dimensions (1D, 2D, 3D).
        in_channels: input channels to the layer.
        norm_num_groups: number of groups involved for the group normalisation layer. Ensure that your number of
            channels is divisible by this number.
        norm_eps: epsilon for the normalisation.
        out_channels: number of output channels.
        label_nc: number of semantic channels for SPADE normalisation
        spade_intermediate_channels: number of intermediate channels for SPADE block layer
    spatial_dimsintin_channelsnorm_num_groupsnorm_epsfloatout_channelslabel_ncspade_intermediate_channelsreturnNonec              	     s   t    || _|d u r|n|| _t||d|d|d|d|d| _t|| j| jddddd| _t||d|d|d|d|d| _t|| j| jddddd| _	|  | j| jkret|| j| jddd	dd| _
d S t | _
d S )
NGROUPF)
num_groupsaffineeps   )r   norm_ncnormnorm_paramshidden_channelskernel_sizer      Tr   r   r   stridesr    padding	conv_onlyr   )super__init__r   r   r   norm1r   conv1norm2conv2nin_shortcutnnIdentity)selfr   r   r   r   r   r   r   	__class__ i/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/spade_autoencoderkl.pyr'   -   sf   


		
	

zSPADEResBlock.__init__xtorch.Tensorsegc                 C  sV   |}|  ||}t|}| |}| ||}t|}| |}| |}|| S N)r(   Fsilur)   r*   r+   r,   )r/   r4   r6   hr2   r2   r3   forwardm   s   




zSPADEResBlock.forward)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r4   r5   r6   r5   r   r5   __name__
__module____qualname____doc__r'   r;   __classcell__r2   r2   r0   r3   r      s    @r   c                      s8   e Zd ZdZ					d!d" fddZd#dd Z  ZS )$SPADEDecodera  
    Convolutional cascade upsampling from a spatial latent space into an image space.
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: number of spatial dimensions (1D, 2D, 3D).
        channels: sequence of block output channels.
        in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
        out_channels: number of output channels.
        num_res_blocks: number of residual blocks (see ResBlock) per level.
        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
        norm_eps: epsilon for the normalization.
        attention_levels: indicate which level from channels contain an attention block.
        label_nc: number of semantic channels for SPADE normalisation.
        with_nonlocal_attn: if True use non-local attention block.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    T   Fr   r   channelsSequence[int]r   r   num_res_blocksr   r   r   attention_levelsSequence[bool]r   with_nonlocal_attnboolr   
include_fcuse_combined_linearuse_flash_attentionr   r   c                   s   t    || _|| _|| _|| _|| _|| _|| _|| _	|	| _
tt|}g }|t|||d ddddd |
du rp|t||d |||d |	|d |t||d |||||d |t||d |||d |	|d tt|}tt|}|d }tt|D ]W}|}|| }|t|d k}t|| D ]$}|t||||||	|d |}|| r|t|||||||d q|st|||ddddd}|t|d||d	d
|d d q|tj|||dd |t|||ddddd t|| _d S )Nr   r!   r   Tr"   )r   r   r   r   r   r   r   )r   num_channelsr   r   rL   rM   rN   nontrainablenearestg       @)r   moder   r   interp_modescale_factor	post_convalign_corners)r   rO   r   r   )r&   r'   r   rE   r   r   rG   r   r   rH   r   listreversedappendr   r   r   rangelenr   r-   	GroupNorm
ModuleListblocks)r/   r   rE   r   r   rG   r   r   rH   r   rJ   r   rL   rM   rN   reversed_block_out_channelsr^   reversed_attention_levelsreversed_num_res_blocksblock_out_chiblock_in_chis_final_block_rU   r0   r2   r3   r'      s   
	zSPADEDecoder.__init__r4   r5   r6   c                 C  s.   | j D ]}t|tr|||}q||}q|S r7   )r^   
isinstancer   )r/   r4   r6   blockr2   r2   r3   r;   ,  s
   


zSPADEDecoder.forward)TrD   TFF)r   r   rE   rF   r   r   r   r   rG   rF   r   r   r   r   rH   rI   r   r   rJ   rK   r   r   rL   rK   rM   rK   rN   rK   r   r   r<   r=   r2   r2   r0   r3   rC   {   s    ! rC   c                      s   e Zd ZdZ													
				d<d= fd$d%Zd>d)d*Zd?d-d.Zd@d0d1ZdAd3d4ZdBd6d7Z	dCd8d9Z
dAd:d;Z  ZS )Dr
   a/  
    Autoencoder model with KL-regularized latent space based on
    Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
    and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: number of spatial dimensions (1D, 2D, 3D).
        label_nc: number of semantic channels for SPADE normalisation.
        in_channels: number of input channels.
        out_channels: number of output channels.
        num_res_blocks: number of residual blocks (see ResBlock) per level.
        channels: sequence of block output channels.
        attention_levels: sequence of levels to add attention.
        latent_channels: latent embedding dimension.
        norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
        norm_eps: epsilon for the normalization.
        with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
        with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer.
    r!      rj   rj   rj       @   rm   rm   FFTTr   rl   ư>TrD   Fr   r   r   r   r   rG   Sequence[int] | intrE   rF   rH   rI   latent_channelsr   r   r   with_encoder_nonlocal_attnrK   with_decoder_nonlocal_attnr   rL   rM   rN   r   r   c                   s
  t    t fdd|D rtdt|t|kr tdt|tr,t|t|}t|t|kr8tdt||||| |
|||||d| _	t
||||| |
|||||||d| _t|||ddd	d
d| _t|||ddd	d
d| _t|||ddd	d
d| _|| _d S )Nc                 3  s    | ]	}|  d kV  qdS )r   Nr2   ).0out_channelr   r2   r3   	<genexpr>b  s    z.SPADEAutoencoderKL.__init__.<locals>.<genexpr>zISPADEAutoencoderKL expects all channels being multiple of norm_num_groupszGSPADEAutoencoderKL expects channels being same size of attention_levelszf`num_res_blocks` should be a single integer or a tuple of integers with the same length as `channels`.)r   r   rE   r   rG   r   r   rH   rJ   rL   rM   rN   )r   rE   r   r   rG   r   r   rH   r   rJ   r   rL   rM   rN   r!   r   Tr"   )r&   r'   any
ValueErrorr[   rg   r   r	   r   encoderrC   decoderr   quant_conv_muquant_conv_log_sigmapost_quant_convrq   )r/   r   r   r   r   rG   rE   rH   rq   r   r   rr   rs   r   rL   rM   rN   r0   rv   r3   r'   L  s   

		
	zSPADEAutoencoderKL.__init__r4   r5   !tuple[torch.Tensor, torch.Tensor]c                 C  sB   |  |}| |}| |}t|dd}t|d }||fS )z
        Forwards an image through the spatial encoder, obtaining the latent mean and sigma representations.

        Args:
            x: BxCx[SPATIAL DIMS] tensor

        g      >g      4@rj   )rz   r|   r}   torchclampexp)r/   r4   r:   z_mu	z_log_varz_sigmar2   r2   r3   encode  s   


zSPADEAutoencoderKL.encoder   r   c                 C  s   t |}|||  }|S )aE  
        From the mean and sigma representations resulting of encoding an image through the latent space,
        obtains a noise sample resulting from sampling gaussian noise, multiplying by the variance (sigma) and
        adding the mean.

        Args:
            z_mu: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] mean vector obtained by the encoder when you encode an image
            z_sigma: Bx[Z_CHANNELS]x[LATENT SPACE SIZE] variance vector obtained by the encoder when you encode an image

        Returns:
            sample of shape Bx[Z_CHANNELS]x[LATENT SPACE SIZE]
        )r   
randn_like)r/   r   r   r   z_vaer2   r2   r3   sampling  s   
zSPADEAutoencoderKL.samplingr6   c                 C  s   |  |\}}| ||}|S )a  
        Encodes and decodes an input image.

        Args:
            x: BxCx[SPATIAL DIMENSIONS] tensor.
            seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
        Returns:
            reconstructed image, of the same shape as input
        )r   decode)r/   r4   r6   r   rf   reconstructionr2   r2   r3   reconstruct  s   
zSPADEAutoencoderKL.reconstructzc                 C  s   |  |}| ||}|S )a!  
        Based on a latent space sample, forwards it through the Decoder.

        Args:
            z: Bx[Z_CHANNELS]x[LATENT SPACE SHAPE]
            seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
        Returns:
            decoded image tensor
        )r~   r{   )r/   r   r6   decr2   r2   r3   r     s   

zSPADEAutoencoderKL.decode/tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s0   |  |\}}| ||}| ||}|||fS r7   )r   r   r   )r/   r4   r6   r   r   r   r   r2   r2   r3   r;     s   
zSPADEAutoencoderKL.forwardc                 C  s   |  |\}}| ||}|S r7   )r   r   )r/   r4   r   r   r   r2   r2   r3   encode_stage_2_inputs  s   z(SPADEAutoencoderKL.encode_stage_2_inputsc                 C  s   |  ||}|S r7   )r   )r/   r   r6   imager2   r2   r3   decode_stage_2_outputs  s   z)SPADEAutoencoderKL.decode_stage_2_outputs)r!   r!   ri   rk   rn   r   rl   ro   TTrD   TFF)"r   r   r   r   r   r   r   r   rG   rp   rE   rF   rH   rI   rq   r   r   r   r   r   rr   rK   rs   rK   r   r   rL   rK   rM   rK   rN   rK   r   r   )r4   r5   r   r   )r   r5   r   r5   r   r5   r<   )r   r5   r6   r5   r   r5   )r4   r5   r6   r5   r   r   )r4   r5   r   r5   )r>   r?   r@   rA   r'   r   r   r   r   r;   r   r   rB   r2   r2   r0   r3   r
   5  s0    
`




)
__future__r   collections.abcr   r   torch.nnr-   torch.nn.functional
functionalr8   monai.networks.blocksr   r   r   Z monai.networks.blocks.spade_normr   Z!monai.networks.nets.autoencoderklr   monai.utilsr	   __all__Moduler   rC   r
   r2   r2   r2   r3   <module>   s   _ ;