o
    -iG                     @  s   d dl mZ d dlmZ d dlmZ d dlZd dlmZ d dl	m
Z
 d dlmZ d dlmZmZ d dlmZ d	gZG d
d dejZG dd dejZG dd dejZG dd	 d	ejZdS )    )annotations)Sequence)TupleN)Convolution)Act)EMAQuantizerVectorQuantizer)ensure_tuple_repVQVAEc                      s4   e Zd ZdZejddfd fddZdd Z  ZS )VQVAEResidualUnita  
    Implementation of the ResidualLayer used in the VQVAE network as originally used in Morphology-preserving
    Autoregressive 3D Generative Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf).

    The original implementation that can be found at
    https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L150.

    Args:
        spatial_dims: number of spatial spatial_dims of the input data.
        in_channels: number of input channels.
        num_res_channels: number of channels in the residual layers.
        act: activation type and arguments. Defaults to RELU.
        dropout: dropout ratio. Defaults to no dropout.
        bias: whether to have a bias term. Defaults to True.
            Tspatial_dimsintin_channelsnum_res_channelsacttuple | str | NonedropoutfloatbiasboolreturnNonec              	     sr   t    || _|| _|| _|| _|| _|| _t| j| j| jd| j| j| jd| _	t| j| j| j| jdd| _
d S )NDA)r   r   out_channelsadn_orderingr   r   r   T)r   r   r   r   	conv_only)super__init__r   r   r   r   r   r   r   conv1conv2)selfr   r   r   r   r   r   	__class__ [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/vqvae.pyr   -   s.   
	
zVQVAEResidualUnit.__init__c                 C  s    t jj|| | | dS )NT)torchnn
functionalrelur    r   )r!   xr$   r$   r%   forwardQ   s    zVQVAEResidualUnit.forward)r   r   r   r   r   r   r   r   r   r   r   r   r   r   )	__name__
__module____qualname____doc__r   RELUr   r+   __classcell__r$   r$   r"   r%   r      s    $r   c                      s,   e Zd ZdZd fddZdddZ  ZS )Encodera  
    Encoder module for VQ-VAE.

    Args:
        spatial_dims: number of spatial spatial_dims.
        in_channels: number of input channels.
        out_channels: number of channels in the latent space (embedding_dim).
        channels: sequence containing the number of channels at each level of the encoder.
        num_res_layers: number of sequential residual layers at each level.
        num_res_channels: number of channels in the residual layers at each level.
        downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
            following information stride (int), kernel_size (int), dilation (int) and padding (int).
        dropout: dropout ratio.
        act: activation type and arguments.
    r   r   r   r   channelsSequence[int]num_res_layersr   downsample_parameters#Sequence[Tuple[int, int, int, int]]r   r   r   r   r   r   c
                   sN  t    || _|| _|| _|| _|| _|| _|| _|| _	|	| _
g }
tt| jD ]]}|
t| j|dkr7| jn| j|d  | j| | j| d | j| d d| j
|dkrVd n| j	d| j| d | j| d d t| jD ]}|
t| j| j| | j| | j
| j	d qnq)|
t| j| jt| jd  | jddddd	 t|
| _d S )
Nr      r         )r   r   r   strideskernel_sizer   r   r   dropout_dimdilationpaddingr   r   r   r   r   Tr   r   r   r;   r<   r?   r   )r   r   r   r   r   r3   r5   r   r6   r   r   rangelenappendr   r   r'   
ModuleListblocks)r!   r   r   r   r3   r5   r   r6   r   r   rF   i_r"   r$   r%   r   f   sd   
zEncoder.__init__r*   torch.Tensorc                 C     | j D ]}||}q|S NrF   r!   r*   blockr$   r$   r%   r+         

zEncoder.forward)r   r   r   r   r   r   r3   r4   r5   r   r   r4   r6   r7   r   r   r   r   r   r   r*   rI   r   rI   r,   r-   r.   r/   r   r+   r1   r$   r$   r"   r%   r2   U   s    Cr2   c                      s,   e Zd ZdZd fddZdddZ  ZS )DecoderaW  
    Decoder module for VQ-VAE.

    Args:
        spatial_dims: number of spatial spatial_dims.
        in_channels: number of channels in the latent space (embedding_dim).
        out_channels: number of output channels.
        channels: sequence containing the number of channels at each level of the decoder.
        num_res_layers: number of sequential residual layers at each level.
        num_res_channels: number of channels in the residual layers at each level.
        upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
            following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
        dropout: dropout ratio.
        act: activation type and arguments.
        output_act: activation type and arguments for the output.
    r   r   r   r   r3   r4   r5   r   upsample_parameters(Sequence[Tuple[int, int, int, int, int]]r   r   r   r   
output_actr   r   c                   s  t    || _|| _|| _|| _|| _|| _|| _|| _	|	| _
|
| _tt| j}g }|t| j| j|d ddddd tt| j}tt| jD ]r}t| jD ]}|t| j|| || | j
| j	d qS|t| j|| |t| jd kr}| jn||d  | j| d | j| d d| j
|t| jd kr| j	nd d | j| d |t| jd kd| j| d | j| d	 d
 qL| jr|t| j   t|| _d S )Nr   r8   r:   TrA   r@   r   r9      )r   r   r   r;   r<   r   r   r   normr>   r   is_transposedr?   output_padding)r   r   r   r   r   r3   r5   r   rS   r   r   rU   listreversedrD   r   rB   rC   r   r   r'   rE   rF   )r!   r   r   r   r3   r5   r   rS   r   r   rU   Zreversed_num_channelsrF   Zreversed_num_res_channelsrG   rH   r"   r$   r%   r      sr   

"zDecoder.__init__r*   rI   c                 C  rJ   rK   rL   rM   r$   r$   r%   r+     rO   zDecoder.forward)r   r   r   r   r   r   r3   r4   r5   r   r   r4   rS   rT   r   r   r   r   rU   r   r   r   rP   rQ   r$   r$   r"   r%   rR      s    MrR   c                      s   e Zd ZdZddddddddd	d
ddejdddfdG fd,d-ZdHd1d2ZdId5d6ZdJd8d9Z	dHd:d;Z
dKd=d>ZdLd?d@ZdMdBdCZdNdEdFZ  ZS )Or
   a  
    Vector-Quantised Variational Autoencoder (VQ-VAE) used in Morphology-preserving Autoregressive 3D Generative
    Modelling of the Brain by Tudosiu et al. (https://arxiv.org/pdf/2209.03177.pdf)

    The original implementation can be found at
    https://github.com/AmigoLab/SynthAnatomy/blob/main/src/networks/vqvae/baseline.py#L163/

    Args:
        spatial_dims: number of spatial spatial_dims.
        in_channels: number of input channels.
        out_channels: number of output channels.
        downsample_parameters: A Tuple of Tuples for defining the downsampling convolutions. Each Tuple should hold the
            following information stride (int), kernel_size (int), dilation (int) and padding (int).
        upsample_parameters: A Tuple of Tuples for defining the upsampling convolutions. Each Tuple should hold the
            following information stride (int), kernel_size (int), dilation (int), padding (int), output_padding (int).
        num_res_layers: number of sequential residual layers at each level.
        channels: number of channels at each level.
        num_res_channels: number of channels in the residual layers at each level.
        num_embeddings: VectorQuantization number of atomic elements in the codebook.
        embedding_dim: VectorQuantization number of channels of the input and atomic elements.
        commitment_cost: VectorQuantization commitment_cost.
        decay: VectorQuantization decay.
        epsilon: VectorQuantization epsilon.
        act: activation type and arguments.
        dropout: dropout ratio.
        output_act: activation type and arguments for the output.
        ddp_sync: whether to synchronize the codebook across processes.
        use_checkpointing if True, use activation checkpointing to save memory.
    )`   r\      r:   )r9   rV   r8   r8   r^   r^   )r9   rV   r8   r8   r   r_   r_       @   normalg      ?g      ?gh㈵>r   NTFr   r   r   r   r3   r4   r5   r   Sequence[int] | intr6   ?Sequence[Tuple[int, int, int, int]] | Tuple[int, int, int, int]rS   ISequence[Tuple[int, int, int, int, int]] | Tuple[int, int, int, int, int]num_embeddingsembedding_dimembedding_initstrcommitment_costr   decayepsilonr   r   r   rU   ddp_syncr   use_checkpointingc                   s  t    || _|| _|| _|| _|	| _|
| _|| _t	|t
r&t|t|}t|t|kr2tdtdd |D rC|ft| }n|}tdd |D rV|ft| }n|}tdd |D setdtdd |D srtd|D ]}t|d	krtd
qt|D ]}t|dkrtdqt|t|krtdt|t|krtd|| _|| _t|||
||||||d	| _t||
||||||||d
| _tt||	|
|||||dd| _d S )Nzk`num_res_channels` should be a single integer or a tuple of integers with the same length as `num_channls`.c                 s      | ]}t |tV  qd S rK   
isinstancer   .0valuesr$   r$   r%   	<genexpr>c      z!VQVAE.__init__.<locals>.<genexpr>c                 s  ro   rK   rp   rr   r$   r$   r%   ru   h  rv   c                 s  "    | ]}t d d |D V  qdS )c                 s  ro   rK   rp   rs   valuer$   r$   r%   ru   m  rv   +VQVAE.__init__.<locals>.<genexpr>.<genexpr>Nallrs   Zsub_itemr$   r$   r%   ru   m       zQ`downsample_parameters` should be a single tuple of integer or a tuple of tuples.c                 s  rw   )c                 s  ro   rK   rp   rx   r$   r$   r%   ru   q  rv   rz   Nr{   r}   r$   r$   r%   ru   q  r~   zO`upsample_parameters` should be a single tuple of integer or a tuple of tuples.rV   zD`downsample_parameters` should be a tuple of tuples with 4 integers.   zB`upsample_parameters` should be a tuple of tuples with 5 integers.z[`downsample_parameters` should be a tuple of tuples with the same length as `num_channels`.zY`upsample_parameters` should be a tuple of tuples with the same length as `num_channels`.)	r   r   r   r3   r5   r   r6   r   r   )
r   r   r   r3   r5   r   rS   r   r   rU   )r   rf   rg   rj   rk   rl   rh   rm   	quantizer)r   r   r   r   r   r3   rf   rg   rn   rq   r   r	   rC   
ValueErrorr|   r5   r   r2   encoderrR   decoderr   r   r   )r!   r   r   r   r3   r5   r   r6   rS   rf   rg   rh   rj   rk   rl   r   r   rU   rm   rn   Zupsample_parameters_tupleZdownsample_parameters_tuple	parameterr"   r$   r%   r   3  s   

zVQVAE.__init__imagesrI   r   c                 C  .   | j rtjjj| j|dd}|S | |}|S NF)use_reentrant)rn   r&   utils
checkpointr   )r!   r   outputr$   r$   r%   encode  s
   
zVQVAE.encode	encodings!tuple[torch.Tensor, torch.Tensor]c                 C  s   |  |\}}||fS rK   r   )r!   r   Zx_lossr*   r$   r$   r%   quantize  s   zVQVAE.quantizequantizationsc                 C  r   r   )rn   r&   r   r   r   )r!   r   r   r$   r$   r%   decode  s
   
zVQVAE.decodec                 C  s   | j | j|dS )N)r   )r   r   r   )r!   r   r$   r$   r%   index_quantize  s   zVQVAE.index_quantizeembedding_indicesc                 C  s   |  | j|S rK   )r   r   embed)r!   r   r$   r$   r%   decode_samples  s   zVQVAE.decode_samplesc                 C  s&   |  | |\}}| |}||fS rK   )r   r   r   )r!   r   r   Zquantization_lossesreconstructionr$   r$   r%   r+     s   
zVQVAE.forwardr*   c                 C  s   |  |}| |\}}|S rK   )r   r   )r!   r*   zerH   r$   r$   r%   encode_stage_2_inputs  s   
zVQVAE.encode_stage_2_inputsr   c                 C  s   |  |\}}| |}|S rK   )r   r   )r!   r   r   rH   imager$   r$   r%   decode_stage_2_outputs  s   
zVQVAE.decode_stage_2_outputs)&r   r   r   r   r   r   r3   r4   r5   r   r   rc   r6   rd   rS   re   rf   r   rg   r   rh   ri   rj   r   rk   r   rl   r   r   r   r   r   rU   r   rm   r   rn   r   )r   rI   r   rI   )r   rI   r   r   )r   rI   r   rI   )r   rI   r   rI   )r   rI   r   r   rP   )r   rI   r   rI   )r,   r-   r.   r/   r   r0   r   r   r   r   r   r   r+   r   r   r1   r$   r$   r"   r%   r
     s6    #
|


	


)
__future__r   collections.abcr   typingr   r&   torch.nnr'   monai.networks.blocksr   monai.networks.layersr   Z&monai.networks.layers.vector_quantizerr   r   monai.utilsr	   __all__Moduler   r2   rR   r
   r$   r$   r$   r%   <module>   s   9Ze