U
    Ph                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m
Z
mZ d dlmZmZ d dlmZ d dlmZmZ G dd	 d	ejZdS )
    )annotations)SequenceN)ConvolutionResidualUnit)ActNorm)Reshape)ensure_tupleensure_tuple_repc                      sn   e Zd ZdZddejejddfddddddd	d
dd	 fddZdddd
ddddZ	dddddZ
  ZS )	GeneratoraV	  
    Defines a simple generator network accepting a latent vector and through a sequence of convolution layers
    constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to
    create each of these layers, override this method to define layers beyond the default
    :py:class:`monai.networks.blocks.Convolution` or :py:class:`monai.networks.blocks.ResidualUnit` layers.

    The layers are constructed using the values in the `channels` and `strides` arguments, the number being defined by
    the length of these (which must match). Input is first passed through a :py:class:`torch.nn.Linear` layer to
    convert the input vector to an image tensor with dimensions `start_shape`. This passes through the convolution
    layers and is progressively upsampled if the `strides` values are greater than 1 using transpose convolutions. The
    size of the final output is defined by the `start_shape` dimension and the amount of upsampling done through
    strides. In the default definition the size of the output's spatial dimensions will be that of `start_shape`
    multiplied by the product of `strides`, thus the example network below upsamples an starting size of (64, 8, 8)
    to (1, 64, 64) since its `strides` are (2, 2, 2).

    Args:
        latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension)
        start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork
        channels: tuple of integers stating the output channels of each convolutional layer
        strides: tuple of integers stating the stride (upscale factor) of each convolutional layer
        kernel_size: integer or tuple of integers stating size of convolutional kernels
        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
        act: name or type defining activation layers
        norm: name or type defining normalization layers
        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
        bias: boolean stating if convolution layers should have a bias component

    Examples::

        # 3 layers, latent input vector of shape (42, 24), output volume of shape (1, 64, 64)
        net = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2))

          NTzSequence[int]zSequence[int] | intintzfloat | NoneboolNone)	latent_shapestart_shapechannelsstrideskernel_sizenum_res_unitsdropoutbiasreturnc                   s  t    t|^| _| _t| j| _t|| _t|| _t|| _	t
|| j| _|| _|| _|| _|	| _|
| _t | _ttt| jtt|| _t| | _t | _| j}tt||D ]B\}\}}|t|d k}| ||||}| j d| | |}qd S )N   zlayer_%i)!super__init__r	   in_channelsr   len
dimensionsr   r   r   r
   r   r   actnormr   r   nnFlattenflattenLinearr   npprodlinearr   reshape
Sequentialconv	enumeratezip
_get_layer
add_module)selfr   r   r   r   r   r   r    r!   r   r   Zechannelicsis_lastlayer	__class__ R/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/generator.pyr   =   s,    




$

zGenerator.__init__zConvolution | nn.Sequential)r   out_channelsr   r4   r   c                 C  sx   t ||d|p| jdk| j|| j| j| j| j| jd}| jdkrtt|| j|| j|| j| j| j| j| jd
}t	
||}|S )ad  
        Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels`
        number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last`
        is True this is the final layer and is not expected to include activation and normalization layers.
        Tr   )r   r   is_transposed	conv_onlyspatial_dimsr:   r   r    r!   r   r   )
r   subunitslast_conv_onlyr=   r:   r   r    r!   r   r   )r   r   r   r   r    r!   r   r   r   r"   r*   )r0   r   r:   r   r4   r5   rur8   r8   r9   r.   g   s8    
zGenerator._get_layerztorch.Tensor)xr   c                 C  s,   |  |}| |}| |}| |}|S )N)r$   r(   r)   r+   )r0   rA   r8   r8   r9   forward   s
    



zGenerator.forward)__name__
__module____qualname____doc__r   PRELUr   INSTANCEr   r.   rB   __classcell__r8   r8   r6   r9   r      s   ($*+r   )
__future__r   collections.abcr   numpyr&   torchtorch.nnr"   Zmonai.networks.blocksr   r   monai.networks.layers.factoriesr   r   Z"monai.networks.layers.simplelayersr   monai.utilsr	   r
   Moduler   r8   r8   r8   r9   <module>   s   