o
    ,ic8                     @  s  d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	Z	d dl	m
Z
 d dlmZmZ d dlmZ d d	lmZ d d
lmZ d dlmZ d dlmZ d dlmZmZ g dZG dd dZe Zee ee G dd de
jZ G dd de
j!Z"G dd de
jZ#e#Z$dS )    )annotationsN)Sequence)locate)Any)nn)BaseEncoderUpSample)Conv)get_act_layer)EfficientNetEncoder)UpCat)ResNetEncoder)InterpolateModeoptional_import)FlexibleUNetFlexUNetFLEXUNET_BACKBONEFlexUNetEncoderRegisterc                   @  s"   e Zd ZdZdd Zd	ddZdS )
r   az  
    A register to regist backbones for the flexible unet. All backbones can be found in
    register_dict. Please notice each output of backbone must be 2x downsample in spatial
    dimension of last output. For example, if given a 512x256 2D image and a backbone with
    4 outputs. Then spatial size of each encoder output should be 256x128, 128x64, 64x32
    and 32x16.
    c                 C  s
   i | _ d S )N)register_dict)self r   c/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/flexible_unet.py__init__*   s   
z FlexUNetEncoderRegister.__init__nametype[Any] | strc                 C  s   t |tr$td| d\}}|st| }|}t |ts$td| dt|ts1t	| d |
 }| }| }| }t|t|  krYt|  krYt|ks\J  J t|D ]\}}	||| || || d}
|
| j|	< q`dS )z
        Register a given class to the encoder dict. Please notice that input class must be a
        subclass of BaseEncoder.
        zmonai.networks.nets)r   zCannot find z class.zl would better be derived from monai.networks.blocks.BaseEncoder or implement all interfaces specified by it.)typefeature_numberfeature_channel	parameterN)
isinstancestrr   r   r   
ValueError
issubclassr   warningswarnget_encoder_namesnum_outputsnum_channels_per_outputget_encoder_parameterslen	enumerater   )r   r   tmp_namehas_built_inZname_string_listZfeature_number_listZfeature_channel_listparameter_listcntname_stringZcur_dictr   r   r   register_class-   s0   



6z&FlexUNetEncoderRegister.register_classN)r   r   )__name__
__module____qualname____doc__r   r0   r   r   r   r   r   !   s    r   c                      s.   e Zd ZdZd fddZdd ddZ  ZS )!UNetDecoderaE  
    UNet Decoder.
    This class refers to `segmentation_models.pytorch
    <https://github.com/qubvel/segmentation_models.pytorch>`_.

    Args:
        spatial_dims: number of spatial dimensions.
        encoder_channels: number of output channels for all feature maps in encoder.
            `len(encoder_channels)` should be no less than 2.
        decoder_channels: number of output channels for all feature maps in decoder.
            `len(decoder_channels)` should equal to `len(encoder_channels) - 1`.
        act: activation type and arguments.
        norm: feature normalization type and arguments.
        dropout: dropout ratio.
        bias: whether to have a bias term in convolution blocks in this decoder.
        upsample: upsampling mode, available options are
            ``"deconv"``, ``"pixelshuffle"``, ``"nontrainable"``.
        pre_conv: a conv block applied before upsampling.
            Only used in the "nontrainable" or "pixelshuffle" mode.
        interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
            Only used in the "nontrainable" mode.
        align_corners: set the align_corners parameter for upsample. Defaults to True.
            Only used in the "nontrainable" mode.
        is_pad: whether to pad upsampling features to fit the encoder spatial dims.

    spatial_dimsintencoder_channelsSequence[int]decoder_channelsactstr | tuplenormdropoutfloat | tuplebiasboolupsampler    pre_conv
str | Noneinterp_modealign_cornersbool | Noneis_padc                   s   t    t|dk rtdt|t|d krtd|d gt|d d  }t|dd d d d dg }dgt|d  }|d g }t||||D ]\}}}}|t||||||||||	|
|||d	 qQt	|| _
d S )
N   z:the length of `encoder_channels` should be no less than 2.   zD`len(decoder_channels)` should equal to `len(encoder_channels) - 1`.r   TF)r6   in_chnscat_chnsout_chnsr;   r=   r>   r@   rB   rC   rE   rF   halvesrH   )superr   r)   r!   listappendzipr   r   
ModuleListblocks)r   r6   r8   r:   r;   r=   r>   r@   rB   rC   rE   rF   rH   in_channelsZskip_channelsrO   rU   Zin_chnZskip_chnZout_chnZhalve	__class__r   r   r   q   s<   
 
zUNetDecoder.__init__   featureslist[torch.Tensor]skip_connectc                 C  sl   |d d d d d }|dd  d d d }|d }t | jD ]\}}||k r,|| }nd }|||}q|S )NrK   rJ   r   )r*   rU   )r   rZ   r\   skipsxiblockskipr   r   r   forward   s   
zUNetDecoder.forward)r6   r7   r8   r9   r:   r9   r;   r<   r=   r<   r>   r?   r@   rA   rB   r    rC   rD   rE   r    rF   rG   rH   rA   )rY   )rZ   r[   r\   r7   r1   r2   r3   r4   r   rb   __classcell__r   r   rW   r   r5   U   s    /r5   c                      s*   e Zd ZdZ			dd fddZ  ZS )SegmentationHeada  
    Segmentation head.
    This class refers to `segmentation_models.pytorch
    <https://github.com/qubvel/segmentation_models.pytorch>`_.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels for the block.
        out_channels: number of output channels for the block.
        kernel_size: kernel size for the conv layer.
        act: activation type and arguments.
        scale_factor: multiplier for spatial size. Has to match input size if it is a tuple.

       N      ?r6   r7   rV   out_channelskernel_sizer;   tuple | str | Nonescale_factorfloatc           
        sp   t t j|f ||||d d}t }|dkr!t||dd tjd}|d ur*t|}	nt }	t 	|||	 d S )NrI   )rV   rh   ri   paddingrg   nontrainable)r6   rk   moderC   rE   )
r	   CONVr   Identityr   r   LINEARr
   rP   r   )
r   r6   rV   rh   ri   r;   rk   
conv_layerZup_layerZ	act_layerrW   r   r   r      s    	
zSegmentationHead.__init__)rf   Nrg   )r6   r7   rV   r7   rh   r7   ri   r7   r;   rj   rk   rl   )r1   r2   r3   r4   r   rd   r   r   rW   r   re      s    re   c                      sV   e Zd ZdZdddddddfd	d
difddddddfd, fd&d'Zd-d*d+Z  ZS ).r   zN
    A flexible implementation of UNet-like encoder-decoder architecture.
    F)      @          rI   batchgMbP?g?)epsmomentumreluinplaceTg        rn   defaultnearestrV   r7   rh   backboner    
pretrainedrA   r:   tupler6   r=   r<   r;   r>   r?   decoder_biasrB   rC   rE   rH   returnNonec                   s*  t    |tjvrtd| dtj  d|dvr tdtj| }|| _|| _|d }d|v r;d|v r;d	|v s?td
|d }|dkrKtd|d| }|d | _|	|||d t
|gt|d  }|d }|di || _t||||||	|
|||d|d| _t||d |ddd| _dS )a.
  
        A flexible implement of UNet, in which the backbone/encoder can be replaced with
        any efficient or residual network. Currently the input must have a 2 or 3 spatial dimension
        and the spatial size of each dimension must be a multiple of 32 if is_pad parameter
        is False.
        Please notice each output of backbone must be 2x downsample in spatial dimension
        of last output. For example, if given a 512x256 2D image and a backbone with 4 outputs.
        Spatial size of each encoder output should be 256x128, 128x64, 64x32 and 32x16.

        Args:
            in_channels: number of input channels.
            out_channels: number of output channels.
            backbone: name of backbones to initialize, only support efficientnet and resnet right now,
                can be from [efficientnet-b0, ..., efficientnet-b8, efficientnet-l2, resnet10, ..., resnet200].
            pretrained: whether to initialize pretrained weights. ImageNet weights are available for efficient networks
                if spatial_dims=2 and batch norm is used. MedicalNet weights are available for residual networks
                if spatial_dims=3 and in_channels=1. Default to False.
            decoder_channels: number of output channels for all feature maps in decoder.
                `len(decoder_channels)` should equal to `len(encoder_channels) - 1`,default
                to (256, 128, 64, 32, 16).
            spatial_dims: number of spatial dimensions, default to 2.
            norm: normalization type and arguments, default to ("batch", {"eps": 1e-3,
                "momentum": 0.1}).
            act: activation type and arguments, default to ("relu", {"inplace": True}).
            dropout: dropout ratio, default to 0.0.
            decoder_bias: whether to have a bias term in decoder's convolution blocks.
            upsample: upsampling mode, available options are``"deconv"``, ``"pixelshuffle"``,
                ``"nontrainable"``.
            pre_conv:a conv block applied before upsampling. Only used in the "nontrainable" or
                "pixelshuffle" mode, default to `default`.
            interp_mode: {``"nearest"``, ``"linear"``, ``"bilinear"``, ``"bicubic"``, ``"trilinear"``}
                Only used in the "nontrainable" mode.
            is_pad: whether to pad upsampling features to fit features from encoder. Default to True.
                If this parameter is set to "True", the spatial dim of network input can be arbitrary
                size, which is not supported by TensorRT. Otherwise, it must be a multiple of 32.
        zinvalid model_name z found, must be one of .)rI   rf   z spatial_dims can only be 2 or 3.r   r6   rV   r   zWThe backbone init method must have spatial_dims, in_channels and pretrained parameters.r      zBFlexible unet can only accept no more than 5 encoder feature maps.NrJ   )r6   rV   r   r   r   )r6   r8   r:   r;   r=   r>   r@   rB   rE   rC   rF   rH   rK   rf   )r6   rV   rh   ri   r;   r   )rP   r   r   r   r!   keysr   r6   r\   updater   rQ   encoderr5   decoderre   segmentation_head)r   rV   rh   r   r   r:   r6   r=   r;   r>   r   rB   rC   rE   rH   r   Zencoder_parametersZencoder_feature_numr8   Zencoder_typerW   r   r   r      sZ   
5


zFlexibleUNet.__init__inputstorch.Tensorc                 C  s*   |}|  |}| || j}| |}|S )as  
        Do a typical encoder-decoder-header inference.

        Args:
            inputs: input should have spatially N dimensions ``(Batch, in_channels, dim_0[, dim_1, ..., dim_N])``,
                N is defined by `dimensions`.

        Returns:
            A torch Tensor of "raw" predictions in shape ``(Batch, out_channels, dim_0[, dim_1, ..., dim_N])``.

        )r   r   r\   r   )r   r   r^   Zenc_outZdecoder_outZx_segr   r   r   rb   J  s
   

zFlexibleUNet.forward)rV   r7   rh   r7   r   r    r   rA   r:   r   r6   r7   r=   r<   r;   r<   r>   r?   r   rA   rB   r    rC   r    rE   r    rH   rA   r   r   )r   r   rc   r   r   rW   r   r      s    	
jr   )%
__future__r   r#   collections.abcr   pydocr   typingr   torchr   monai.networks.blocksr   r   monai.networks.layers.factoriesr	   monai.networks.layers.utilsr
   monai.networks.netsr   Zmonai.networks.nets.basic_unetr   Zmonai.networks.nets.resnetr   monai.utilsr   r   __all__r   r   r0   Moduler5   
Sequentialre   r   r   r   r   r   r   <module>   s0   /

Z, 