U
    Ph"G                     @   s   d dl mZmZmZmZmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZmZmZ dddgZG dd	 d	ejZG d
d dejZe ZZdS )    )ListOptionalSequenceTupleUnionN)interpolate)UnetBasicBlockUnetOutBlockUnetResBlockUnetUpBlockDynUNetDynUnetDynunetc                       s>   e Zd ZU dZeeej  ed< d fdd	Z	dd Z
  ZS )	DynUNetSkipLayerap  
    Defines a layer in the UNet topology which combines the downsample and upsample pathways with the skip connection.
    The member `next_layer` may refer to instances of this class or the final bottleneck layer at the bottom the UNet
    structure. The purpose of using a recursive class like this is to get around the Torchscript restrictions on
    looping over lists of layers and accumulating lists of output tensors which must be indexed. The `heads` list is
    shared amongst all the instances of this class and is used to store the output from the supervision heads during
    forward passes of the network.
    headsNc                    s2   t    || _|| _|| _|| _|| _|| _d S )N)super__init__
downsample
next_layerupsample
super_headr   index)selfr   r   r   r   r   r   	__class__ P/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/dynunet.pyr   %   s    
zDynUNetSkipLayer.__init__c                 C   sX   |  |}| |}| ||}| jd k	rT| jd k	rT| jdkrT| || j| jd < |S )Nr      )r   r   r   r   r   r   )r   xZdownoutZnextoutZupoutr   r   r   forward.   s    

zDynUNetSkipLayer.forward)NN)__name__
__module____qualname____doc__r   r   torchTensor__annotations__r   r   __classcell__r   r   r   r   r      s   
		r   c                       sv  e Zd ZdZdddddifddddfd	d
d	d	feeeeeee ef  eeee ef  eeee ef  eee  eeee	e
f  eee	f eee	f eeeed fddZdd Zdd Zdd Zdd Zdd Zdd ZedddZdd Zdd  Zd(ee ee eeee ef  eeee ef  ejeeeee ef   ed!d"d#Zd$d% Zed&d' Z  ZS ))r   a  
    This reimplementation of a dynamic UNet (DynUNet) is based on:
    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
    `Optimized U-Net for Brain Tumor Segmentation <https://arxiv.org/pdf/2110.03352.pdf>`_.

    This model is more flexible compared with ``monai.networks.nets.UNet`` in three
    places:

        - Residual connection is supported in conv blocks.
        - Anisotropic kernel sizes and strides can be used in each layers.
        - Deep supervision heads can be added.

    The model supports 2D or 3D inputs and is consisted with four kinds of blocks:
    one input block, `n` downsample blocks, one bottleneck and `n+1` upsample blocks. Where, `n>0`.
    The first and last kernel and stride values of the input sequences are used for input block and
    bottleneck respectively, and the rest value(s) are used for downsample and upsample blocks.
    Therefore, pleasure ensure that the length of input sequences (``kernel_size`` and ``strides``)
    is no less than 3 in order to have at least one downsample and upsample blocks.

    To meet the requirements of the structure, the input size for each spatial dimension should be divisible
    by the product of all strides in the corresponding dimension. In addition, the minimal spatial size should have
    at least one dimension that has twice the size of the product of all strides.
    For example, if `strides=((1, 2, 4), 2, 2, 1)`, the spatial size should be divisible by `(4, 8, 16)`,
    and the minimal spatial size is `(8, 8, 16)` or `(4, 16, 16)` or `(4, 8, 32)`.

    The output size for each spatial dimension equals to the input size of the corresponding dimension divided by the
    stride in strides[0].
    For example, if `strides=((1, 2, 4), 2, 2, 1)` and the input size is `(64, 32, 32)`, the output size is `(64, 16, 8)`.

    For backwards compatibility with old weights, please set `strict=False` when calling `load_state_dict`.

    Usage example with medical segmentation decathlon dataset is available at:
    https://github.com/Project-MONAI/tutorials/tree/master/modules/dynunet_pipeline.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        kernel_size: convolution kernel size.
        strides: convolution strides for each blocks.
        upsample_kernel_size: convolution kernel size for transposed convolution layers. The values should
            equal to strides[1:].
        filters: number of output channels for each blocks. Different from nnU-Net, in this implementation we add
            this argument to make the network more flexible. As shown in the third reference, one way to determine
            this argument is like:
            ``[64, 96, 128, 192, 256, 384, 512, 768, 1024][: len(strides)]``.
            The above way is used in the network that wins task 1 in the BraTS21 Challenge.
            If not specified, the way which nnUNet used will be employed. Defaults to ``None``.
        dropout: dropout ratio. Defaults to no dropout.
        norm_name: feature normalization type and arguments. Defaults to ``INSTANCE``.
            `INSTANCE_NVFUSER` is a faster version of the instance norm layer, it can be used when:
            1) `spatial_dims=3`, 2) CUDA device is available, 3) `apex` is installed and 4) non-Windows OS is used.
        act_name: activation layer type and arguments. Defaults to ``leakyrelu``.
        deep_supervision: whether to add deep supervision head before output. Defaults to ``False``.
            If ``True``, in training mode, the forward function will output not only the final feature map
            (from `output_block`), but also the feature maps that come from the intermediate up sample layers.
            In order to unify the return type (the restriction of TorchScript), all intermediate
            feature maps are interpolated into the same size as the final feature map and stacked together
            (with a new dimension in the first axis)into one single tensor.
            For instance, if there are two intermediate feature maps with shapes: (1, 2, 16, 12) and
            (1, 2, 8, 6), and the final feature map has the shape (1, 2, 32, 24), then all intermediate feature maps
            will be interpolated into (1, 2, 32, 24), and the stacked tensor will has the shape (1, 3, 2, 32, 24).
            When calculating the loss, you can use torch.unbind to get all feature maps can compute the loss
            one by one with the ground truth, then do a weighted average for all losses to achieve the final loss.
        deep_supr_num: number of feature maps that will output during deep supervision head. The
            value should be larger than 0 and less than the number of up sample layers.
            Defaults to 1.
        res_block: whether to use residual connection based convolution blocks during the network.
            Defaults to ``False``.
        trans_bias: whether to set the bias parameter in transposed convolution layers. Defaults to ``False``.
    NINSTANCEaffineT	leakyrelu{Gz?)inplacenegative_slopeFr   )spatial_dimsin_channelsout_channelskernel_sizestridesupsample_kernel_sizefiltersdropout	norm_nameact_namedeep_supervisiondeep_supr_num	res_block
trans_biasc                    s  t    _|_|_|_|_|_|	_|
_	|_
|rHtnt_|_|d k	rl|_  nfddtt|D _ _ _ _ _d_|_|_t dgj _!jr" _#$  %j& '  d	 fdd	 jsN djgt(j jd d d j_)n2 djgt(j jd d d jj#d_)d S )
Nc                    s*   g | ]"}t d d|   dkr dndqS )         i@  i   )min.0i)r.   r   r   
<listcomp>   s     z$DynUNet.__init__.<locals>.<listcomp>r   r   c                    s  t |t |kr*tt | dt | t |dkr:|S |dkr| d|  |dd |dd |}t| |d |d |dS d}| dkr|}n&t |dkrd}|dd }nt } d|  |dd |dd ||d}|rt| |d |d |j|d d	S t| |d |d |dS )
a  
            Construct the UNet topology as a sequence of skip layers terminating with the bottleneck layer. This is
            done recursively from the top down since a recursive nn.Module subclass is being used to be compatible
            with Torchscript. Initially the length of `downsamples` will be one more than that of `superheads`
            since the `input_block` is passed to this function as the first item in `downsamples`, however this
            shouldn't be associated with a supervision head.
            z != r   Nr   )r   r   r   FT
superheads)r   r   r   r   r   )len
ValueErrorr   nn
ModuleListr   )r   downsamples	upsamples
bottleneckrE   r   Zsuper_head_flagZ
rest_heads)create_skipsr   r   r   rM      s2    	"&	z&DynUNet.__init__.<locals>.create_skipsrD   )N)*r   r   r.   r/   r0   r1   r2   r3   r6   r7   r5   r
   r   
conv_blockr;   r4   check_filtersrangerF   get_input_blockZinput_blockget_downsamplesrJ   get_bottleneckrL   get_upsamplesrK   get_output_blockoutput_blockr8   r9   r$   randr   get_deep_supervision_headsZdeep_supervision_headscheck_deep_supr_numapplyinitialize_weightscheck_kernel_stridelistskip_layers)r   r.   r/   r0   r1   r2   r3   r4   r5   r6   r7   r8   r9   r:   r;   r   )rM   r   r.   r   r      sX    






+   zDynUNet.__init__c                 C   s   | j | j }}d}t|t|ks.t|dk r6t|t|D ]n\}}|||  }}t|tsd| d}t|| jkrt|t|ts>d| d}t|| jkr>t|q>d S )NzIlength of kernel_size and strides should be the same, and no less than 3.r>   zlength of kernel_size in block z$ should be the same as spatial_dims.zlength of stride in block )r1   r2   rF   rG   	enumerate
isinstanceintr.   )r   kernelsr2   	error_msgidxk_ikernelstrider   r   r   r]      s    

zDynUNet.check_kernel_stridec                 C   s>   | j | j }}t|d }||kr*td|dk r:tdd S )Nr   zAdeep_supr_num should be less than the number of up sample layers.z&deep_supr_num should be larger than 0.)r9   r2   rF   rG   )r   r9   r2   Znum_up_layersr   r   r   rZ      s    zDynUNet.check_deep_supr_numc                 C   s:   | j }t|t| jk r"tdn|d t| j | _ d S )Nz?length of filters should be no less than the length of strides.)r4   rF   r2   rG   )r   r4   r   r   r   rP     s    
zDynUNet.check_filtersc                 C   s^   |  |}| |}| jrZ| jrZ|g}| jD ]}|t||jdd   q,tj	|ddS |S )Nr<   r   )dim)
r_   rW   trainingr8   r   appendr   shaper$   stack)r   r   outZout_allfeature_mapr   r   r   r     s    


zDynUNet.forwardc              
   C   s6   | j | j| j| jd | jd | jd | j| j| jdS )Nr   r5   )	rO   r.   r/   r4   r1   r2   r6   r7   r5   r   r   r   r   rR     s    zDynUNet.get_input_blockc              
   C   s:   | j | j| jd | jd | jd | jd | j| j| jdS )NrN   rp   )rO   r.   r4   r1   r2   r6   r7   r5   rq   r   r   r   rT   "  s    zDynUNet.get_bottleneck)re   c                 C   s   t | j| j| | j| jdS )Nrp   )r	   r.   r4   r0   r5   )r   re   r   r   r   rV   .  s    zDynUNet.get_output_blockc                 C   sP   | j d d | j dd  }}| jdd | jdd  }}| ||||| jS )Nrr   r   rN   )r4   r2   r1   get_module_listrO   )r   inprn   r2   r1   r   r   r   rS   1  s    zDynUNet.get_downsamplesc              	   C   s   | j dd  d d d | j d d d d d  }}| jdd  d d d | jdd  d d d  }}| jd d d }| j||||t|| jdS )Nr   rN   )r;   )r4   r2   r1   r3   rs   r   r;   )r   rt   rn   r2   r1   r3   r   r   r   rU   6  s    22zDynUNet.get_upsamples)r/   r0   r1   r2   rO   r3   r;   c                 C   s   g }|d k	rdt |||||D ]D\}	}
}}}| j|	|
||| j| j| j||d
}|f |}|| qnNt ||||D ]>\}	}
}}| j|	|
||| j| j| jd}|f |}|| qrt|S )N)
r.   r/   r0   r1   rh   r6   r7   r5   r3   r;   )r.   r/   r0   r1   rh   r6   r7   r5   )zipr.   r6   r7   r5   rk   rH   rI   )r   r/   r0   r1   r2   rO   r3   r;   layersin_cout_crg   rh   Z	up_kernelparamslayerr   r   r   rs   D  sF    
    


zDynUNet.get_module_listc                    s   t  fddt jD S )Nc                    s   g | ]}  |d  qS )r   )rV   r@   rq   r   r   rC   r  s     z6DynUNet.get_deep_supervision_heads.<locals>.<listcomp>)rH   rI   rQ   r9   rq   r   rq   r   rY   q  s    z"DynUNet.get_deep_supervision_headsc                 C   sN   t | tjtjtjtjfrJtjj| jdd| _| j	d k	rJtj
| j	d| _	d S )Nr+   )ar   )ra   rH   Conv3dConv2dConvTranspose3dConvTranspose2dinitkaiming_normal_weightbias	constant_)moduler   r   r   r\   t  s    
zDynUNet.initialize_weights)NF)r    r!   r"   r#   rb   r   r   r   r   strfloatboolr   r]   rZ   rP   r   rR   rT   rV   rS   rU   r   rH   Modulers   rY   staticmethodr\   r'   r   r   r   r   r   8   s`   Q



j
  -)typingr   r   r   r   r   r$   torch.nnrH   torch.nn.functionalr   Z#monai.networks.blocks.dynunet_blockr   r	   r
   r   __all__r   r   r   r   r   r   r   r   r   <module>   s   
  F