U
    Phm                     @  s   d dl mZ d dlmZ d dlmZ d dlZd dlmZ d dl	m  m
Z d dlmZ d dlmZ d dlmZmZ d dlmZ d	d
ddgZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZe Z ZZ dS )    )annotations)Sequence)OptionalN)Convolution)UpSample)ConvPool)ensure_tuple_repBasicUnetPlusPlusBasicunetPlusPlusbasicunetplusplusBasicUNetPlusPlusc                   @  s   e Zd ZdZdd ZdS )MCDropout3du>   MC Dropout：无论 model.train()/eval() 都启用随机失活c                 C  s   t j|| jd| jdS )NT)ptraininginplace)F	dropout3dr   r   )selfinput r   [/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/basic_unetplusplus.pyforward%   s    zMCDropout3d.forwardN)__name__
__module____qualname____doc__r   r   r   r   r   r   "   s   r   c                      sJ   e Zd ZdZddddddddd	d	d	d

 fddZdddddZ  ZS )TwoConvzResidual two-conv block (ResUNet++ style).

    If input/output channels differ, applies a 1x1 projection on the skip path.
               r    r       r"   r"   intstr | tupleboolfloat | tupleSequence[int] | int)
spatial_dimsin_chnsout_chnsactnormbiasdropoutkernel_sizepaddingstridec                   sl   t    t|||||||||	d	| _t|||||||||	d	| _d | _||krhtd|f ||dd| _d S )N)r+   r,   r.   r-   r/   r0   convr"   r/   )super__init__r   conv_0conv_1projr   )r   r(   r)   r*   r+   r,   r-   r.   r/   r0   r1   	__class__r   r   r5   /   s4    
zTwoConv.__init__torch.Tensor)xreturnc                 C  s8   |}|  |}| |}| jd k	r,| |}|| }|S )N)r6   r7   r8   )r   r<   identityoutr   r   r   r   Y   s    



zTwoConv.forward)r   r   r!   r!   r   r   r   r   r5   r   __classcell__r   r   r9   r   r   )   s       &*r   c                      s8   e Zd ZdZddddddddd	d	d
	 fddZ  ZS )Downz-maxpooling downsampling and two convolutions.r   r   r!   r#   r$   r%   r&   r'   )	r(   r)   r*   r+   r,   r-   r.   r/   r0   c
                   sR   t    td|f dd}
t|||||||||	d	}| d|
 | d| dS )a  
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

        MAX   rE   rE   r3   r/   r0   max_poolingconvsN)r4   r5   r   r   
add_module)r   r(   r)   r*   r+   r,   r-   r.   r/   r0   rG   rH   r9   r   r   r5   f   s    
zDown.__init__)r   r   r!   )r   r   r   r   r5   rA   r   r   r9   r   rB   c   s
   
   rB   c                      sV   e Zd ZdZdd	d	d	d	d
d
ddddddddddd fddZdddddZ  ZS )UpCatzHupsampling, concatenation with the encoder feature map, two convolutionsr   deconvdefaultlinearTr   r!   r#   r$   r%   r&   strznn.Module | str | Nonezbool | Noner'   )r(   r)   cat_chnsr*   r+   r,   r-   r.   upsamplepre_convinterp_modealign_cornershalvesis_padr/   r0   c                   sv   t    |	dkr |
dkr |}n|r,|d n|}t|||d|	|
|||d	| _t||| |||||||d	| _|| _dS )a6  
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels to be upsampled.
            cat_chns: number of channels from the encoder.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
            pre_conv: a conv block applied before upsampling.
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            align_corners: set the align_corners parameter for upsample. Defaults to True.
                Only used in the "nontrainable" mode.
            halves: whether to halve the number of channels during upsampling.
                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
            is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.

        nontrainableNrE   rD   )moderQ   rR   rS   r/   rF   )r4   r5   r   rP   r   rH   rU   )r   r(   r)   rO   r*   r+   r,   r-   r.   rP   rQ   rR   rS   rT   rU   r/   r0   up_chnsr9   r   r   r5      s6    *
zUpCat.__init__r;   zOptional[torch.Tensor])r<   x_ec                 C  s  |  |}|dk	rxtj|tjrxt|jd }dg|d  }t|D ]`}|j|d   }|j|d   }||k rJ|| }	|	d }
|	|
 }|
||d < |||d d < qJt|rtj	j
j||dd}tdg|  }t|D ]n}|j|d   }|j|d   }||kr|| }|d }|| }|}|| }| |d  }t||||< q|t| }| tj||gdd}n
| |}|S )z

        Args:
            x: features to be upsampled.
            x_e: optional features from the encoder, if None, this branch is not in use.
        NrE   r   r"   	replicate)rW   dim)rP   torchjit
isinstanceTensorlenshaperangeanynn
functionalpadslicer\   tuplerH   cat)r   r<   rY   x_0
dimensionsZpad_listiZsize_x0Zsize_xediffpad_left	pad_rightslicescrop	crop_left
crop_rightstartendZ	dim_indexr   r   r   r      s>    

zUpCat.forward)	r   rK   rL   rM   TTTr   r!   r@   r   r   r9   r   rJ      s            2GrJ   c                      sn   e Zd Zddddddddd	fd
ddifddddfdddddddddddd fddZddddZ  ZS )r   r    r"   rE   )0   @   `         rw   F	LeakyReLUg?T)negative_sloper   instanceaffiner   rK   r#   zSequence[int]r%   r$   r&   rN   float)r(   in_channelsout_channelsfeaturesdeep_supervisionr+   r,   r-   r.   rP   	dropout_pc                   s  t    || _|| _t|d}td| d tt||d |d ||||	ddd	t||d |d ||||	ddd	d

t||d |d ||||	dddd

| _	t
|||d ||||	ddd	| _t||d |d	 ||||	ddd	| _t||d	 |d ||||	ddd	| _t||d |d ||||	ddd	| _t||d |d ||||	ddd	| _t||d	 |d |d ||||	dddddd| _t||d |d	 |d	 ||||	dddddd| _t||d |d |d ||||	dddddd| _t||d |d |d ||||	dddddd| _t||d	 |d d |d ||||	dddddd| _t||d |d	 d |d	 ||||	dddddd| _t||d |d d |d ||||	dddddd| _t||d	 |d d |d ||||	dddddd| _t||d |d	 d |d	 ||||	dddddd| _t||d	 |d d |d ||||	dddddd| _| jrd| jdkrdt| jnt | _td|f |d |d	d| _ td|f |d |d	d| _!td|f |d |d	d| _"td|f |d |d	d| _#dS )a	  
        A UNet++ implementation with 1D/2D/3D supports.

        Based on:

            Zhou et al. "UNet++: A Nested U-Net Architecture for Medical Image
            Segmentation". 4th Deep Learning in Medical Image Analysis (DLMIA)
            Workshop, DOI: https://doi.org/10.48550/arXiv.1807.10165


        Args:
            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            deep_supervision: whether to prune the network at inference time. Defaults to False. If true, returns a list,
                whose elements correspond to outputs at different nodes.
            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            bias: whether to have a bias term in convolution blocks. Defaults to True.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        Examples::

            # for spatial 2D
            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with deep supervision enabled
            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), deep_supervision=True)

            # for spatial 2D, with group norm
            >>> net = BasicUNetPlusPlus(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNetPlusPlus(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))

        See Also
            - :py:class:`monai.networks.nets.BasicUNet`
            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

           zBasicUNetPlusPlus features: .r    r!   r   )r+   r,   r-   r.   r/   r0   r   r"   )r+   r,   r-   r.   r/   r0   dilationrE   rF   )r    r       )r"   r"   rE   )r    r       )r"   r"   r       )r    r    	   )r"   r"   r   rV   rL   T)rP   rQ   rT   r/   r0   r   r2   r3   N)$r4   r5   r   r   r	   printre   
Sequentialr   asppr   conv_0_0rB   conv_1_0conv_2_0conv_3_0conv_4_0rJ   	upcat_0_1	upcat_1_1	upcat_2_1	upcat_3_1	upcat_0_2	upcat_1_2	upcat_2_2	upcat_0_3	upcat_1_3	upcat_0_4r   Identitymc_dropout_outr   final_conv_0_1final_conv_0_2final_conv_0_3final_conv_0_4)r   r(   r   r   r   r   r+   r,   r-   r.   rP   r   fear9   r   r   r5   	  s"   A

&






  
  
  
  zBasicUNetPlusPlus.__init__r;   )r<   c                 C  sl  |  |}| |}| ||}| |}| ||}| |tj||gdd}| |}| 	|}| 
||}	| |	tj||gdd}
| |
tj|||gdd}t|j | |}| ||}| |tj||	gdd}| |tj|||
gdd}| |tj||||gdd}| |}| |}| |}| |}| |}| jrb||||g}n|g}|S )a  
        Args:
            x: input should have spatially N dimensions
                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `dimensions`.
                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
                even edge lengths.

        Returns:
            A torch Tensor of "raw" predictions in shape
            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
        r"   r[   )r   r   r   r   r   r   r]   rj   r   r   r   r   r   r   rb   r   r   r   r   r   r   r   r   r   r   r   )r   r<   Zx_0_0Zx_1_0Zx_0_1Zx_2_0Zx_1_1Zx_0_2Zx_3_0Zx_2_1Zx_1_2Zx_0_3Zx_4_0Zx_3_1Zx_2_2Zx_1_3Zx_0_4Z
output_0_1Z
output_0_2Z
output_0_3Z
output_0_4outputr   r   r   r   [  s4    











zBasicUNetPlusPlus.forward)r   r   r   r5   r   rA   r   r   r9   r   r     s   
(  T)!
__future__r   collections.abcr   typingr   r]   torch.nnre   torch.nn.functionalrf   r   "monai.networks.blocks.convolutionsr   monai.networks.blocks.upsampler   monai.networks.layers.factoriesr   r   monai.utils.miscr	   __all__	Dropout3dr   Moduler   r   rB   rJ   r   r
   r   r   r   r   r   r   <module>   s.   :,y   