o
    &i4                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	 d dl
mZ d dlmZmZmZ g dZG dd	 d	ejZG d
d dejZG dd dejZeZe ZZdS )    )annotations)SequenceN)ConvPool)pixelunshuffle)DownsampleModeensure_tuple_replook_up_option)
MaxAvgPool
DownSample
DownsampleSubpixelDownsampleSubpixelDownSampleSubpixeldownsamplec                      s4   e Zd ZdZ			dd fddZdddZ  ZS )r
   z
    Downsample with both maxpooling and avgpooling,
    double the channel size by concatenating the downsampled feature maps.
    Nr   Fspatial_dimsintkernel_sizeSequence[int] | intstrideSequence[int] | int | Nonepadding	ceil_modeboolreturnNonec                   sn   t    t|||du rdnt||t|||d}ttj|f di || _ttj|f di || _dS )a  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            kernel_size: the kernel size of both pooling operations.
            stride: the stride of the window. Default value is `kernel_size`.
            padding: implicit zero padding to be added to both pooling operations.
            ceil_mode: when True, will use ceil instead of floor to compute the output shape.
        N)r   r   r   r    )super__init__r   r   MAXmax_poolAVGavg_pool)selfr   r   r   r   r   _params	__class__r   b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/downsample.pyr       s   
zMaxAvgPool.__init__xtorch.Tensorc                 C  s   t j| || |gddS )z
        Args:
            x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...]).

        Returns:
            Tensor in shape (batch, 2*channel, spatial_1[, spatial_2, ...]).
           )dim)torchcatr   r!   r"   r'   r   r   r&   forward:   s   zMaxAvgPool.forward)Nr   F)r   r   r   r   r   r   r   r   r   r   r   r   r'   r(   r   r(   __name__
__module____qualname____doc__r   r.   __classcell__r   r   r$   r&   r
      s    	r
   c                      s6   e Zd ZdZddddejdddfd fddZ  ZS )r   aJ  
    Downsamples data by `scale_factor`.

    Supported modes are:

    - "conv": uses a strided convolution for learnable downsampling.
    - "convgroup": uses a grouped strided convolution for efficient feature reduction.
    - "nontrainable": uses :py:class:`torch.nn.Upsample` with inverse scale factor.
    - "pixelunshuffle": uses :py:class:`monai.networks.blocks.PixelUnshuffle` for channel-space rearrangement.

    This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
    Please check the link below for more details:
    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

    This module can optionally take a pre-convolution
    (often used to map the number of features from `in_channels` to `out_channels`).
    N   defaultTr   r   in_channels
int | Noneout_channelsscale_factorSequence[float] | floatr   Sequence[float] | float | NonemodeDownsampleMode | strpre_convnn.Module | str | None	post_convnn.Module | Nonebiasr   r   r   c
                   s0  t    t||}
t|t}|s|
}td|}nt||}tdd |D }|tjkrI|s2td| dt	t	j|f ||p?|||
||	d dS |tj
krz|sTtd|du rZ|}|| dkrb|nd	}| d
t	t	j|f ||||
|||	d dS |tjkr|dkr||kr|std| dt	t	j|f ||p|d	|	d | dttj|f ||
|d |r| d| dS dS |tjkr|dkr||kr|std| dt	t	j|f ||p|d	|	d | dttj|f ||
|d |r| d| dS dS |tjkr| dt||||
d ||	d dS dS )a  
        Downsamples data by `scale_factor`.
        Supported modes are:

            - DownsampleMode.CONV: uses a strided convolution for learnable downsampling.
            - DownsampleMode.CONVGROUP: uses a grouped strided convolution for efficient feature reduction.
            - DownsampleMode.MAXPOOL: uses maxpooling for non-learnable downsampling.
            - DownsampleMode.AVGPOOL: uses average pooling for non-learnable downsampling.
            - DownsampleMode.PIXELUNSHUFFLE: uses :py:class:`monai.networks.blocks.SubpixelDownsample`.

        This operation will cause non-deterministic behavior when ``mode`` is ``DownsampleMode.NONTRAINABLE``.
        Please check the link below for more details:
        https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms

        This module can optionally take a pre-convolution and post-convolution
        (often used to map the number of features from `in_channels` to `out_channels`).

        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of channels of the input image.
            out_channels: number of channels of the output image. Defaults to `in_channels`.
            scale_factor: multiplier for spatial size reduction. Has to match input size if it is a tuple. Defaults to 2.
            kernel_size: kernel size used during convolutions. Defaults to `scale_factor`.
            mode: {``DownsampleMode.CONV``, ``DownsampleMode.CONVGROUP``, ``DownsampleMode.MAXPOOL``, ``DownsampleMode.AVGPOOL``,
                ``DownsampleMode.PIXELUNSHUFFLE``}. Defaults to ``DownsampleMode.CONV``.
            pre_conv: a conv block applied before downsampling. Defaults to "default".
                When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
                Only used in the "maxpool", "avgpool" or "pixelunshuffle" modes.
            post_conv: a conv block applied after downsampling. Defaults to None. Only used in the "maxpool" and "avgpool" modes.
            bias: whether to have a bias term in the default preconv and conv layers. Defaults to True.
        r   c                 s  s    | ]	}|d  d V  qdS )r)   r6   Nr   ).0kr   r   r&   	<genexpr>   s    z&DownSample.__init__.<locals>.<genexpr>z.in_channels needs to be specified in conv modeconvr8   r:   r   r   r   rD   z!in_channels needs to be specifiedNr)   	convgroup)r8   r:   r   r   r   groupsrD   r7   Zpreconv)r8   r:   r   rD   maxpool)r   r   r   Zpostconvavgpoolr   )r   r8   r:   r;   
conv_blockrD   )r   r   r   r	   r   tupleCONV
ValueError
add_moduler   	CONVGROUPMAXPOOLr   r   AVGPOOLr    PIXELUNSHUFFLEr   )r"   r   r8   r:   r;   r   r>   r@   rB   rD   Zscale_factor_Z	down_modekernel_size_r   rK   r$   r   r&   r   X   s   
+






zDownSample.__init__)r   r   r8   r9   r:   r9   r;   r<   r   r=   r>   r?   r@   rA   rB   rC   rD   r   r   r   )r1   r2   r3   r4   r   rP   r   r5   r   r   r$   r&   r   E   s    r   c                      s6   e Zd ZdZ				dd fddZdddZ  ZS )r   u  
    Downsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.
    The module consists of two parts. First, a convolutional layer is employed
    to adjust the number of channels. Secondly, a pixel unshuffle manipulation
    rearranges the spatial information into channel space, effectively reducing
    spatial dimensions while increasing channel depth.

    The pixel unshuffle operation is the inverse of pixel shuffle, rearranging dimensions
    from (B, C, H*r, W*r) to (B, C*r², H, W) for 2D images or from (B, C, H*r, W*r, D*r) to (B, C*r³, H, W, D) in 3D case.

    Example: (1, 1, 4, 4) with r=2 becomes (1, 4, 2, 2).

    See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
    Using a nEfficient Sub-Pixel Convolutional Neural Network."

    The pixel unshuffle mechanism is the inverse operation of:
    https://github.com/Project-MONAI/MONAI/blob/dev/monai/networks/blocks/upsample.py
    Nr6   r7   Tr   r   r8   r9   r:   r;   rN   rA   rD   r   r   r   c                   s   t    |dkrtd| d|| _|| _|dkr8|s!td|p$|}ttj| jf ||ddd|d| _d	S |d	u rCt	 | _d	S || _d	S )
a  
        Downsamples data by rearranging spatial information into channel space.
        This reduces spatial dimensions while increasing channel depth.

        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of channels of the input image.
            out_channels: optional number of channels of the output image.
            scale_factor: factor to reduce the spatial dimensions by. Defaults to 2.
            conv_block: a conv block to adjust channels before downsampling. Defaults to None.
                When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
                When ``conv_block`` is an ``nn.module``,
                please ensure the input number of channels matches requirements.
            bias: whether to have a bias term in the default conv_block. Defaults to True.
        r   zEThe `scale_factor` multiplier must be an integer greater than 0, got .r7   z!in_channels need to be specified.   r)   rI   N)
r   r   rQ   
dimensionsr;   r   rP   rN   nnIdentity)r"   r   r8   r:   r;   rN   rD   r$   r   r&   r      s   

zSubpixelDownsample.__init__r'   r(   c                   s^     |}t fdd|jdd D s%td|jdd  d j t| j j}|S )z
        Args:
            x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
        Returns:
            Tensor with reduced spatial dimensions and increased channel depth.
        c                 3  s    | ]
}| j  d kV  qdS )r   N)r;   )rE   dr"   r   r&   rG   %  s    z-SubpixelDownsample.forward.<locals>.<genexpr>r6   NzAll spatial dimensions z* must be evenly divisible by scale_factor )rN   allshaperQ   r;   r   rZ   r-   r   r^   r&   r.     s   
 zSubpixelDownsample.forward)Nr6   r7   T)r   r   r8   r9   r:   r9   r;   r   rN   rA   rD   r   r   r   r/   r0   r   r   r$   r&   r      s    ,r   )
__future__r   collections.abcr   r+   torch.nnr[   monai.networks.layers.factoriesr   r   monai.networks.utilsr   monai.utilsr   r   r	   __all__Moduler
   
Sequentialr   r   r   r   r   r   r   r   r&   <module>   s   + P