o
    /i-                     @  sX   d dl mZ d dlmZmZ d dlmZ d dlZd dlm	Z	 edZ
G dd de	ZdS )	    )annotations)CallableIterable)TypeVarN)	OptimizerTc                      sH   e Zd ZdZ						dd fddZ fddZdd ddZ  ZS )!Novograda  
    Novograd based on `Stochastic Gradient Methods with Layer-wise Adaptive Moments for Training of Deep Networks
    <https://arxiv.org/pdf/1905.11286.pdf>`_.
    The code is adapted from the implementations in `Jasper for PyTorch
    <https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/SpeechRecognition/Jasper/common/optimizers.py>`_,
    and `OpenSeq2Seq <https://github.com/NVIDIA/OpenSeq2Seq/blob/master/open_seq2seq/optimizers/novograd.py>`_.

    Args:
        params: iterable of parameters to optimize or dicts defining parameter groups.
        lr: learning rate. Defaults to 1e-3.
        betas: coefficients used for computing running averages of gradient and its square. Defaults to (0.9, 0.98).
        eps: term added to the denominator to improve numerical stability. Defaults to 1e-8.
        weight_decay: weight decay (L2 penalty). Defaults to 0.
        grad_averaging: gradient averaging. Defaults to ``False``.
        amsgrad: whether to use the AMSGrad variant of this algorithm from the paper
            `On the Convergence of Adam and Beyond <https://arxiv.org/pdf/1904.09237.pdf>`_. Defaults to ``False``.
    MbP?g?g\(\?:0yE>r   Fparamsr   lrfloatbetastuple[float, float]epsweight_decaygrad_averagingboolamsgradc           	        s   d|krt d| d|krt d| d|d   kr"dk s,n t d|d  d|d   kr8dk sBn t d|d  d|krMt d	| t||||||d
}t || d S )Ng        zInvalid learning rate: zInvalid epsilon value: r   g      ?z#Invalid beta parameter at index 0:    z#Invalid beta parameter at index 1: zInvalid weight_decay value: )r   r   r   r   r   r   )
ValueErrordictsuper__init__)	selfr   r   r   r   r   r   r   defaults	__class__ [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/optimizers/novograd.pyr   *   s   
zNovograd.__init__c                   s(   t  | | jD ]}|dd q	d S )Nr   F)r   __setstate__param_groups
setdefault)r   stategroupr   r   r    r!   D   s   
zNovograd.__setstate__NclosureCallable[[], T] | NonereturnT | Nonec                 C  s  d}|dur	| }| j D ]}|d D ]}|jdu rq|jj}|jr%td|d }| j| }t|dkr\d|d< t|j|d< t	g 
|d j|d< |r\t	g 
|d j|d	< |d |d }}	|rk|d	 }
|d
 \}}|d  d7  < tt|d}|	dkr|	| n|	|j|d| d |rtj|
|	|
d |
 |d }n	|	 |d }|| |d dkr|j|j|d d |d r|d|  ||| |jj||d  d qq|S )zPerforms a single optimization step.

        Arguments:
            closure: A closure that reevaluates the model and returns the loss. Defaults to ``None``.
        Nr   z#Sparse gradients are not supported.r   r   stepexp_avg
exp_avg_sqmax_exp_avg_sqr   r      )alpha)outr   r   r   r   )r"   graddata	is_sparseRuntimeErrorr$   lentorch
zeros_likezerostodevicesumpowcopy_mul_add_maxsqrtdiv_)r   r&   lossr%   pr1   r   r$   r+   r,   r-   beta1beta2normdenomr   r   r    r*   I   sP   



4zNovograd.step)r	   r
   r   r   FF)r   r   r   r   r   r   r   r   r   r   r   r   r   r   )N)r&   r'   r(   r)   )__name__
__module____qualname____doc__r   r!   r*   __classcell__r   r   r   r    r      s    r   )
__future__r   collections.abcr   r   typingr   r6   torch.optimr   r   r   r   r   r   r    <module>   s   