U
    Ph*                     @  s   d dl mZ d dlmZ d dlmZ d dlZd dlmZ d dl	m
Z
mZ d dlmZmZ d dlmZ dd	d
dgZG dd dejZG dd dejZG dd dejZG dd dejZe Z ZZdS )    )annotations)Sequence)OptionalN)ConvolutionUpSample)ConvPool)ensure_tuple_rep	BasicUnet	Basicunet	basicunet	BasicUNetc                	      s4   e Zd ZdZd
dddddddd fdd	Z  ZS )TwoConvztwo convolutions.        intstr | tupleboolfloat | tuplespatial_dimsin_chnsout_chnsactnormbiasdropoutc           
   
     sV   t    t|||||||dd}t|||||||dd}	| d| | d|	 dS )  
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

           )r   r   r   r   paddingconv_0conv_1N)super__init__r   
add_module)
selfr   r   r   r   r   r   r   r   r    	__class__ S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/basic_unet.pyr"      s    
       zTwoConv.__init__)r   __name__
__module____qualname____doc__r"   __classcell__r'   r'   r%   r(   r      s   
 r   c                	      s4   e Zd ZdZd
dddddddd fdd	Z  ZS )Downz-maxpooling downsampling and two convolutions.r   r   r   r   r   r   c           
        sL   t    td|f dd}t|||||||}	| d| | d|	 dS )r   MAX   kernel_sizemax_poolingconvsN)r!   r"   r   r   r#   )
r$   r   r   r   r   r   r   r   r4   r5   r%   r'   r(   r"   @   s
    
zDown.__init__)r   r)   r'   r'   r%   r(   r/   =   s   
 r/   c                      sR   e Zd ZdZdddddddd	d
ddddd	d	d fddZdddddZ  ZS )UpCatzHupsampling, concatenation with the encoder feature map, two convolutionsr   deconvdefaultlinearTr   r   r   r   strznn.Module | str | Nonezbool | None)r   r   cat_chnsr   r   r   r   r   upsamplepre_convinterp_modealign_cornershalvesis_padc              
     sn   t    |	dkr |
dkr |}n|r,|d n|}t|||d|	|
||d| _t||| |||||| _|| _dS )a6  
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels to be upsampled.
            cat_chns: number of channels from the encoder.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
            pre_conv: a conv block applied before upsampling.
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            align_corners: set the align_corners parameter for upsample. Defaults to True.
                Only used in the "nontrainable" mode.
            halves: whether to halve the number of channels during upsampling.
                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
            is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.

        nontrainableNr1   )moder=   r>   r?   )r!   r"   r   r<   r   r5   rA   )r$   r   r   r;   r   r   r   r   r   r<   r=   r>   r?   r@   rA   Zup_chnsr%   r'   r(   r"   _   s     (

zUpCat.__init__torch.TensorzOptional[torch.Tensor])xx_ec                 C  s   |  |}|dk	rtj|tjr| jrt|jd }dg|d  }t|D ]4}|j| d  |j| d  krLd||d d < qLtj	j
||d}| tj||gdd}n
| |}|S )z

        Args:
            x: features to be upsampled.
            x_e: optional features from the encoder, if None, this branch is not in use.
        Nr1   r   r   	replicate)dim)r<   torchjit
isinstanceTensorrA   lenshaperangenn
functionalpadr5   cat)r$   rE   rF   x_0
dimensionsspir'   r'   r(   forward   s    
 
zUpCat.forward)r   r7   r8   r9   TTT)r*   r+   r,   r-   r"   rX   r.   r'   r'   r%   r(   r6   \   s          .:r6   c                      sf   e Zd Zddddddddfd	d
difdddf	dddddddddd	 fddZddddZ  ZS )r      r   r1   )    rZ   @         rZ   	LeakyReLUg?T)negative_slopeinplaceinstanceaffiner   r7   r   zSequence[int]r   r   r   r:   )	r   in_channelsout_channelsfeaturesr   r   r   r   r<   c
                   sr  t    t|d}
td|
 d t|||d ||||| _t||
d |
d ||||| _t||
d |
d ||||| _t||
d |
d ||||| _	t||
d |
d ||||| _
t||
d |
d |
d |||||		| _t||
d |
d |
d |||||		| _t||
d |
d |
d |||||		| _t||
d |
d |
d	 |||||	d
d
| _td|f |
d	 |dd| _dS )u  
        A UNet implementation with 1D/2D/3D supports.

        Based on:

            Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and
            Morphometry". Nature Methods 16, 67–70 (2019), DOI:
            http://dx.doi.org/10.1038/s41592-018-0261-2

        Args:
            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            bias: whether to have a bias term in convolution blocks. Defaults to True.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        Examples::

            # for spatial 2D
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with group norm
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))

        See Also

            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

           zBasicUNet features: .r   r   r1   rY         F)r@   convr2   N)r!   r"   r	   printr   r   r/   down_1down_2down_3down_4r6   upcat_4upcat_3upcat_2upcat_1r   
final_conv)r$   r   rc   rd   re   r   r   r   r   r<   Zfear%   r'   r(   r"      s    9

&&&*zBasicUNet.__init__rD   )rE   c                 C  sp   |  |}| |}| |}| |}| |}| ||}| ||}| ||}	| |	|}
| 	|
}|S )a  
        Args:
            x: input should have spatially N dimensions
                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`.
                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
                even edge lengths.

        Returns:
            A torch Tensor of "raw" predictions in shape
            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
        )
r   rl   rm   rn   ro   rp   rq   rr   rs   rt   )r$   rE   x0x1x2x3x4u4u3u2u1logitsr'   r'   r(   rX      s    





zBasicUNet.forward)r*   r+   r,   r"   rX   r.   r'   r'   r%   r(   r      s   
$J)
__future__r   collections.abcr   typingr   rI   torch.nnrP   Zmonai.networks.blocksr   r   monai.networks.layers.factoriesr   r   monai.utilsr	   __all__
Sequentialr   r/   Moduler6   r   r
   r   r   r'   r'   r'   r(   <module>   s   "Vh