U
    Ph                     @  s2   d dl mZ d dlZd dlZdd Zdd	d
ZdS )    )annotationsNc              
   C  s   dd }t   ||| | }||| | }| d| d d| d  |   | |td  | | | j||d | W  5 Q R  S Q R X dS )a  Tensor initialization with truncated normal distribution.
    Based on:
    https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
    https://github.com/rwightman/pytorch-image-models

    Args:
       tensor: an n-dimensional `torch.Tensor`.
       mean: the mean of the normal distribution.
       std: the standard deviation of the normal distribution.
       a: the minimum cutoff value.
       b: the maximum cutoff value.
    c                 S  s   dt | t d  d S )N      ?       @)matherfsqrt)x r	   V/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/layers/weight_init.pynorm_cdf!   s    z(_no_grad_trunc_normal_.<locals>.norm_cdf      r   )minmaxN)	torchno_graduniform_erfinv_mul_r   r   add_clamp_)tensormeanstdabr   lur	   r	   r
   _no_grad_trunc_normal_   s    

r           r          r   c                 C  s0   |dkrt d||kr t dt| ||||S )aq  Tensor initialization with truncated normal distribution.
    Based on:
    https://github.com/rwightman/pytorch-image-models

    Args:
       tensor: an n-dimensional `torch.Tensor`
       mean: the mean of the normal distribution
       std: the standard deviation of the normal distribution
       a: the minimum cutoff value
       b: the maximum cutoff value
    r   z3the standard deviation should be greater than zero.zIminimum cutoff value (a) should be smaller than maximum cutoff value (b).)
ValueErrorr   )r   r   r   r   r   r	   r	   r
   trunc_normal_/   s
    r"   )r   r   r    r   )
__future__r   r   r   r   r"   r	   r	   r	   r
   <module>   s   