o
    ,i6                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m  m
Z d dlmZmZmZ d dlmZ d dlmZmZ d dlmZ dd	gZG d
d dejZG dd	 d	eZdS )    )annotations)SequenceN)ResBlockget_conv_layerget_upsample_layer)Dropout)get_act_layerget_norm_layer)UpsampleMode	SegResNetSegResNetVAEc                      s   e Zd ZdZdddddddd	ifd
ddifddd	ddejfd6 fd"d#Zd$d% Zd&d' Zd7d(d)Z	d8d.d/Z
d9d2d3Zd:d4d5Z  ZS );r   a  
    SegResNet based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    The module does not include the variational autoencoder (VAE).
    The model supports 2D or 3D inputs.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 8.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        dropout_prob: probability of an element to be zero-ed. Defaults to ``None``.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``GROUP``.
        norm_name: deprecating option for feature normalization type.
        num_groups: deprecating option for group norm. parameters.
        use_conv_final: if add a final convolution block to output. Defaults to ``True``.
        blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.
        upsample_mode: [``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``]
            The mode of upsampling manipulations.
            Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.

            - ``deconv``, uses transposed convolution layers.
            - ``nontrainable``, uses non-trainable `linear` interpolation.
            - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.

                NRELUinplaceTGROUP
num_groups r   r   r      r   r   r   spatial_dimsintinit_filtersin_channelsout_channelsdropout_probfloat | Noneacttuple | strnorm	norm_namestruse_conv_finalboolblocks_downtuple	blocks_upupsample_modeUpsampleMode | strc                   s   t    |dvrtd|| _|| _|| _|| _|| _|| _|| _	t
|| _|r=| dkr7td| ddd|	if}|| _t|| _|
| _t|||| _|  | _|  \| _| _| || _|d urrttj|f || _d S d S )N)r   r   z"`spatial_dims` can only be 2 or 3.groupzDeprecating option 'norm_name=z', please use 'norm' instead.r   )super__init__
ValueErrorr   r   r   r'   r)   r   r    r   act_modlowerr"   r
   r*   r%   r   convInit_make_down_layersdown_layers_make_up_layers	up_layers
up_samples_make_final_conv
conv_finalr   DROPOUTdropout)selfr   r   r   r   r   r    r"   r#   r   r%   r'   r)   r*   	__class__ _/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/segresnet.pyr.   ;   s2   



zSegResNet.__init__c                   s   t  }jjjjf\}}t|D ]6\}}|d|   |dkr.t d  ddnt  }t j	|g fddt
|D R  }|| q|S )Nr   r   )stridec                   s   g | ]}t  jd qS )r"   r    r   r    .0_Zlayer_in_channelsr"   r<   r   r?   r@   
<listcomp>r   s    z/SegResNet._make_down_layers.<locals>.<listcomp>)nn
ModuleListr'   r   r   r"   	enumerater   Identity
Sequentialrangeappend)r<   r4   r'   filtersiitempre_convZ
down_layerr?   rG   r@   r3   g   s   zSegResNet._make_down_layersc              
     s   t  t  }}jjjjjf\}}} t|}t|D ]9}|d||   |	t j
 fddt|| D   |	t j
td ddtd |dg  q"||fS )Nr   c                   s"   g | ]}t d   jdqS )r   rB   rC   rD   r"   Zsample_in_channelsr<   r   r?   r@   rH      s    z-SegResNet._make_up_layers.<locals>.<listcomp>r   kernel_sizer*   )rI   rJ   r*   r)   r   r   r"   lenrN   rO   rM   r   r   )r<   r6   r7   r*   r)   rP   Zn_uprQ   r?   rT   r@   r5   w   s2   
zSegResNet._make_up_layersc                 C  s2   t t| j| j| jd| jt| j| j|dddS )Nnamer   channelsr   T)rV   bias)rI   rM   r	   r"   r   r   r0   r   )r<   r   r?   r?   r@   r8      s
   zSegResNet._make_final_convxtorch.Tensorreturn'tuple[torch.Tensor, list[torch.Tensor]]c                 C  sH   |  |}| jd ur| |}g }| jD ]}||}|| q||fS N)r2   r   r;   r4   rO   )r<   r]   down_xdownr?   r?   r@   encode   s   



zSegResNet.encoderb   list[torch.Tensor]c                 C  sP   t t| j| jD ]\}\}}||||d   }||}q	| jr&| |}|S )Nr   )rK   zipr7   r6   r%   r9   )r<   r]   rb   rQ   upuplr?   r?   r@   decode   s   

zSegResNet.decodec                 C  s&   |  |\}}|  | ||}|S ra   )rd   reverseri   )r<   r]   rb   r?   r?   r@   forward   s   zSegResNet.forward)r   r   r   r   r   r   r   r   r   r   r    r!   r"   r!   r#   r$   r   r   r%   r&   r'   r(   r)   r(   r*   r+   )r   r   )r]   r^   r_   r`   )r]   r^   rb   re   r_   r^   )r]   r^   r_   r^   )__name__
__module____qualname____doc__r
   NONTRAINABLEr.   r3   r5   r8   rd   ri   rk   __classcell__r?   r?   r=   r@   r      s,    

,



c                      sl   e Zd ZdZdddddddd	d
ddifdddifdddejfd4 fd)d*Zd+d, Zd5d0d1Zd2d3 Z	  Z
S )6r   a  
    SegResNetVAE based on `3D MRI brain tumor segmentation using autoencoder regularization
    <https://arxiv.org/pdf/1810.11654.pdf>`_.
    The module contains the variational autoencoder (VAE).
    The model supports 2D or 3D inputs.

    Args:
        input_image_size: the size of images to input into the network. It is used to
            determine the in_features of the fc layer in VAE.
        vae_estimate_std: whether to estimate the standard deviations in VAE. Defaults to ``False``.
        vae_default_std: if not to estimate the std, use the default value. Defaults to 0.3.
        vae_nz: number of latent variables in VAE. Defaults to 256.
            Where, 128 to represent mean, and 128 to represent std.
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        init_filters: number of output channels for initial convolution layer. Defaults to 8.
        in_channels: number of input channels for the network. Defaults to 1.
        out_channels: number of output channels for the network. Defaults to 2.
        dropout_prob: probability of an element to be zero-ed. Defaults to ``None``.
        act: activation type and arguments. Defaults to ``RELU``.
        norm: feature normalization type and arguments. Defaults to ``GROUP``.
        use_conv_final: if add a final convolution block to output. Defaults to ``True``.
        blocks_down: number of down sample blocks in each layer. Defaults to ``[1,2,2,4]``.
        blocks_up: number of up sample blocks in each layer. Defaults to ``[1,1,1]``.
        upsample_mode: [``"deconv"``, ``"nontrainable"``, ``"pixelshuffle"``]
            The mode of upsampling manipulations.
            Using the ``nontrainable`` modes cannot guarantee the model's reproducibility. Defaults to``nontrainable``.

            - ``deconv``, uses transposed convolution layers.
            - ``nontrainable``, uses non-trainable `linear` interpolation.
            - ``pixelshuffle``, uses :py:class:`monai.networks.blocks.SubpixelUpsample`.
    Fg333333?   r   r   r   r   Nr   r   Tr   r   r   r   input_image_sizeSequence[int]vae_estimate_stdr&   vae_default_stdfloatvae_nzr   r   r   r   r   r   r   r    str | tupler"   r!   r%   r'   r(   r)   r*   r+   c                   s   t  j|||||	|
|||||d || _d| _dt| jd    fdd| jD | _|| _|| _|| _	| 
  | || _d S )N)r   r   r   r   r   r    r"   r%   r'   r)   r*      r   r   c                   s   g | ]}|d    qS )r   r?   )rE   szoomr?   r@   rH      s    z)SegResNetVAE.__init__.<locals>.<listcomp>)r-   r.   rs   smallest_filtersrX   r'   	fc_insizeru   rv   rx   _prepare_vae_modulesr8   vae_conv_final)r<   rs   ru   rv   rx   r   r   r   r   r   r    r"   r%   r'   r)   r*   r=   r|   r@   r.      s,   zSegResNetVAE.__init__c                 C  s   dt | jd  }| j| }t| jt| j }t	t
| j| j|d| jt| j|| jdddt
| j| j| jd| j| _t|| j| _t|| j| _t| j|| _t	t| j| j|ddt| j|| jdt
| j| j|d| j| _d S )Nr   r   rY   T)rA   r\   rU   rW   )rX   r'   r   r   r~   npprodr   rI   rM   r	   r"   r   r0   r   vae_downLinearrx   vae_fc1vae_fc2vae_fc3r   r*   vae_fc_up_sample)r<   r}   Z	v_filtersZtotal_elementsr?   r?   r@   r     s&   

z!SegResNetVAE._prepare_vae_modules	net_inputr^   	vae_inputc              	   C  s.  |  |}|d| jj}| |}t|}|d | jrH| |}t	
|}dt|d |d  td|d   d  }|||  }n| j}t|d }|||  }| |}| |}|d| jg| j }| |}t| j| jD ]\}}	||}|	|}qy| |}t	||}
||
 }|S )z
        Args:
            net_input: the original input of the network.
            vae_input: the input of VAE module, which is also the output of the network's encoder.
        Fg      ?r   g:0yE>r   )r   viewr   in_featurestorch
randn_likerequires_grad_ru   r   Fsoftplusmeanlogrv   r   r0   r~   r   r   rf   r7   r6   r   mse_loss)r<   r   r   Zx_vaeZz_meanZz_mean_randz_sigmaZvae_reg_lossrg   rh   Zvae_mse_lossvae_lossr?   r?   r@   _get_vae_loss  s0   





0




zSegResNetVAE._get_vae_lossc                 C  sL   |}|  |\}}|  |}| ||}| jr"| ||}||fS |d fS ra   )rd   rj   ri   trainingr   )r<   r]   r   rb   r   r   r?   r?   r@   rk   F  s   zSegResNetVAE.forward)rs   rt   ru   r&   rv   rw   rx   r   r   r   r   r   r   r   r   r   r   r   r    ry   r"   r!   r%   r&   r'   r(   r)   r(   r*   r+   )r   r^   r   r^   )rl   rm   rn   ro   r
   rp   r.   r   r   rk   rq   r?   r?   r=   r@   r      s(    #

,
')
__future__r   collections.abcr   numpyr   r   torch.nnrI   torch.nn.functional
functionalr   Z%monai.networks.blocks.segresnet_blockr   r   r   monai.networks.layers.factoriesr   monai.networks.layers.utilsr   r	   monai.utilsr
   __all__Moduler   r   r?   r?   r?   r@   <module>   s    