o
    &i4                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	m
Z
 d dlmZmZ d dlmZmZmZmZ g dZG dd	 d	ejZG d
d dejZeZe ZZdS )    )annotations)SequenceN)ConvPadPool)	icnr_initpixelshuffle)InterpolateModeUpsampleModeensure_tuple_replook_up_option)UpsampleUpSampleSubpixelUpsampleSubpixelupsampleSubpixelUpSamplec                      s@   e Zd ZdZdddddejddejdddfd" fd d!Z  Z	S )#r   a  
    Upsamples data by `scale_factor`.
    Supported modes are:

        - "deconv": uses a transposed convolution.
        - "deconvgroup": uses a transposed group convolution.
        - "nontrainable": uses :py:class:`torch.nn.Upsample`.
        - "pixelshuffle": uses :py:class:`monai.networks.blocks.SubpixelUpsample`.

    This operation will cause non-deterministic when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
    Please check the link below for more details:
    https://pytorch.org/docs/stable/generated/torch.use_deterministic_algorithms.html#torch.use_deterministic_algorithms
    This module can optionally take a pre-convolution
    (often used to map the number of features from `in_channels` to `out_channels`).
    N   defaultTspatial_dimsintin_channels
int | Noneout_channelsscale_factorSequence[float] | floatkernel_sizeSequence[float] | float | Nonesizetuple[int] | int | NonemodeUpsampleMode | strpre_convnn.Module | str | None	post_convnn.Module | Noneinterp_modestralign_cornersbool | Nonebiasboolapply_pad_poolreturnNonec                   s\  t    t||}t|t}|s|}d }}nt||}tdd |D }tdd t||D }|tjkrY|sAtd| d| 	dt
t
j|f ||pN||||||d d	S |tjkr|shtd| d|d	u rn|}|| dkrv|nd
}| 	dt
t
j|f ||||||||d d	S |tjkr|dkr||kr|std| d| 	dt
t
j|f ||p|d
|d n|d	ur|dkr| 	d| n|d	u r||krtdt|
}
tjtjtjg}|
|v r||d
  }
tj||rd	n||
j|d}| 	d| |	r| 	d|	 d	S d	S |tjkr&| 	dt||||d |||d d	S td| d)aK	  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of channels of the input image.
            out_channels: number of channels of the output image. Defaults to `in_channels`.
            scale_factor: multiplier for spatial size. Has to match input size if it is a tuple. Defaults to 2.
            kernel_size: kernel size used during transposed convolutions. Defaults to `scale_factor`.
            size: spatial size of the output image.
                Only used when ``mode`` is ``UpsampleMode.NONTRAINABLE``.
                In torch.nn.functional.interpolate, only one of `size` or `scale_factor` should be defined,
                thus if size is defined, `scale_factor` will not be used.
                Defaults to None.
            mode: {``"deconv"``, ``"deconvgroup"``, ``"nontrainable"``, ``"pixelshuffle"``}. Defaults to ``"deconv"``.
            pre_conv: a conv block applied before upsampling. Defaults to "default".
                When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized when
                Only used in the "nontrainable" or "pixelshuffle" mode.
            post_conv: a conv block applied after upsampling. Defaults to None. Only used in the "nontrainable"  mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
                If ends with ``"linear"`` will use ``spatial dims`` to determine the correct interpolation.
                This corresponds to linear, bilinear, trilinear for 1D, 2D, and 3D respectively.
                The interpolation mode. Defaults to ``"linear"``.
                See also: https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html
            align_corners: set the align_corners parameter of `torch.nn.Upsample`. Defaults to True.
                Only used in the "nontrainable" mode.
            bias: whether to have a bias term in the default preconv and deconv layers. Defaults to True.
            apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
                size of `scale_factor` with a stride of 1. See also: :py:class:`monai.networks.blocks.SubpixelUpsample`.
                Only used in the "pixelshuffle" mode.

        r   c                 s  s    | ]	}|d  d V  qdS    r   N ).0kr0   r0   `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/upsample.py	<genexpr>c   s    z$UpSample.__init__.<locals>.<genexpr>c                 s  s(    | ]\}}|d  |d  d  V  qdS r.   r0   )r1   r2   sr0   r0   r3   r4   d   s   & z*in_channels needs to be specified in the 'z' mode.deconv)r   r   r   stridepaddingoutput_paddingr)   Nr/   deconvgroup)r   r   r   r7   r8   r9   groupsr)   r   preconv)r   r   r   r)   z\in the nontrainable mode, if not setting pre_conv, out_channels should equal to in_channels.)r   r   r   r'   Zupsample_non_trainablepostconvr   )r   r   r   r   
conv_blockr+   r)   zUnsupported upsampling mode .)super__init__r   r   r
   tuplezipDECONV
ValueError
add_moduler   	CONVTRANSDECONVGROUPNONTRAINABLECONVr	   LINEARBILINEAR	TRILINEARnnr   valuePIXELSHUFFLEr   NotImplementedError)selfr   r   r   r   r   r   r   r!   r#   r%   r'   r)   r+   scale_factor_Zup_modekernel_size_r9   r8   r;   Zlinear_modeupsample	__class__r0   r3   rA   +   s   
/






zUpSample.__init__)r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r'   r(   r)   r*   r+   r*   r,   r-   )
__name__
__module____qualname____doc__r
   rD   r	   rK   rA   __classcell__r0   r0   rV   r3   r      s    r   c                      s8   e Zd ZdZ					dd fddZdddZ  ZS )r   a  
    Upsample via using a subpixel CNN. This module supports 1D, 2D and 3D input images.
    The module is consisted with two parts. First of all, a convolutional layer is employed
    to increase the number of channels into: ``in_channels * (scale_factor ** dimensions)``.
    Secondly, a pixel shuffle manipulation is utilized to aggregates the feature maps from
    low resolution space and build the super resolution space.
    The first part of the module is not fixed, a sequential layers can be used to replace the
    default single layer.

    See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
    Using a nEfficient Sub-Pixel Convolutional Neural Network."

    See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".

    The idea comes from:
    https://arxiv.org/abs/1609.05158

    The pixel shuffle mechanism refers to:
    https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html#torch.nn.PixelShuffle.
    and:
    https://github.com/pytorch/pytorch/pull/6340.

    Nr   r   Tr   r   r   r   r   r   r>   r"   r+   r*   r)   r,   r-   c                   s
  t    |dkrtd| d|| _|| _|dkrE|p|}|s%td||| j  }ttj| jf ||ddd|d| _t| j| j n|d	u rOt	
 | _n|| _t	
 | _|rttj| jf }	ttj| jf }
t	|
| jd df| j d
d|	| jdd| _d	S d	S )a4  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of channels of the input image.
            out_channels: optional number of channels of the output image.
            scale_factor: multiplier for spatial size. Defaults to 2.
            conv_block: a conv block to extract feature maps before upsampling. Defaults to None.

                - When ``conv_block`` is ``"default"``, one reserved conv layer will be utilized.
                - When ``conv_block`` is an ``nn.module``,
                  please ensure the output number of channels is divisible ``(scale_factor ** dimensions)``.

            apply_pad_pool: if True the upsampled tensor is padded then average pooling is applied with a kernel the
                size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution
                component of subpixel convolutions described in Aitken et al.
            bias: whether to have a bias term in the default conv_block. Defaults to True.

        r   zEThe `scale_factor` multiplier must be an integer greater than 0, got r?   r   z!in_channels need to be specified.   r/   )r   r   r   r7   r8   r)   Ng        )r8   rO   )r   r7   )r@   rA   rE   
dimensionsr   r   rJ   r>   r   rN   Identitypad_poolr   AVGr   ZCONSTANTPAD
Sequential)rR   r   r   r   r   r>   r+   r)   conv_out_channels	pool_typeZpad_typerV   r0   r3   rA      s4   


zSubpixelUpsample.__init__xtorch.Tensorc              
   C  sv   |  |}|jd | j| j  dkr,td|jd  d| j d| j d| j| j  d	t|| j| j}| |}|S )zd
        Args:
            x: Tensor in shape (batch, channel, spatial_1[, spatial_2, ...).
        r/   r   z'Number of channels after `conv_block` (z:) must be evenly divisible by scale_factor ** dimensions (^=z).)r>   shaper   r^   rE   r   r`   )rR   re   r0   r0   r3   forward  s   


zSubpixelUpsample.forward)Nr   r   TT)r   r   r   r   r   r   r   r   r>   r"   r+   r*   r)   r*   r,   r-   )re   rf   r,   rf   )rX   rY   rZ   r[   rA   rj   r\   r0   r0   rV   r3   r      s    >r   )
__future__r   collections.abcr   torchtorch.nnrN   monai.networks.layers.factoriesr   r   r   monai.networks.utilsr   r   monai.utilsr	   r
   r   r   __all__rb   r   Moduler   r   r   r   r0   r0   r0   r3   <module>   s    "h