U
    Ph6                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m  m
Z d dlmZmZmZ d dlmZ d dlmZmZ d dlmZ dd	gZG d
d dejZG dd	 d	eZdS )    )annotations)SequenceN)ResBlockget_conv_layerget_upsample_layer)Dropout)get_act_layerget_norm_layer)UpsampleMode	SegResNetSegResNetVAEc                      s   e Zd ZdZdddddddd	ifd
ddifddd	ddejfdddddddddddddd fddZdd Zdd ZddddZ	d d!d"d#d$Z
d d%d d&d'd(Zd d d"d)d*Z  ZS )+r   a  
    SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    The module does not include the variational autoencoder (VAE).
    The model supports 2D or 3D inputs.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 8.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        dropout_prob: probability of an element to be zero-ed. Defaults to ``None``.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``GROUP``.
        norm_name: deprecating option for feature normalization type.
        num_groups: deprecating option for group norm. parameters.
        use_conv_final: if add a final convolution block to output. Defaults to ``True``.
        blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.
        upsample_mode: [``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``]
            The mode of upsampling manipulations.
            Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.

            - ``deconv``, uses transposed convolution layers.
            - ``nontrainable``, uses non-trainable `linear` interpolation.
            - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.

                NRELUinplaceTGROUP
num_groups r   r   r      r   r   r   intfloat | Nonetuple | strstrbooltupleUpsampleMode | str)spatial_dimsinit_filtersin_channelsout_channelsdropout_probactnorm	norm_namer   use_conv_finalblocks_down	blocks_upupsample_modec                   s   t    |dkrtd|| _|| _|| _|| _|| _|| _|| _	t
|| _|rz| dkrntd| ddd|	if}|| _t|| _|
| _t|||| _|  | _|  \| _| _| || _|d k	rttj|f || _d S )N)r   r   z"`spatial_dims` can only be 2 or 3.groupzDeprecating option 'norm_name=z', please use 'norm' instead.r   )super__init__
ValueErrorr    r!   r"   r)   r*   r$   r%   r   act_modlowerr&   r
   r+   r(   r   convInit_make_down_layersdown_layers_make_up_layers	up_layers
up_samples_make_final_conv
conv_finalr   DROPOUTdropout)selfr    r!   r"   r#   r$   r%   r&   r'   r   r(   r)   r*   r+   	__class__ R/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/segresnet.pyr.   ;   s0    



zSegResNet.__init__c                   s   t  }jjjjf\}}t|D ]j\}}|d|   |dkr\t d  ddnt  }t j	|f fddt
|D  }|| q,|S )Nr   r   )stridec                   s   g | ]}t  jd qS )r&   r%   r   r%   .0_Zlayer_in_channelsr&   r<   r    r?   r@   
<listcomp>r   s     z/SegResNet._make_down_layers.<locals>.<listcomp>)nn
ModuleListr)   r    r!   r&   	enumerater   Identity
Sequentialrangeappend)r<   r4   r)   filtersiitempre_convZ
down_layerr?   rG   r@   r3   g   s    zSegResNet._make_down_layersc              
     s   t  t   }}jjjjjf\}}} t|}t|D ]r}|d||   |	t j
 fddt|| D   |	t j
td ddtd |dg  qD||fS )Nr   c                   s"   g | ]}t d   jdqS )r   rB   rC   rD   r&   Zsample_in_channelsr<   r    r?   r@   rH      s   z-SegResNet._make_up_layers.<locals>.<listcomp>r   kernel_sizer+   )rI   rJ   r+   r*   r    r!   r&   lenrN   rO   rM   r   r   )r<   r6   r7   r+   r*   rP   Zn_uprQ   r?   rT   r@   r5   w   s2    
zSegResNet._make_up_layers)r#   c                 C  s2   t t| j| j| jd| jt| j| j|dddS )Nnamer    channelsr   T)rV   bias)rI   rM   r	   r&   r    r!   r0   r   )r<   r#   r?   r?   r@   r8      s
    zSegResNet._make_final_convtorch.Tensorz'tuple[torch.Tensor, list[torch.Tensor]])xreturnc                 C  sH   |  |}| jd k	r| |}g }| jD ]}||}|| q(||fS N)r2   r$   r;   r4   rO   )r<   r^   down_xdownr?   r?   r@   encode   s    



zSegResNet.encodezlist[torch.Tensor])r^   ra   r_   c                 C  sP   t t| j| jD ](\}\}}||||d   }||}q| jrL| |}|S )Nr   )rK   zipr7   r6   r(   r9   )r<   r^   ra   rQ   upuplr?   r?   r@   decode   s    

zSegResNet.decodec                 C  s&   |  |\}}|  | ||}|S r`   )rc   reverserg   )r<   r^   ra   r?   r?   r@   forward   s    zSegResNet.forward)__name__
__module____qualname____doc__r
   NONTRAINABLEr.   r3   r5   r8   rc   rg   ri   __classcell__r?   r?   r=   r@   r      s*   

,,
c                      s   e Zd ZdZdddddddd	d
ddifdddifdddejfdddddddddddddddd fddZdd Zddd d!d"Zd#d$ Z	  Z
S )%r   a  
    SegResNetVAE based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    The module contains the variational autoencoder (VAE).
    The model supports 2D or 3D inputs.

    Args:
        input_image_size: the size of images to input into the network. It is used to
            determine the in_features of the fc layer in VAE.
        vae_estimate_std: whether to estimate the standard deviations in VAE. Defaults to ``False``.
        vae_default_std: if not to estimate the std, use the default value. Defaults to 0.3.
        vae_nz: number of latent variables in VAE. Defaults to 256.
            Where, 128 to represent mean, and 128 to represent std.
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 8.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        dropout_prob: probability of an element to be zero-ed. Defaults to ``None``.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``GROUP``.
        use_conv_final: if add a final convolution block to output. Defaults to ``True``.
        blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.
        upsample_mode: [``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``]
            The mode of upsampling manipulations.
            Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.

            - ``deconv``, uses transposed convolution layers.
            - ``nontrainable``, uses non-trainable `linear` interpolation.
            - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
    Fg333333?   r   r   r   r   Nr   r   Tr   r   r   r   zSequence[int]r   floatr   r   zstr | tupler   r   r   )input_image_sizevae_estimate_stdvae_default_stdvae_nzr    r!   r"   r#   r$   r%   r&   r(   r)   r*   r+   c                   s   t  j|||||	|
|||||d || _d| _dt| jd    fdd| jD | _|| _|| _|| _	| 
  | || _d S )N)r    r!   r"   r#   r$   r%   r&   r(   r)   r*   r+      r   r   c                   s   g | ]}|d    qS )r   r?   )rE   szoomr?   r@   rH      s     z)SegResNetVAE.__init__.<locals>.<listcomp>)r-   r.   rr   smallest_filtersrX   r)   	fc_insizers   rt   ru   _prepare_vae_modulesr8   vae_conv_final)r<   rr   rs   rt   ru   r    r!   r"   r#   r$   r%   r&   r(   r)   r*   r+   r=   rx   r@   r.      s,    zSegResNetVAE.__init__c                 C  s   dt | jd  }| j| }t| jt| j }t	t
| j| j|d| jt| j|| jdddt
| j| j| jd| j| _t|| j| _t|| j| _t| j|| _t	t| j| j|ddt| j|| jdt
| j| j|d| j| _d S )Nr   r   rY   T)rA   r\   rU   rW   )rX   r)   r!   r   rz   npprodr{   rI   rM   r	   r&   r    r0   r   vae_downLinearru   vae_fc1vae_fc2vae_fc3r   r+   vae_fc_up_sample)r<   ry   Z	v_filtersZtotal_elementsr?   r?   r@   r|     s&    
z!SegResNetVAE._prepare_vae_modulesr]   )	net_input	vae_inputc              	   C  s.  |  |}|d| jj}| |}t|}|d | jr| |}t	
|}dt|d |d  td|d   d  }|||  }n | j}t|d }|||  }| |}| |}|d| jg| j }| |}t| j| jD ]\}}	||}|	|}q| |}t	||}
||
 }|S )z
        Args:
            net_input: the original input of the network.
            vae_input: the input of VAE module, which is also the output of the network's encoder.
        Fg      ?r   g:0yE>r   )r   viewr   in_featurestorch
randn_likerequires_grad_rs   r   Fsoftplusmeanlogrt   r   r0   rz   r{   r   rd   r7   r6   r}   mse_loss)r<   r   r   Zx_vaeZz_meanZz_mean_randZz_sigmaZvae_reg_lossre   rf   Zvae_mse_lossvae_lossr?   r?   r@   _get_vae_loss  s0    





0




zSegResNetVAE._get_vae_lossc                 C  sL   |}|  |\}}|  |}| ||}| jrD| ||}||fS |d fS r`   )rc   rh   rg   trainingr   )r<   r^   r   ra   r   r   r?   r?   r@   ri   F  s    zSegResNetVAE.forward)rj   rk   rl   rm   r
   rn   r.   r|   r   ri   ro   r?   r?   r=   r@   r      s&   #

0,')
__future__r   collections.abcr   numpyr~   r   torch.nnrI   torch.nn.functional
functionalr   Z%monai.networks.blocks.segresnet_blockr   r   r   monai.networks.layers.factoriesr   monai.networks.layers.utilsr   r	   monai.utilsr
   __all__Moduler   r   r?   r?   r?   r@   <module>   s    