o
    &iH'                     @  sb   d dl mZ d dlmZmZ d dlZd dlmZ ddgZG dd dejZ	G dd dejjZ
dS )	    )annotations)SequenceTupleN)nnVectorQuantizerEMAQuantizerc                      sV   e Zd ZdZ					d%d& fddZd'ddZd(ddZd)d!d"Zd'd#d$Z  Z	S )*r   a9  
    Vector Quantization module using Exponential Moving Average (EMA) to learn the codebook parameters based on  Neural
    Discrete Representation Learning by Oord et al. (https://arxiv.org/abs/1711.00937) and the official implementation
    that can be found at https://github.com/deepmind/sonnet/blob/v2/sonnet/src/nets/vqvae.py#L148 and commit
    58d9a2746493717a7c9252938da7efa6006f3739.

    This module is not compatible with TorchScript while working in a Distributed Data Parallelism Module. This is due
    to lack of TorchScript support for torch.distributed module as per https://github.com/pytorch/pytorch/issues/41353
    on 22/10/2022. If you want to TorchScript your model, please turn set `ddp_sync` to False.

    Args:
        spatial_dims: number of spatial dimensions of the input.
        num_embeddings: number of atomic elements in the codebook.
        embedding_dim: number of channels of the input and atomic elements.
        commitment_cost: scaling factor of the MSE loss between input and its quantized version. Defaults to 0.25.
        decay: EMA decay. Defaults to 0.99.
        epsilon: epsilon value. Defaults to 1e-5.
        embedding_init: initialization method for the codebook. Defaults to "normal".
        ddp_sync: whether to synchronize the codebook across processes. Defaults to True.
          ?Gz?h㈵>normalTspatial_dimsintnum_embeddingsembedding_dimcommitment_costfloatdecayepsilonembedding_initstrddp_syncboolc	           	        s  t    || _|| _|| _| jdv sJ td| dtj| j| j| _	|dkr,n|dkr=tjj
j| j	jjddd d	| j	j_|| _| d
t| j | d| j	jj  |  |  || _|| _|| _dgttd| jd  dg | _d| jd gttd| jd  | _d S )N)      zMEMAQuantizer only supports 4D and 5D tensor inputs but received spatial dims .r   kaiming_uniformfan_inlinear)modenonlinearityFema_cluster_sizeema_wr   r      )super__init__r   r   r   
ValueErrortorchr   	Embedding	embeddinginitkaiming_uniform_weightdatarequires_gradr   register_bufferzeroscloner   r   r   listrangeflatten_permutationquantization_permutation)	selfr   r   r   r   r   r   r   r   	__class__ h/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/layers/vector_quantizer.pyr$   ,   s2   


"zEMAQuantizer.__init__inputstorch.Tensorreturn/Tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s   t jdddb t|j}|d= | }|| j d| j	}|d j
ddd| jj d j
d	dd dt || jj   }t j| dd
d }t jj|| j }||}|||fW  d   S 1 slw   Y  dS )a  
        Given an input it projects it to the quantized space and returns additional tensors needed for EMA loss.

        Args:
            inputs: Encoding space tensors of shape [B, C, H, W, D].

        Returns:
            torch.Tensor: Flatten version of the input of shape [B*H*W*D, C].
            torch.Tensor: One-hot representation of the quantization indices of shape [B*H*W*D, self.num_embeddings].
            torch.Tensor: Quantization indices of shape [B,H,W,D,1]

        cudaFenabledr"   r   T)dimkeepdimr   )rB   N)r&   autocastr1   shaper   permuter3   
contiguousviewr   sumr(   r+   tmmmaxr   
functionalone_hotr   )r5   r:   Zencoding_indices_view
flat_input	distancesencoding_indices	encodingsr8   r8   r9   quantizeZ   s    

$zEMAQuantizer.quantizeembedding_indicesc                 C  sL   t jddd | || j }|W  d   S 1 sw   Y  dS )a  
        Given encoding indices of shape [B,D,H,W,1] embeds them in the quantized space
        [B, D, H, W, self.embedding_dim] and reshapes them to [B, self.embedding_dim, D, H, W] to be fed to the
        decoder.

        Args:
            embedding_indices: Tensor in channel last format which holds indices referencing atomic
                elements from self.embedding

        Returns:
            torch.Tensor: Quantize space representation of encoding_indices in channel first format.
        r>   Fr?   N)r&   rD   r(   rF   r4   rG   )r5   rT   r(   r8   r8   r9   embed   s
   $zEMAQuantizer.embedencodings_sumdwNonec                 C  sF   | j r tj r tjj|tjjjd tjj|tjjjd dS 	 dS )a'  
        TorchScript does not support torch.distributed.all_reduce. This function is a bypassing trick based on the
        example: https://pytorch.org/docs/stable/generated/torch.jit.unused.html#torch.jit.unused

        Args:
            encodings_sum: The summation of one hot representation of what encoding was used for each
                position.
            dw: The multiplication of the one hot representation of what encoding was used for each
                position with the flattened input.

        Returns:
            None
        )tensoropN)r   r&   distributedis_initialized
all_reduceReduceOpSUM)r5   rV   rW   r8   r8   r9   distributed_synchronization   s   z(EMAQuantizer.distributed_synchronizationc                 C  s6  |  |\}}}| |}| jrt e |d}t| |}| jr+| 	|| | j
j| jt|d| j  | j
 }| j
| j || j| j   | }	| jj| jt|d| j  | jjj| j|	d  W d    n1 s|w   Y  | jtjj| | }
|||   }||
|fS )Nr   r"   )rS   rU   trainingr&   no_gradrI   rK   rJ   r   r`   r    r,   mul_r   add_mulr   r   r!   r(   r+   copy_	unsqueezer   r   rM   mse_lossdetach)r5   r:   rO   rR   rQ   	quantizedrV   rW   nweightslossr8   r8   r9   forward   s"   


$
 $
zEMAQuantizer.forward)r   r	   r
   r   T)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r:   r;   r<   r=   rT   r;   r<   r;   )rV   r;   rW   r;   r<   rX   )
__name__
__module____qualname____doc__r$   rS   rU   r`   rn   __classcell__r8   r8   r6   r9   r      s    
.
&
c                      s@   e Zd ZdZd fddZdd
dZdddZdddZ  ZS )r   aU  
    Vector Quantization wrapper that is needed as a workaround for the AMP to isolate the non fp16 compatible parts of
    the quantization in their own class.

    Args:
        quantizer (torch.nn.Module):  Quantizer module that needs to return its quantized representation, loss and index
            based quantized representation.
    	quantizerr   c                   s    t    || _td| _d S )Nr"   )r#   r$   ru   r&   rand
perplexity)r5   ru   r6   r8   r9   r$      s   
zVectorQuantizer.__init__r:   r;   r<   !Tuple[torch.Tensor, torch.Tensor]c              	   C  sd   |  |\}}}tj| | j j| j jd | }tt|t	|d   | _
||fS )N)binsrL   g|=)ru   r&   histcr   r   divnumelexprI   logrw   )r5   r:   rj   rm   rQ   Z	avg_probsr8   r8   r9   rn      s   
"zVectorQuantizer.forwardrT   c                 C  s   | j j|dS )N)rT   )ru   rU   )r5   rT   r8   r8   r9   rU      s   zVectorQuantizer.embedrR   c                 C  s   |  |}|d }|S )Nr   )ru   )r5   rR   outputrQ   r8   r8   r9   rS      s   
zVectorQuantizer.quantize)ru   r   )r:   r;   r<   rx   ro   )rR   r;   r<   r;   )	rp   rq   rr   rs   r$   rn   rU   rS   rt   r8   r8   r6   r9   r      s    	

)
__future__r   typingr   r   r&   r   __all__Moduler   r   r8   r8   r8   r9   <module>   s    0