o
    ,iO                     @  s8  d dl mZ d dlmZmZmZmZ d dlZd dlm	Z	 d dl
m	  mZ d dlmZmZ d dlmZ d dlmZmZ d dlmZ d dlmZmZ d d	lmZ ed
\ZZdgZG dd deZG dd de	j Z!G dd deZ"G dd deZ#G dd de#Z$G dd de#Z%G dd de#Z&G dd de	j'Z(dS )    )annotations)OptionalSequenceTupleUnionN)ConvDenseBlockConvolution)squeeze_and_excitation)ActNorm)SkipConnection)get_dropout_layerget_pool_layer)optional_importr	   Quicknatc                      s    e Zd ZdZ fddZ  ZS )SkipConnectionWithIdxa7  
    Combine the forward pass input with the result from the given submodule::
    --+--submodule--o--
      |_____________|
    The available modes are ``"cat"``, ``"add"``, ``"mul"``.
    Defaults to "cat" and dimension 1.
    Inherits from SkipConnection but provides the indizes with each forward pass.
    c                   s   t  ||fS N)superforward)selfinputindices	__class__ ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/quicknat.pyr   -      zSkipConnectionWithIdx.forward)__name__
__module____qualname____doc__r   __classcell__r   r   r   r   r   #   s    	r   c                      s(   e Zd ZdZ fddZdd Z  ZS )SequentialWithIdxz
    A sequential container.
    Modules will be added to it in the order they are passed in the
    constructor.
    Own implementation to work with the new indices in the forward pass.
    c                   s   t  j|  d S r   r   __init__)r   argsr   r   r   r$   9   r   zSequentialWithIdx.__init__c                 C  s    | D ]	}|||\}}q||fS r   r   )r   r   r   moduler   r   r   r   <   s   zSequentialWithIdx.forwardr   r   r   r    r$   r   r!   r   r   r   r   r"   1   s    r"   c                      s2   e Zd ZdZd
 fdd	Zdd fdd	Z  ZS )ClassifierBlocka  
    Returns a classifier block without an activation function at the top.
    It consists of a 1 * 1 convolutional layer which maps the input to a num_class channel feature map.
    The output is a probability map for each of the classes.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of classes to map to.
        strides: convolution stride. Defaults to 1.
        kernel_size: convolution kernel size. Defaults to 3.
        adn_ordering: a string representing the ordering of activation, normalization, and dropout.
        Defaults to "NDA".
        act: activation type and arguments. Defaults to PReLU.

    NAc              	     s   t  ||||||| d S r   r#   )r   spatial_dimsin_channelsout_channelsstrideskernel_sizeactadn_orderingr   r   r   r$   T   s   zClassifierBlock.__init__r   torch.Tensorc                   st   |  ^}}}|d ur0tj|dd\}}|d|dd}t|dkr,t||}|d fS tdt 	|}|d fS )Nr   dim      z;Quicknat is a 2D architecture, please check your dimension.)
sizetorchmaxviewlenFconv2d
ValueErrorr   r   )r   r   weightsr   _channeldimsZout_convr   r   r   r   W   s   zClassifierBlock.forward)Nr)   )NN)r   r1   r'   r   r   r   r   r(   B   s    r(   c                      s<   e Zd ZdZ				dd fddZdd Zdd Z  ZS )ConvConcatDenseBlocka  
    This dense block is defined as a sequence of 'Convolution' blocks. It overwrite the '_get_layer' methodto change the ordering of
    Every convolutional layer is preceded by a batch-normalization layer and a Rectifier Linear Unit (ReLU) layer.
    The first two convolutional layers are followed by a concatenation layer that concatenates
    the input feature map with outputs of the current and previous convolutional blocks.
    Kernel size of two convolutional layers kept small to limit number of paramters.
    Appropriate padding is provided so that the size of feature maps before and after convolution remains constant.
    The output channels for each convolution layer is set to 64, which acts as a bottle- neck for feature map selectivity.
    The input channel size is variable, depending on the number of dense connections.
    The third convolutional layer is also preceded by a batch normalization and ReLU,
    but has a 1 * 1 kernel size to compress the feature map size to 64.
    Args:
        in_channles: variable depending on depth of the network
        seLayer: Squeeze and Excite block to be included, defaults to None, valid options are {'NONE', 'CSE', 'SSE', 'CSSE'},
        dropout_layer: Dropout block to be included, defaults to None.
    :return: forward passed tensor
    N   @   r+   intse_layerOptional[nn.Module]dropout_layerOptional[nn.Dropout2d]r.   Sequence[int] | intnum_filtersc                   s`   d| _ t j|d|||gdd|if|d |d ur|nt | _|d ur)|| _d S t | _d S )Nr   r5   instancenum_features)r+   r*   channelsnormr.   )countr   r$   nnIdentityrF   rH   )r   r+   rF   rH   r.   rK   r   r   r   r$   {   s   
 zConvConcatDenseBlock.__init__c              
   C  s\   | j dk r| jnd}|  j d7  _ t| j||d|| jdd|ifd}t|d|dS )	a  
        After ever convolutional layer the output is concatenated with the input and the layer before.
        The concatenated output is used as input to the next convolutional layer.

        Args:
            in_channels: number of input channels.
            out_channels: number of output channels.
            strides: convolution stride.
            is_top: True if this is the top block.
        r5   )r4   r4   r4   rL   rM   )r*   r+   r,   r-   r.   r/   rO   adnconv)rP   r.   r   r*   r/   rQ   
Sequentialget_submodule)r   r+   r,   dilationZ
kernelsizerT   r   r   r   
_get_layer   s   
	zConvConcatDenseBlock._get_layerc                 C  s   d}|}|}|   D ]3}t|tjtjtjfrq
||}|dkr+|}tj||fdd}|dkr9tj|||fdd}|d }q
| |}| 	|}|d fS )Nr   r4   r2   )
children
isinstancerQ   	MaxPool2dMaxUnpool2d	Dropout2dr7   catrF   rH   )r   r   r?   iresultresult1lr   r   r   r      s    


zConvConcatDenseBlock.forward)NNrC   rD   )
r+   rE   rF   rG   rH   rI   r.   rJ   rK   rE   )r   r   r   r    r$   rX   r   r!   r   r   r   r   rB   h   s    rB   c                      s0   e Zd ZdZd	 fddZd
 fdd	Z  ZS )Encodera  
    Returns a convolution dense block for the encoding (down) part of a layer of the network.
    This Encoder block downpools the data with max_pool.
    Its output is used as input to the next layer down.
    New feature: it returns the indices of the max_pool to the decoder (up) path
    at the same layer to upsample the input.

    Args:
        in_channels: number of input channels.
        max_pool: predefined max_pool layer to downsample the data.
        se_layer: Squeeze and Excite block to be included, defaults to None.
        dropout: Dropout block to be included, defaults to None.
        kernel_size : kernel size of the convolutional layers. Defaults to 5*5
        num_filters : number of input channels to each convolution block. Defaults to 64
    r+   rE   c                      t  ||||| || _d S r   )r   r$   max_pool)r   r+   re   rF   dropoutr.   rK   r   r   r   r$         
zEncoder.__init__Nc                   s(   |  |\}}t |d \}}||fS r   )re   r   r   r   r   r   Z	out_blockr?   r   r   r   r      s   zEncoder.forwardr+   rE   r   r'   r   r   r   r   rc      s    rc   c                      .   e Zd ZdZd fddZ fddZ  ZS )	Decodera  
    Returns a convolution dense block for the decoding (up) part of a layer of the network.
    This will upsample data with an unpool block before the forward.
    It uses the indices from corresponding encoder on it's level.
    Its output is used as input to the next layer up.

    Args:
        in_channels: number of input channels.
        un_pool: predefined unpool block.
        se_layer: predefined SELayer. Defaults to None.
        dropout: predefined dropout block. Defaults to None.
        kernel_size: Kernel size of convolution layers. Defaults to 5*5.
        num_filters: number of input channels to each convolution layer. Defaults to 64.
    r+   rE   c                   rd   r   )r   r$   un_pool)r   r+   rl   rF   rf   r.   rK   r   r   r   r$      rg   zDecoder.__init__c                   s&   t  |d \}}| ||}|d fS r   )r   r   rl   rh   r   r   r   r      s   zDecoder.forwardri   r'   r   r   r   r   rk      s    rk   c                      rj   )	
Bottlenecka  
    Returns the bottom or bottleneck layer at the bottom of a network linking encoder to decoder halves.
    It consists of a 5 * 5 convolutional layer and a batch normalization layer to separate
    the encoder and decoder part of the network, restricting information flow between the encoder and decoder.

    Args:
        in_channels: number of input channels.
        se_layer: predefined SELayer. Defaults to None.
        dropout: predefined dropout block. Defaults to None.
        un_pool: predefined unpool block.
        max_pool: predefined maxpool block.
        kernel_size: Kernel size of convolution layers. Defaults to 5*5.
        num_filters: number of input channels to each convolution layer. Defaults to 64.
    r+   rE   c                   s$   t  ||||| || _|| _d S r   )r   r$   re   rl   )r   r+   rF   rf   re   rl   r.   rK   r   r   r   r$     s   
zBottleneck.__init__c                   s4   |  |\}}t |d \}}| ||}|d fS r   )re   r   r   rl   rh   r   r   r   r     s   zBottleneck.forwardri   r'   r   r   r   r   rm      s    rm   c                      sb   e Zd ZdZddddddddddejejd	fd' fddZd(dd Z	e
d!d" Zd)d%d&Z  ZS )*r   a  
    Model for "Quick segmentation of NeuroAnaTomy (QuickNAT) based on a deep fully convolutional neural network.
    Refer to: "QuickNAT: A Fully Convolutional Network for Quick and Accurate Segmentation of Neuroanatomy by
    Abhijit Guha Roya, Sailesh Conjetib, Nassir Navabb, Christian Wachingera"

    QuickNAT has an encoder/decoder like 2D F-CNN architecture with 4 encoders and 4 decoders separated by a bottleneck layer.
    The final layer is a classifier block with softmax.
    The architecture includes skip connections between all encoder and decoder blocks of the same spatial resolution,
    similar to the U-Net architecture.
    All Encoder and Decoder consist of three convolutional layers all with a Batch Normalization and ReLU.
    The first two convolutional layers are followed by a concatenation layer that concatenates
    the input feature map with outputs of the current and previous convolutional blocks.
    The kernel size of the first two convolutional layers is 5*5, the third convolutional layer has a kernel size of 1*1.

    Data in the encode path is downsampled using max pooling layers instead of upsamling like UNet and in the decode path
    upsampled using max un-pooling layers instead of transpose convolutions.
    The pooling is done at the beginning of the block and the unpool afterwards.
    The indices of the max pooling in the Encoder are forwarded through the layer to be available to the corresponding Decoder.

    The bottleneck block consists of a 5 * 5 convolutional layer and a batch normalization layer
    to separate the encoder and decoder part of the network,
    restricting information flow between the encoder and decoder.

    The output feature map from the last decoder block is passed to the classifier block,
    which is a convolutional layer with 1 * 1 kernel size that maps the input to an N channel feature map,
    where N is the number of segmentation classes.

    To further explain this consider the first example network given below. This network has 3 layers with strides
    of 2 for each of the middle layers (the last layer is the bottom connection which does not down/up sample). Input
    data to this network is immediately reduced in the spatial dimensions by a factor of 2 by the first convolution of
    the residual unit defining the first layer of the encode part. The last layer of the decode part will upsample its
    input (data from the previous layer concatenated with data from the skip connection) in the first convolution. this
    ensures the final output of the network has the same shape as the input.

    The original QuickNAT implementation included a `enable_test_dropout()` mechanism for uncertainty estimation during
    testing. As the dropout layers are the only stochastic components of this network calling the train() method instead
    of eval() in testing or inference has the same effect.

    Args:
        num_classes: number of classes to segmentate (output channels).
        num_channels: number of input channels.
        num_filters: number of output channels for each convolutional layer in a Dense Block.
        kernel_size: size of the kernel of each convolutional layer in a Dense Block.
        kernel_c: convolution kernel size of classifier block kernel.
        stride_convolution: convolution stride. Defaults to 1.
        pool: kernel size of the pooling layer,
        stride_pool: stride for the pooling layer.
        se_block: Squeeze and Excite block type to be included, defaults to None. Valid options : NONE, CSE, SSE, CSSE,
        droup_out: dropout ratio. Defaults to no dropout.
        act: activation type and arguments. Defaults to PReLU.
        norm: feature normalization type and arguments. Defaults to instance norm.
        adn_ordering: a string representing the ordering of activation (A), normalization (N), and dropout (D).
            Defaults to "NA". See also: :py:class:`monai.networks.blocks.ADN`.

    Examples::

        from monai.networks.nets import QuickNAT

        # network with max pooling by a factor of 2 at each layer with no se_block.
        net = QuickNAT(
            num_classes=3,
            num_channels=1,
            num_filters=64,
            pool = 2,
            se_block = "None"
        )

    !   r4   rD   rC   r5   Noner   NAnum_classesrE   num_channelsrK   r.   rJ   kernel_cstride_convpoolstride_poolse_blockstrdrop_outfloatr/   Union[Tuple, str]rO   r0   returnc                   s   || _ || _|| _t   | |	tdd|
ifddtd||dddfddtj	||d	
d 	
fdd  d| _
d S )Nrf   pr5   )namedropout_dimr8   T)r.   stridereturn_indices	ceil_mode)r~   r*   )r.   r   layerrE   r|   	nn.Modulec                   s   | dk r | d }n
t 
}| dkr<t}td }td	}t|t|||S td 
}t}t|t||S )a  
            Builds the QuickNAT structure from the bottom up by recursing down to the bottelneck layer, then creating sequential
            blocks containing the decoder, a skip connection around the previous block, and the encoder.
            At the last layer a classifier block is added to the Sequential.

            Args:
                layer = inversproportional to the layers left to create
               r4   r5   )rm   rB   r(   r"   r   rk   rc   )r   subblockdownup
classifier_create_modelrH   rs   r.   re   rr   rq   rK   rF   rt   rl   r   r   r   }  s   
z(Quicknat.__init__.<locals>._create_modelr4   )r   rE   r|   r   )r/   rO   r0   r   r$   get_selayerr   r   rQ   r\   model)r   rq   rr   rK   r.   rs   rt   ru   rv   rw   ry   r/   rO   r0   r   r   r   r$   ^  s   
"zQuicknat.__init__c                 C  sP   |dkr
t d|S |dks|dkr&tstd|dkr!t|S t|S dS )a@  
        Returns the SEBlock defined in the initialization of the QuickNAT model.

        Args:
            n_filters: encoding half of the layer
            se_block_type: defaults to None. Valid options are None, CSE, SSE, CSSE
        Returns: Appropriate SEBlock. SSE and CSSE not implemented in Monai yet.
        ZCSEr5   SSEZCSSEzCPlease install squeeze_and_excitation locally to use SpatialSELayerN)seChannelSELayerflagImportErrorse1ZSpatialSELayerZChannelSpatialSELayer)r   	n_filtersZse_block_typer   r   r   r     s   	

zQuicknat.get_selayerc                 C  s   t |  jS )zE
        Check if model parameters are allocated on the GPU.
        )next
parametersis_cuda)r   r   r   r   r     s   zQuicknat.is_cudar   r1   c                 C  s   |  |d \}}|S r   )r   )r   r   r?   r   r   r   r     s   zQuicknat.forward)rq   rE   rr   rE   rK   rE   r.   rJ   rs   rE   rt   rE   ru   rE   rv   rE   rw   rx   ry   rz   r/   r{   rO   r{   r0   rx   r|   ro   )ro   )r   r1   r|   r1   )r   r   r   r    r
   PRELUr   INSTANCEr$   r   propertyr   r   r!   r   r   r   r   r     s(    G
;
))
__future__r   typingr   r   r   r   r7   torch.nnrQ   torch.nn.functional
functionalr;   monai.networks.blocksr   r   r	   r   monai.networks.layers.factoriesr
   r   "monai.networks.layers.simplelayersr   monai.networks.layers.utilsr   r   monai.utilsr   r   r   __all__r   rU   r"   r(   rB   rc   rk   rm   Moduler   r   r   r   r   <module>   s*   &]