o
    )i                     @  sf   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dl	m
Z
 d dlmZ G dd dejZdS )	    )annotations)SequenceN)Convolutionsame_padding)Convc                      s8   e Zd ZdZ					dd fddZdddZ  ZS )
SimpleASPPa  
    A simplified version of the atrous spatial pyramid pooling (ASPP) module.

    Chen et al., Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation.
    https://arxiv.org/abs/1802.02611

    Wang et al., A Noise-robust Framework for Automatic Segmentation of COVID-19 Pneumonia Lesions
    from CT Images. https://ieeexplore.ieee.org/document/9109297
          r   r   r
            BATCH	LEAKYRELUFspatial_dimsintin_channelsconv_out_channelskernel_sizesSequence[int]	dilations	norm_typetuple | str | None	acti_typebiasboolreturnNonec	              	     s   t    t|t|krtdt| dt| dtdd t||D }	t | _t|||	D ]\}
}}t	t	j
|f |||
||d}| j| q3|t|	 }t|||d|||d| _d	S )
a  
        Args:
            spatial_dims: number of spatial dimensions, could be 1, 2, or 3.
            in_channels: number of input channels.
            conv_out_channels: number of output channels of each atrous conv.
                The final number of output channels is conv_out_channels * len(kernel_sizes).
            kernel_sizes: a sequence of four convolutional kernel sizes.
                Defaults to (1, 3, 3, 3) for four (dilated) convolutions.
            dilations: a sequence of four convolutional dilation parameters.
                Defaults to (1, 2, 4, 6) for four (dilated) convolutions.
            norm_type: final kernel-size-one convolution normalization type.
                Defaults to batch norm.
            acti_type: final kernel-size-one convolution activation type.
                Defaults to leaky ReLU.
            bias: whether to have a bias term in convolution blocks. Defaults to False.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.

        Raises:
            ValueError: When ``kernel_sizes`` length differs from ``dilations``.

        See also:

            :py:class:`monai.networks.layers.Act`
            :py:class:`monai.networks.layers.Conv`
            :py:class:`monai.networks.layers.Norm`

        z?kernel_sizes and dilations length must match, got kernel_sizes=z dilations=.c                 s  s    | ]
\}}t ||V  qd S )Nr   ).0kd r$   \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/aspp.py	<genexpr>P   s    z&SimpleASPP.__init__.<locals>.<genexpr>)r   out_channelskernel_sizedilationpaddingr
   )r   r   r'   r(   actnormr   N)super__init__len
ValueErrortuplezipnn
ModuleListconvsr   CONVappendr   conv_k1)selfr   r   r   r   r   r   r   r   Zpadsr"   r#   pZ_convr'   	__class__r$   r%   r.   #   s4   
'

zSimpleASPP.__init__xtorch.Tensorc                   s,   t j fdd| jD dd}| |}|S )z^
        Args:
            x: in shape (batch, channel, spatial_1[, spatial_2, ...]).
        c                   s   g | ]}| qS r$   r$   )r!   convr=   r$   r%   
<listcomp>i   s    z&SimpleASPP.forward.<locals>.<listcomp>r
   )dim)torchcatr5   r8   )r9   r=   x_outr$   r@   r%   forwardd   s   
zSimpleASPP.forward)r	   r   r   r   F)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r=   r>   r   r>   )__name__
__module____qualname____doc__r.   rF   __classcell__r$   r$   r;   r%   r      s    Ar   )
__future__r   collections.abcr   rC   torch.nnr3   Z"monai.networks.blocks.convolutionsr   monai.networks.layersr   monai.networks.layers.factoriesr   Moduler   r$   r$   r$   r%   <module>   s   