U
    PhD*                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlmZm	Z	m
Z
mZmZ d dlmZ dgZdddd	d
dZG dd dejZdddddddddZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdS )    )annotationsN)Convolution)ActConvDropoutNorm
split_args)deprecated_argVNettuple[str, dict] | strint)actnchanc                 C  s2   | dkrdd|if} t | \}}t| }|f |S )Nprelunum_parameters)r   r   )r   r   act_nameact_argsact_type r   M/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/vnet.pyget_acti_layer   s
    r   c                      s2   e Zd Zd
ddddd fddZdd	 Z  ZS )LUConvFr   r   bool)spatial_dimsr   r   biasc              	     s4   t    t||| _t|||dd tj|d| _d S )N   r   in_channelsout_channelskernel_sizer   normr   )super__init__r   act_functionr   r   BATCH
conv_block)selfr   r   r   r   	__class__r   r   r"   "   s    
zLUConv.__init__c                 C  s   |  |}| |}|S N)r%   r#   r&   xoutr   r   r   forward0   s    

zLUConv.forward)F__name__
__module____qualname__r"   r-   __classcell__r   r   r'   r   r       s   r   Fr   )r   r   depthr   r   c                 C  s0   g }t |D ]}|t| ||| qtj| S r)   )rangeappendr   nn
Sequential)r   r   r3   r   r   layers_r   r   r   _make_nconv6   s    r:   c                      s4   e Zd Zd
dddddd fddZdd	 Z  ZS )InputTransitionFr   r   r   r   r   r   r   r   c              	     sh   t    || dkr,td| d| d|| _|| _|| _t||| _t|||dd t	j
|d| _d S )Nr   zAout channels should be divisible by in_channels. Got in_channels=z, out_channels=.r   r   )r!   r"   
ValueErrorr   r   r   r   r#   r   r   r$   r%   )r&   r   r   r   r   r   r'   r   r   r"   ?   s$    
zInputTransition.__init__c                 C  sN   |  |}| j| j }|d|dddgd | jd  }| t||}|S )N      )r%   r   r   repeatr   r#   torchadd)r&   r+   r,   Z
repeat_numx16r   r   r   r-   W   s
    
"zInputTransition.forward)Fr.   r   r   r'   r   r;   =   s    r;   c                	      s8   e Zd Zddddddddd fd	d
Zdd Z  ZS )DownTransitionN   Fr   r   float | Noner   )r   r   nconvsr   dropout_probdropout_dimr   c                   s   t    ttj|f }ttj|f }	ttj|f }
d| }|||dd|d| _|	|| _	t
||| _t
||| _t|||||| _|d k	r|
|nd | _d S )Nr@   )r   strider   )r!   r"   r   CONVr   r$   r   DROPOUT	down_convbn1r   act_function1act_function2r:   opsdropout)r&   r   r   rH   r   rI   rJ   r   	conv_type	norm_typedropout_typer   r'   r   r   r"   a   s    


zDownTransition.__init__c                 C  sP   |  | | |}| jd k	r,| |}n|}| |}| t||}|S r)   )rP   rO   rN   rS   rR   rQ   rB   rC   )r&   r+   downr,   r   r   r   r-   y   s    

zDownTransition.forward)NrF   Fr.   r   r   r'   r   rE   _   s
       rE   c                	      s8   e Zd Zddddddddd fddZd	d
 Z  ZS )UpTransitionN      ?rF   r   r   tuple[float | None, float])r   r   r   rH   r   rI   rJ   c                   s   t    ttj|f }ttj|f }	ttj|f }
|||d ddd| _|	|d | _	|d d k	rp|
|d nd | _
|
|d | _t||d | _t||| _t||||| _d S )Nr@   )r   rK   r   r?   )r!   r"   r   	CONVTRANSr   r$   r   rM   up_convrO   rS   dropout2r   rP   rQ   r:   rR   )r&   r   r   r   rH   r   rI   rJ   conv_trans_typerU   rV   r'   r   r   r"      s    

zUpTransition.__init__c                 C  sj   | j d k	r|  |}n|}| |}| | | |}t||fd}| |}| t	||}|S )Nr?   )
rS   r^   rP   rO   r]   rB   catrR   rQ   rC   )r&   r+   Zskipxr,   ZskipxdoZxcatr   r   r   r-      s    


zUpTransition.forward)rY   rF   r.   r   r   r'   r   rX      s   	   rX   c                      s4   e Zd Zd
dddddd fddZdd	 Z  ZS )OutputTransitionFr   r   r   r<   c              	     sR   t    ttj|f }t||| _t|||dd tj|d| _	|||dd| _
d S )Nr   r   r?   )r   )r!   r"   r   rL   r   rP   r   r   r$   r%   conv2)r&   r   r   r   r   r   rT   r'   r   r   r"      s    
	zOutputTransition.__init__c                 C  s"   |  |}| |}| |}|S r)   )r%   rP   rb   r*   r   r   r   r-      s    


zOutputTransition.forward)Fr.   r   r   r'   r   ra      s    ra   c                      sz   e Zd ZdZedddddedddddd	d
d
dddifdddd	df	dddddddddd	 fddZdd Z  ZS )r
   a  
    V-Net based on `Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
    <https://arxiv.org/pdf/1606.04797.pdf>`_.
    Adapted from `the official Caffe implementation
    <https://github.com/faustomilletari/VNet>`_. and `another pytorch implementation
    <https://github.com/mattmacy/vnet.pytorch/blob/master/vnet.py>`_.
    The model supports 2D or 3D inputs.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        in_channels: number of input channels for the network. Defaults to 1.
            The value should meet the condition that ``16 % in_channels == 0``.
        out_channels: number of output channels for the network. Defaults to 1.
        act: activation type in the network. Defaults to ``("elu", {"inplace": True})``.
        dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5.
        dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5).
        dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5).

            - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel.
            - ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map).
            - ``dropout_dim = 3``, Randomly zeroes out entire channels (a channel is a 3D feature map).
        bias: whether to have a bias term in convolution blocks. Defaults to False.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.

    .. deprecated:: 1.2
        ``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``.

    rI   z1.2dropout_prob_downz'please use `dropout_prob_down` instead.)namesincenew_name
msg_suffixdropout_prob_upz%please use `dropout_prob_up` instead.rF   r?   eluinplaceTrZ   )rZ   rZ   Fr   r   rG   r[   r   )	r   r   r   r   rI   rc   rh   rJ   r   c
           
        s   t    |dkrtdt||d||	d| _t|dd||	d| _t|dd||	d| _t|dd	|||	d
| _t|dd|||	d
| _	t
|ddd||d| _t
|ddd||d| _t
|ddd|| _t
|ddd|| _t|d|||	d| _d S )N)r@   rF   z spatial_dims can only be 2 or 3.   )r   r?       r@   @   rF   )rI   r         )rI   )r!   r"   AssertionErrorr;   in_trrE   	down_tr32	down_tr64
down_tr128
down_tr256rX   up_tr256up_tr128up_tr64up_tr32ra   out_tr)
r&   r   r   r   r   rI   rc   rh   rJ   r   r'   r   r   r"      s    
zVNet.__init__c                 C  sp   |  |}| |}| |}| |}| |}| ||}| ||}| ||}| ||}| 	|}|S r)   )
rq   rr   rs   rt   ru   rv   rw   rx   ry   rz   )r&   r+   Zout16Zout32Zout64Zout128Zout256r   r   r   r-     s    





zVNet.forward)r/   r0   r1   __doc__r	   r"   r-   r2   r   r   r'   r   r
      s0      
()r   )F)
__future__r   rB   torch.nnr6   "monai.networks.blocks.convolutionsr   monai.networks.layers.factoriesr   r   r   r   r   monai.utilsr	   __all__r   Moduler   r:   r;   rE   rX   ra   r
   r   r   r   r   <module>   s   "%'