U
    Ph-                     @  sX   d dl mZ d dlmZmZ d dlmZ d dlZd dlm	Z	 edZ
G dd de	ZdS )	    )annotations)CallableIterable)TypeVarN)	OptimizerTc                	      sR   e Zd ZdZdddd	ddd
d
d fddZ fddZddddddZ  ZS )Novograda  
    Novograd based on `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks
    <https://arxiv.org/pdf/1905.11286.pdf>`_.
    The code is adapted from the implementations in `Jasper for PyTorch
    <https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/common/optimizers.py>`_,
    and `OpenSeq2Seq <https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/optimizers/novograd.py>`_.

    Args:
        params: iterable of parameters to optimize or dicts defining parameter groups.
        lr: learning rate. Defaults to 1e-3.
        betas: coefficients used for computing running averages of gradient and its square. Defaults to (0.9, 0.98).
        eps: term added to the denominator to improve numerical stability. Defaults to 1e-8.
        weight_decay: weight decay (L2 penalty). Defaults to 0.
        grad_averaging: gradient averaging. Defaults to ``False``.
        amsgrad: whether to use the AMSGrad variant of this algorithm from the paper
            `On the Convergence of Adam and Beyond <https://arxiv.org/pdf/1904.09237.pdf>`_. Defaults to ``False``.
    MbP?g?g\(\?:0yE>r   Fr   floatztuple[float, float]bool)paramslrbetasepsweight_decaygrad_averagingamsgradc           	        s   d|krt d| d|kr,t d| d|d   krDdk sXn t d|d  d|d   krpdk sn t d|d  d|krt d	| t||||||d
}t || d S )Ng        zInvalid learning rate: zInvalid epsilon value: r   g      ?z#Invalid beta parameter at index 0:    z#Invalid beta parameter at index 1: zInvalid weight_decay value: )r   r   r   r   r   r   )
ValueErrordictsuper__init__)	selfr   r   r   r   r   r   r   defaults	__class__ N/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/optimizers/novograd.pyr   *   s&    
     zNovograd.__init__c                   s(   t  | | jD ]}|dd qd S )Nr   F)r   __setstate__param_groups
setdefault)r   stategroupr   r   r   r    D   s    
zNovograd.__setstate__NzCallable[[], T] | NonezT | None)closurereturnc                 C  s  d}|dk	r| }| j D ]}|d D ]}|jdkr8q&|jj}|jrNtd|d }| j| }t|dkrd|d< t|j|d< t	g 
|d j|d< |rt	g 
|d j|d	< |d |d  }}	|r|d	 }
|d
 \}}|d  d7  < tt|d}|	dkr|	| n|	|j|d| d |r`tj|
|	|
d |
 |d }n|	 |d }|| |d dkr|j|j|d d |d r|d|  ||| |jj||d  d q&q|S )zPerforms a single optimization step.

        Arguments:
            closure: A closure that reevaluates the model and returns the loss. Defaults to ``None``.
        Nr   z#Sparse gradients are not supported.r   r   stepexp_avg
exp_avg_sqmax_exp_avg_sqr   r      )alpha)outr   r   r   r   )r!   graddata	is_sparseRuntimeErrorr#   lentorch
zeros_likezerostodevicesumpowcopy_mul_add_maxsqrtdiv_)r   r%   lossr$   pr.   r   r#   r(   r)   r*   beta1beta2normdenomr   r   r   r'   I   sN    




zNovograd.step)r	   r
   r   r   FF)N)__name__
__module____qualname____doc__r   r    r'   __classcell__r   r   r   r   r      s          r   )
__future__r   collections.abcr   r   typingr   r3   torch.optimr   r   r   r   r   r   r   <module>   s   