o
    ,i*                     @  s   d dl mZ d dlmZ d dlmZ d dlZd dlmZ d dl	m
Z
mZ d dlmZmZ d dlmZ g dZG d	d
 d
ejZG dd dejZG dd dejZG dd dejZe Z ZZdS )    )annotations)Sequence)OptionalN)ConvolutionUpSample)ConvPool)ensure_tuple_rep)	BasicUnet	Basicunet	basicunet	BasicUNetc                      &   e Zd ZdZ	dd fddZ  ZS )TwoConvztwo convolutions.        spatial_dimsintin_chnsout_chnsactstr | tuplenormbiasbooldropoutfloat | tuplec           
   
     sV   t    t|||||||dd}t|||||||dd}	| d| | d|	 dS )  
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.

           )r   r   r   r   paddingconv_0conv_1N)super__init__r   
add_module)
selfr   r   r   r   r   r   r   r   r    	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/basic_unet.pyr"      s   
zTwoConv.__init__r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   __name__
__module____qualname____doc__r"   __classcell__r'   r'   r%   r(   r          
r   c                      r   )Downz-maxpooling downsampling and two convolutions.r   r   r   r   r   r   r   r   r   r   r   r   c           
        sL   t    td|f dd}t|||||||}	| d| | d|	 dS )r   MAX   kernel_sizemax_poolingconvsN)r!   r"   r   r   r#   )
r$   r   r   r   r   r   r   r   r7   r8   r%   r'   r(   r"   @   s
   
zDown.__init__r)   r*   r+   r'   r'   r%   r(   r2   =   r1   r2   c                      s<   e Zd ZdZ							d$d% fddZd&d"d#Z  ZS )'UpCatzHupsampling, concatenation with the encoder feature map, two convolutionsr   deconvdefaultlinearTr   r   r   cat_chnsr   r   r   r   r   r   r   r   upsamplestrpre_convnn.Module | str | Noneinterp_modealign_cornersbool | Nonehalvesis_padc              
     sn   t    |	dkr|
du r|}n|r|d n|}t|||d|	|
||d| _t||| |||||| _|| _dS )a6  
        Args:
            spatial_dims: number of spatial dimensions.
            in_chns: number of input channels to be upsampled.
            cat_chns: number of channels from the encoder.
            out_chns: number of output channels.
            act: activation type and arguments.
            norm: feature normalization type and arguments.
            bias: whether to have a bias term in convolution blocks.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
            pre_conv: a conv block applied before upsampling.
                Only used in the "nontrainable" or "pixelshuffle" mode.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            align_corners: set the align_corners parameter for upsample. Defaults to True.
                Only used in the "nontrainable" mode.
            halves: whether to halve the number of channels during upsampling.
                This parameter does not work on ``nontrainable`` mode if ``pre_conv`` is `None`.
            is_pad: whether to pad upsampling features to fit features from encoder. Defaults to True.

        nontrainableNr4   )moder@   rB   rC   )r!   r"   r   r>   r   r8   rF   )r$   r   r   r=   r   r   r   r   r   r>   r@   rB   rC   rE   rF   Zup_chnsr%   r'   r(   r"   _   s    
(

zUpCat.__init__xtorch.Tensorx_eOptional[torch.Tensor]c                 C  s   |  |}|durXtj|tjrX| jrJt|jd }dg|d  }t|D ]}|j| d  |j| d  kr@d||d d < q&tj	j
||d}| tj||gdd}|S | |}|S )z

        Args:
            x: features to be upsampled.
            x_e: optional features from the encoder, if None, this branch is not in use.
        Nr4   r   r   	replicate)dim)r>   torchjit
isinstanceTensorrF   lenshaperangenn
functionalpadr8   cat)r$   rI   rK   x_0
dimensionsspir'   r'   r(   forward   s   
 
zUpCat.forward)r   r:   r;   r<   TTT)r   r   r   r   r=   r   r   r   r   r   r   r   r   r   r   r   r>   r?   r@   rA   rB   r?   rC   rD   rE   r   rF   r   )rI   rJ   rK   rL   )r,   r-   r.   r/   r"   r^   r0   r'   r'   r%   r(   r9   \   s    :r9   c                	      sN   e Zd Zddddddddfd	d
difdddf	d" fddZd#d d!Z  ZS )$r      r   r4   )    r`   @         r`   	LeakyReLUg?T)negative_slopeinplaceinstanceaffiner   r:   r   r   in_channelsout_channelsfeaturesSequence[int]r   r   r   r   r   r   r   r>   r?   c
                   sr  t    t|d}
td|
 d t|||d ||||| _t||
d |
d ||||| _t||
d |
d ||||| _t||
d |
d ||||| _	t||
d |
d ||||| _
t||
d |
d |
d |||||		| _t||
d |
d |
d |||||		| _t||
d |
d |
d |||||		| _t||
d |
d |
d	 |||||	d
d
| _td|f |
d	 |dd| _dS )u  
        A UNet implementation with 1D/2D/3D supports.

        Based on:

            Falk et al. "U-Net – Deep Learning for Cell Counting, Detection, and
            Morphometry". Nature Methods 16, 67–70 (2019), DOI:
            http://dx.doi.org/10.1038/s41592-018-0261-2

        Args:
            spatial_dims: number of spatial dimensions. Defaults to 3 for spatial 3D inputs.
            in_channels: number of input channels. Defaults to 1.
            out_channels: number of output channels. Defaults to 2.
            features: six integers as numbers of features.
                Defaults to ``(32, 32, 64, 128, 256, 32)``,

                - the first five values correspond to the five-level encoder feature sizes.
                - the last value corresponds to the feature size after the last upsampling.

            act: activation type and arguments. Defaults to LeakyReLU.
            norm: feature normalization type and arguments. Defaults to instance norm.
            bias: whether to have a bias term in convolution blocks. Defaults to True.
                According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
                if a conv layer is directly followed by a batch norm layer, bias should be False.
            dropout: dropout ratio. Defaults to no dropout.
            upsample: upsampling mode, available options are
                ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.

        Examples::

            # for spatial 2D
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128))

            # for spatial 2D, with group norm
            >>> net = BasicUNet(spatial_dims=2, features=(64, 128, 256, 512, 1024, 128), norm=("group", {"num_groups": 4}))

            # for spatial 3D
            >>> net = BasicUNet(spatial_dims=3, features=(32, 32, 64, 128, 256, 32))

        See Also

            - :py:class:`monai.networks.nets.DynUNet`
            - :py:class:`monai.networks.nets.UNet`

           zBasicUNet features: .r   r   r4   r_         F)rE   convr5   N)r!   r"   r	   printr   r   r2   down_1down_2down_3down_4r9   upcat_4upcat_3upcat_2upcat_1r   
final_conv)r$   r   ri   rj   rk   r   r   r   r   r>   Zfear%   r'   r(   r"      s   
9
&&&* zBasicUNet.__init__rI   rJ   c                 C  sp   |  |}| |}| |}| |}| |}| ||}| ||}| ||}	| |	|}
| 	|
}|S )a  
        Args:
            x: input should have spatially N dimensions
                ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N-1])``, N is defined by `spatial_dims`.
                It is recommended to have ``dim_n % 16 == 0`` to ensure all maxpooling inputs have
                even edge lengths.

        Returns:
            A torch Tensor of "raw" predictions in shape
            ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N-1])``.
        )
r   rs   rt   ru   rv   rw   rx   ry   rz   r{   )r$   rI   x0x1x2x3x4u4u3u2u1logitsr'   r'   r(   r^      s   





zBasicUNet.forward)r   r   ri   r   rj   r   rk   rl   r   r   r   r   r   r   r   r   r>   r?   )rI   rJ   )r,   r-   r.   r"   r^   r0   r'   r'   r%   r(   r      s    
Jr   )
__future__r   collections.abcr   typingr   rO   torch.nnrV   monai.networks.blocksr   r   monai.networks.layers.factoriesr   r   monai.utilsr	   __all__
Sequentialr   r2   Moduler9   r   r
   r   r   r'   r'   r'   r(   <module>   s   "Vh