U
    Ph*1                     @  st   d dl mZ d dlmZ d dlmZ d dlZd dlmZ d dl	m
Z
mZ d dlmZmZ dgZG dd dejZdS )	    )annotations)Sequence)AnyN)ConvolutionResidualUnit)ActNormAutoEncoderc                      s   e Zd ZdZddddddejejdddfdddddd	d	dd
d
dddddddd fddZdddddddZ	ddddddZ
dddddddZddddddd d!Zddddd"dd#d$Zd%d&d'd(d)Z  ZS )*r	   a  
    Simple definition of an autoencoder and base class for the architecture implementing
    :py:class:`monai.networks.nets.VarAutoEncoder`. The network is composed of an encode sequence of blocks, followed
    by an intermediary sequence of blocks, and finally a decode sequence of blocks. The encode and decode blocks are
    default :py:class:`monai.networks.blocks.Convolution` instances with the encode blocks having the given stride
    and the decode blocks having transpose convolutions with the same stride. If `num_res_units` is given residual
    blocks are used instead.

    By default the intermediary sequence is empty but if `inter_channels` is given to specify the output channels of
    blocks then this will be become a sequence of Convolution blocks or of residual blocks if `num_inter_units` is
    given. The optional parameter `inter_dilations` can be used to specify the dilation values of the convolutions in
    these blocks, this allows a network to use dilated kernels in this  middle section. Since the intermediary section
    isn't meant to change the size of the output the strides for all these kernels is 1.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        channels: sequence of channels. Top block first. The length of `channels` should be no less than 2.
        strides: sequence of convolution strides. The length of `stride` should equal to `len(channels) - 1`.
        kernel_size: convolution kernel size, the value(s) should be odd. If sequence,
            its length should equal to dimensions. Defaults to 3.
        up_kernel_size: upsampling convolution kernel size, the value(s) should be odd. If sequence,
            its length should equal to dimensions. Defaults to 3.
        num_res_units: number of residual units. Defaults to 0.
        inter_channels: sequence of channels defining the blocks in the intermediate layer between encode and decode.
        inter_dilations: defines the dilation value for each block of the intermediate layer. Defaults to 1.
        num_inter_units: number of residual units for each block of the intermediate layer. Defaults to 0.
        act: activation type and arguments. Defaults to PReLU.
        norm: feature normalization type and arguments. Defaults to instance norm.
        dropout: dropout ratio. Defaults to no dropout.
        bias: whether to have a bias term in convolution blocks. Defaults to True.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.
        padding: controls the amount of implicit zero-paddings on both sides for padding number of points
            for each dimension in convolution blocks. Defaults to None.

    Examples::

        from monai.networks.nets import AutoEncoder

        # 3 layers each down/up sampling their inputs by a factor 2 with no intermediate layer
        net = AutoEncoder(
            spatial_dims=2,
            in_channels=1,
            out_channels=1,
            channels=(2, 4, 8),
            strides=(2, 2, 2)
        )

        # 1 layer downsampling by 2, followed by a sequence of residual units with 2 convolutions defined by
        # progressively increasing dilations, then final upsample layer
        net = AutoEncoder(
                spatial_dims=2,
                in_channels=1,
                out_channels=1,
                channels=(4,),
                strides=(2,),
                inter_channels=(8, 8, 8),
                inter_dilations=(1, 2, 4),
                num_inter_units=2
            )

       r   N   TintzSequence[int]zSequence[int] | intzlist | Noneztuple | str | Noneztuple | strztuple | str | float | NoneboolzSequence[int] | int | NoneNone)spatial_dimsin_channelsout_channelschannelsstrideskernel_sizeup_kernel_sizenum_res_unitsinter_channelsinter_dilationsnum_inter_unitsactnormdropoutbiaspaddingreturnc                   s"  t    || _|| _|| _t|| _t|| _|| _|| _	|| _
|| _|| _|| _|| _|| _|| _|	d k	rr|	ng | _t|
pdgt| j | _t|t|krtd|| _t|dd d |g }| | j||\| _| _| | j|\| _| _| | j||d d d pdg\| _}d S )N   z;Autoencoder expects matching number of channels and strides)super__init__
dimensionsr   r   listr   r   r   r   r   r   r   r   r   r   r   r   lenr   
ValueErrorZencoded_channels_get_encode_moduleencode_get_intermediate_moduleintermediate_get_decode_moduledecode)selfr   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   Zdecode_channel_list_	__class__ T/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/autoencoder.pyr$   \   s0    


zAutoEncoder.__init__ztuple[nn.Sequential, int])r   r   r   r   c           
      C  sT   t  }|}tt||D ]0\}\}}| |||d}	|d| |	 |}q||fS )z}
        Returns the encode part of the network by building up a sequence of layers returned by `_get_encode_layer`.
        Fz	encode_%i)nn
Sequential	enumeratezip_get_encode_layer
add_module)
r/   r   r   r   r*   layer_channelsicslayerr3   r3   r4   r)      s    zAutoEncoder._get_encode_moduleztuple[nn.Module, int])r   r   r   c           	      C  s   t  }|}| jrt  }tt| j| jD ]\}\}}| jdkrtt| j	||d| j
| j| j| j| j|| j| jd}n,t| j	||d| j
| j| j| j|| j| jd}|d| | |}q,||fS )z
        Returns the intermediate block of the network which accepts input from the encoder and whose output goes
        to the decoder.
        r   r    )r   r   r   r   r   subunitsr   r   r   dilationr   r   )r   r   r   r   r   r   r   r   rA   r   r   zinter_%i)r5   Identityr   r6   r7   r8   r   r   r   r%   r   r   r   r   r   r   r   r:   )	r/   r   r   r,   r;   r<   dcdiunitr3   r3   r4   r+      sH    	
z$AutoEncoder._get_intermediate_modulec           
   	   C  s`   t  }|}tt||D ]<\}\}}| ||||t|d k}	|d| |	 |}q||fS )z}
        Returns the decode part of the network by building up a sequence of layers returned by `_get_decode_layer`.
        r    z	decode_%i)r5   r6   r7   r8   _get_decode_layerr'   r:   )
r/   r   r   r   r.   r;   r<   r=   r>   r?   r3   r3   r4   r-      s    zAutoEncoder._get_decode_modulez	nn.Module)r   r   r   is_lastr   c                 C  sn   | j dkr>t| j|||| j| j | j| j| j| j| j|d}|S t	| j|||| j| j| j| j| j| j|d}|S )zL
        Returns a single layer of the encoder part of the network.
        r   r   r   r   r   r   r@   r   r   r   r   r   last_conv_only)r   r   r   r   r   r   r   r   r   r   	conv_only)
r   r   r%   r   r   r   r   r   r   r   )r/   r   r   r   rG   modr3   r3   r4   r9      s<    
zAutoEncoder._get_encode_layerznn.Sequentialc                 C  s   t  }t| j|||| j| j| j| j| j| j	|o8| j
dkdd}|d| | j
dkrt| j||d| jd| j| j| j| j| j	|d}|d| |S )zL
        Returns a single layer of the decoder part of the network.
        r   T)r   r   r   r   r   r   r   r   r   r   rJ   is_transposedconvr    rH   Zresunit)r5   r6   r   r%   r   r   r   r   r   r   r   r:   r   r   )r/   r   r   r   rG   r.   rM   rur3   r3   r4   rF     sB    
zAutoEncoder._get_decode_layerztorch.Tensorr   )xr   c                 C  s"   |  |}| |}| |}|S )N)r*   r,   r.   )r/   rO   r3   r3   r4   forward.  s    


zAutoEncoder.forward)__name__
__module____qualname____doc__r   PRELUr   INSTANCEr$   r)   r+   r-   r9   rF   rP   __classcell__r3   r3   r1   r4   r	      s&   H403$+)
__future__r   collections.abcr   typingr   torchtorch.nnr5   Zmonai.networks.blocksr   r   monai.networks.layers.factoriesr   r   __all__Moduler	   r3   r3   r3   r4   <module>   s   