o
    ,iD*                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlmZm	Z	m
Z
mZmZ d dlmZ dgZdd ddZG dd dejZd!d"ddZG dd dejZG dd dejZG dd dejZG dd dejZG dd dejZdS )#    )annotationsN)Convolution)ActConvDropoutNorm
split_args)deprecated_argVNetacttuple[str, dict] | strnchanintc                 C  s6   | dkr
dd|if} t | \}}t| }|di |S )Nprelunum_parameters )r   r   )r   r   act_nameact_argsact_typer   r   Z/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/vnet.pyget_acti_layer   s
   r   c                      s(   e Zd Zdd fd	d
Zdd Z  ZS )LUConvFspatial_dimsr   r   r   r   biasboolc              	     s4   t    t||| _t|||dd tj|d| _d S )N   r   in_channelsout_channelskernel_sizer   normr   )super__init__r   act_functionr   r   BATCH
conv_block)selfr   r   r   r   	__class__r   r   r"   "   s   
zLUConv.__init__c                 C  s   |  |}| |}|S N)r%   r#   r&   xoutr   r   r   forward0   s   

zLUConv.forwardF)r   r   r   r   r   r   r   r   __name__
__module____qualname__r"   r-   __classcell__r   r   r'   r   r       s    r   Fr   depthr   r   c                 C  s0   g }t |D ]}|t| ||| qtj| S r)   )rangeappendr   nn
Sequential)r   r   r4   r   r   layers_r   r   r   _make_nconv6   s   
r;   c                      *   e Zd Z	dd fd
dZdd Z  ZS )InputTransitionFr   r   r   r   r   r   r   r   c              	     sh   t    || dkrtd| d| d|| _|| _|| _t||| _t|||dd t	j
|d| _d S )Nr   zAout channels should be divisible by in_channels. Got in_channels=z, out_channels=.r   r   )r!   r"   
ValueErrorr   r   r   r   r#   r   r   r$   r%   )r&   r   r   r   r   r   r'   r   r   r"   ?   s$   
zInputTransition.__init__c                 C  sN   |  |}| j| j }|d|dddgd | jd  }| t||}|S )N      )r%   r   r   repeatr   r#   torchadd)r&   r+   r,   Z
repeat_numx16r   r   r   r-   W   s
   
"zInputTransition.forwardr.   
r   r   r   r   r   r   r   r   r   r   r/   r   r   r'   r   r=   =   s    r=   c                      s.   e Zd Z			dd fddZdd Z  ZS )DownTransitionN   Fr   r   r   nconvsr   r   dropout_probfloat | Nonedropout_dimr   r   c                   s   t    ttj|f }ttj|f }	ttj|f }
d| }|||dd|d| _|	|| _	t
||| _t
||| _t|||||| _|d urM|
|| _d S d | _d S )NrA   )r   strider   )r!   r"   r   CONVr   r$   r   DROPOUT	down_convbn1r   act_function1act_function2r;   opsdropout)r&   r   r   rI   r   rJ   rL   r   	conv_type	norm_typedropout_typer   r'   r   r   r"   a   s   


 zDownTransition.__init__c                 C  sP   |  | | |}| jd ur| |}n|}| |}| t||}|S r)   )rR   rQ   rP   rU   rT   rS   rC   rD   )r&   r+   downr,   r   r   r   r-   y   s   

zDownTransition.forward)NrH   F)r   r   r   r   rI   r   r   r   rJ   rK   rL   r   r   r   r/   r   r   r'   r   rG   _   s    rG   c                      s,   e Zd Z		dd fddZdd Z  ZS )UpTransitionN      ?rH   r   r   r   r   rI   r   r   rJ   tuple[float | None, float]rL   c                   s   t    ttj|f }ttj|f }	ttj|f }
|||d ddd| _|	|d | _	|d d ur8|
|d nd | _
|
|d | _t||d | _t||| _t||||| _d S )NrA   )r   rM   r   r@   )r!   r"   r   	CONVTRANSr   r$   r   rO   up_convrQ   rU   dropout2r   rR   rS   r;   rT   )r&   r   r   r   rI   r   rJ   rL   conv_trans_typerW   rX   r'   r   r   r"      s   

zUpTransition.__init__c                 C  sj   | j d ur|  |}n|}| |}| | | |}t||fd}| |}| t	||}|S )Nr@   )
rU   r`   rR   rQ   r_   rC   catrT   rS   rD   )r&   r+   Zskipxr,   ZskipxdoZxcatr   r   r   r-      s   


zUpTransition.forward)r[   rH   )r   r   r   r   r   r   rI   r   r   r   rJ   r]   rL   r   r/   r   r   r'   r   rZ      s
    	rZ   c                      r<   )OutputTransitionFr   r   r   r   r   r   r   r   c              	     sR   t    ttj|f }t||| _t|||dd tj|d| _	|||dd| _
d S )Nr   r   r@   )r   )r!   r"   r   rN   r   rR   r   r   r$   r%   conv2)r&   r   r   r   r   r   rV   r'   r   r   r"      s   
	zOutputTransition.__init__c                 C  s"   |  |}| |}| |}|S r)   )r%   rR   rd   r*   r   r   r   r-      s   


zOutputTransition.forwardr.   rF   r/   r   r   r'   r   rc      s    rc   c                      sf   e Zd ZdZedddddedddddd	d
d
dddifdddd	df	d  fddZdd Z  ZS )!r
   a  
    V-Net based on `Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
    <https://arxiv.org/pdf/1606.04797.pdf>`_.
    Adapted from `the official Caffe implementation
    <https://github.com/faustomilletari/VNet>`_. and `another pytorch implementation
    <https://github.com/mattmacy/vnet.pytorch/blob/master/vnet.py>`_.
    The model supports 2D or 3D inputs.

    Args:
        spatial_dims: spatial dimension of the input data. Defaults to 3.
        in_channels: number of input channels for the network. Defaults to 1.
            The value should meet the condition that ``16 % in_channels == 0``.
        out_channels: number of output channels for the network. Defaults to 1.
        act: activation type in the network. Defaults to ``("elu", {"inplace": True})``.
        dropout_prob_down: dropout ratio for DownTransition blocks. Defaults to 0.5.
        dropout_prob_up: dropout ratio for UpTransition blocks. Defaults to (0.5, 0.5).
        dropout_dim: determine the dimensions of dropout. Defaults to (0.5, 0.5).

            - ``dropout_dim = 1``, randomly zeroes some of the elements for each channel.
            - ``dropout_dim = 2``, Randomly zeroes out entire channels (a channel is a 2D feature map).
            - ``dropout_dim = 3``, Randomly zeroes out entire channels (a channel is a 3D feature map).
        bias: whether to have a bias term in convolution blocks. Defaults to False.
            According to `Performance Tuning Guide <https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html>`_,
            if a conv layer is directly followed by a batch norm layer, bias should be False.

    .. deprecated:: 1.2
        ``dropout_prob`` is deprecated in favor of ``dropout_prob_down`` and ``dropout_prob_up``.

    rJ   z1.2dropout_prob_downz'please use `dropout_prob_down` instead.)namesincenew_name
msg_suffixdropout_prob_upz%please use `dropout_prob_up` instead.rH   r@   eluinplaceTr\   )r\   r\   Fr   r   r   r   r   r   rK   r]   rL   r   r   c
           
        s   t    |dvrtdt||d||	d| _t|dd||	d| _t|dd||	d| _t|dd	|||	d
| _t|dd|||	d
| _	t
|ddd||d| _t
|ddd||d| _t
|ddd|| _t
|ddd|| _t|d|||	d| _d S )N)rA   rH   z spatial_dims can only be 2 or 3.   )r   r@       rA   @   rH   )rJ   r         )rJ   )r!   r"   AssertionErrorr=   in_trrG   	down_tr32	down_tr64
down_tr128
down_tr256rZ   up_tr256up_tr128up_tr64up_tr32rc   out_tr)
r&   r   r   r   r   rJ   re   rj   rL   r   r'   r   r   r"      s   
zVNet.__init__c                 C  sp   |  |}| |}| |}| |}| |}| ||}| ||}| ||}| ||}| 	|}|S r)   )
rs   rt   ru   rv   rw   rx   ry   rz   r{   r|   )r&   r+   Zout16Zout32Zout64Zout128Zout256r   r   r   r-     s   





zVNet.forward)r   r   r   r   r   r   r   r   rJ   rK   re   rK   rj   r]   rL   r   r   r   )r0   r1   r2   __doc__r	   r"   r-   r3   r   r   r'   r   r
      s,    
)r   )r   r   r   r   r.   )
r   r   r   r   r4   r   r   r   r   r   )
__future__r   rC   torch.nnr7   "monai.networks.blocks.convolutionsr   monai.networks.layers.factoriesr   r   r   r   r   monai.utilsr	   __all__r   Moduler   r;   r=   rG   rZ   rc   r
   r   r   r   r   <module>   s   "%'