o
    -i                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m
Z
mZ d dlmZmZ d dlmZ d dlmZmZ G dd	 d	ejZdS )
    )annotations)SequenceN)ConvolutionResidualUnit)ActNorm)Reshape)ensure_tupleensure_tuple_repc                      sH   e Zd ZdZddejejddfd! fddZd"ddZ	d#dd Z
  ZS )$	GeneratoraV	  
    Defines a simple generator network accepting a latent vector and through a sequence of convolution layers
    constructs an output tensor of greater size and high dimensionality. The method `_get_layer` is used to
    create each of these layers, override this method to define layers beyond the default
    :py:class:`monai.networks.blocks.Convolution` or :py:class:`monai.networks.blocks.ResidualUnit` layers.

    The layers are constructed using the values in the `channels` and `strides` arguments, the number being defined by
    the length of these (which must match). Input is first passed through a :py:class:`torch.nn.Linear` layer to
    convert the input vector to an image tensor with dimensions `start_shape`. This passes through the convolution
    layers and is progressively upsampled if the `strides` values are greater than 1 using transpose convolutions. The
    size of the final output is defined by the `start_shape` dimension and the amount of upsampling done through
    strides. In the default definition the size of the output's spatial dimensions will be that of `start_shape`
    multiplied by the product of `strides`, thus the example network below upsamples an starting size of (64, 8, 8)
    to (1, 64, 64) since its `strides` are (2, 2, 2).

    Args:
        latent_shape: tuple of integers stating the dimension of the input latent vector (minus batch dimension)
        start_shape: tuple of integers stating the dimension of the tensor to pass to convolution subnetwork
        channels: tuple of integers stating the output channels of each convolutional layer
        strides: tuple of integers stating the stride (upscale factor) of each convolutional layer
        kernel_size: integer or tuple of integers stating size of convolutional kernels
        num_res_units: integer stating number of convolutions in residual units, 0 means no residual units
        act: name or type defining activation layers
        norm: name or type defining normalization layers
        dropout: optional float value in range [0, 1] stating dropout probability for layers, None for no dropout
        bias: boolean stating if convolution layers should have a bias component

    Examples::

        # 3 layers, latent input vector of shape (42, 24), output volume of shape (1, 64, 64)
        net = Generator((42, 24), (64, 8, 8), (32, 16, 1), (2, 2, 2))

          NTlatent_shapeSequence[int]start_shapechannelsstrideskernel_sizeSequence[int] | intnum_res_unitsintdropoutfloat | NonebiasboolreturnNonec                   s  t    t|^| _| _t| j| _t|| _t|| _t|| _	t
|| j| _|| _|| _|| _|	| _|
| _t | _ttt| jtt|| _t| | _t | _| j}tt||D ]!\}\}}|t|d k}| ||||}| j d| | |}qcd S )N   zlayer_%i)!super__init__r	   in_channelsr   len
dimensionsr   r   r   r
   r   r   actnormr   r   nnFlattenflattenLinearr   npprodlinearr   reshape
Sequentialconv	enumeratezip
_get_layer
add_module)selfr   r   r   r   r   r   r#   r$   r   r   echannelicsis_lastlayer	__class__ _/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/generator.pyr   =   s.   




$

zGenerator.__init__r    out_channelsr8   Convolution | nn.Sequentialc                 C  sx   t ||d|p
| jdk| j|| j| j| j| j| jd}| jdkr:t|| j|| j|| j| j| j| j| jd
}t	
||}|S )ad  
        Returns a layer accepting inputs with `in_channels` number of channels and producing outputs of `out_channels`
        number of channels. The `strides` indicates upsampling factor, ie. transpose convolutional stride. If `is_last`
        is True this is the final layer and is not expected to include activation and normalization layers.
        Tr   )r    r   is_transposed	conv_onlyspatial_dimsr>   r   r#   r$   r   r   )
r    subunitslast_conv_onlyrB   r>   r   r#   r$   r   r   )r   r   r"   r   r#   r$   r   r   r   r%   r-   )r3   r    r>   r   r8   r9   rur<   r<   r=   r1   g   s8   
zGenerator._get_layerxtorch.Tensorc                 C  s,   |  |}| |}| |}| |}|S )N)r'   r+   r,   r.   )r3   rF   r<   r<   r=   forward   s
   



zGenerator.forward)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )
r    r   r>   r   r   r   r8   r   r   r?   )rF   rG   r   rG   )__name__
__module____qualname____doc__r   PRELUr   INSTANCEr   r1   rH   __classcell__r<   r<   r:   r=   r      s    (
*+r   )
__future__r   collections.abcr   numpyr)   torchtorch.nnr%   monai.networks.blocksr   r   monai.networks.layers.factoriesr   r   "monai.networks.layers.simplelayersr   monai.utilsr	   r
   Moduler   r<   r<   r<   r=   <module>   s   