o
    -i                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dlm	Z
 d dlmZmZ d dlmZmZ d dlmZ dgZG d	d deZdS )
    )annotations)SequenceN)
functional)calculate_out_shapesame_padding)ActNorm)AutoEncoderVarAutoEncoderc                      sh   e Zd ZdZddddddejejdddfd3 fd"d#Zd4d'd(Z	d5d6d*d+Z
d7d.d/Zd8d1d2Z  ZS )9r
   a	  
    Variational Autoencoder based on the paper - https://arxiv.org/abs/1312.6114

    Args:
        spatial_dims: number of spatial dimensions.
        in_shape: shape of input data starting with channel dimension.
        out_channels: number of output channels.
        latent_size: size of the latent variable.
        channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.
        strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.
        kernel_size: convolution kernel size, the value(s) should be odd. If sequence,
            its length should equal to dimensions. Defaults to 3.
        up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,
            its length should equal to dimensions. Defaults to 3.
        num_res_units: number of residual units. Defaults to 0.
        inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode.
        inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1.
        num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0.
        act: activation type and arguments. Defaults to PReLU.
        norm: feature normalization type and arguments. Defaults to instance norm.
        dropout: dropout ratio. Defaults to no dropout.
        bias: whether to have a bias term in convolution blocks. Defaults to True.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.
        use_sigmoid: whether to use the sigmoid function on final output. Defaults to True.

    Examples::

        from monai.networks.nets import VarAutoEncoder

        # 3 layer network accepting images with dimensions (1, 32, 32) and using a latent vector with 2 values
        model = VarAutoEncoder(
            spatial_dims=2,
            in_shape=(32, 32),  # image spatial shape
            out_channels=1,
            latent_size=2,
            channels=(16, 32, 64),
            strides=(1, 2, 2),
        )

    see also:
        - Variational autoencoder network with MedNIST Dataset
          https://github.com/Project-MONAI/tutorials/blob/master/modules/varautoencoder_mednist.ipynb
       r   N   Tspatial_dimsintin_shapeSequence[int]out_channelslatent_sizechannelsstrideskernel_sizeSequence[int] | intup_kernel_sizenum_res_unitsinter_channelslist | Noneinter_dilationsnum_inter_unitsacttuple | str | Nonenormtuple | strdropouttuple | str | float | Nonebiasbooluse_sigmoidreturnNonec                   s   |^| _ | _|| _|| _tj| jtd| _t 	|| j ||||||	|
|||||| t
| j}|D ]}t| j| j||| _q1tt| j| j }t|| j| _t|| j| _t| j|| _d S )N)dtype)in_channelsr   r%   r   npasarrayr   
final_sizesuper__init__r   r   r   prodencoded_channelsnnLinearmulogvardecodeL)selfr   r   r   r   r   r   r   r   r   r   r   r   r   r   r!   r#   r%   paddingsZlinear_size	__class__ d/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/varautoencoder.pyr.   J   s8   
zVarAutoEncoder.__init__xtorch.Tensor!tuple[torch.Tensor, torch.Tensor]c                 C  sB   |  |}| |}||jd d}| |}| |}||fS Nr   )encodeintermediateviewshaper3   r4   )r6   r=   r3   r4   r;   r;   r<   encode_forward   s   



zVarAutoEncoder.encode_forwardzc                 C  sN   t | |}|j|jd | jd g| jR  }| |}|r%t	|}|S r@   )
Frelur5   rD   rE   r   r,   decodetorchsigmoid)r6   rG   r%   r=   r;   r;   r<   decode_forward   s   "

zVarAutoEncoder.decode_forwardr3   r4   c                 C  s.   t d| }| jrt ||}||S )Ng      ?)rK   exptraining
randn_likemuladd_)r6   r3   r4   stdr;   r;   r<   reparameterize   s   
zVarAutoEncoder.reparameterize=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s0   |  |\}}| ||}| || j|||fS )N)rF   rT   rM   r%   )r6   r=   r3   r4   rG   r;   r;   r<   forward   s   zVarAutoEncoder.forward)$r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r$   r&   r'   )r=   r>   r&   r?   )T)rG   r>   r%   r$   r&   r>   )r3   r>   r4   r>   r&   r>   )r=   r>   r&   rU   )__name__
__module____qualname____doc__r   PRELUr   INSTANCEr.   rF   rM   rT   rV   __classcell__r;   r;   r9   r<   r
      s$    5
6
)
__future__r   collections.abcr   numpyr*   rK   torch.nnr1   r   rH   monai.networks.layers.convutilsr   r   monai.networks.layers.factoriesr   r   monai.networks.netsr	   __all__r
   r;   r;   r;   r<   <module>   s   