o
    i                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
mZmZ er1d dlmZmZ nede
jed	\ZZede
jed
\ZZG dd dZdS )    )annotationsNbisect_right)Callable)TYPE_CHECKING)
IgniteInfomin_versionoptional_import)EngineEventszignite.enginer
   r   c                   @  sv   e Zd ZdZ			d.d/ddZdd Zd0ddZd0ddZed1d"d#Z	ed2d%d&Z
ed3d(d)Zed4d,d-ZdS )5ParamSchedulerHandlera^  
    General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or
    multistep function. One can also pass Callables to have customized scheduling logic.

    Args:
        parameter_setter (Callable): Function that sets the required parameter
        value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')
         or Callable for custom logic.
        vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.
        epoch_level (bool): Whether the step is based on epoch or iteration. Defaults to False.
        name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
        event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.
    FNparameter_setterr   value_calculatorstr | Callable	vc_kwargsdictepoch_levelboolname
str | Noneeventc                 C  sb   || _ |d ur	|ntj| _| j| j| j| jd| _|| _	|| _
| j|d| _t|| _|| _d S )N)linearexponentialstepZ	multistep)r   )r   r   ITERATION_COMPLETEDr   _linear_exponential_step
_multistep_calculators_parameter_setter
_vc_kwargs_get_value_calculator_value_calculatorlogging	getLoggerlogger_name)selfr   r   r   r   r   r    r)   d/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/parameter_scheduler.py__init__+   s   	
zParamSchedulerHandler.__init__c                 C  s:   t |tr
| j| S t|r|S tdt| j  d)Nz.value_calculator must be either a string from z or a Callable.)
isinstancestrr   callable
ValueErrorlistkeys)r(   r   r)   r)   r*   r"   E   s   

z+ParamSchedulerHandler._get_value_calculatorenginer
   returnNonec                 C  sD   | j r|jj| jd< n|jj| jd< | jdi | j}| | d S )Ncurrent_stepr)   )r   stateepochr!   	iterationr#   r    )r(   r2   	new_valuer)   r)   r*   __call__N   s
   zParamSchedulerHandler.__call__c                 C  s$   | j du r	|j| _|| j|  dS )zT
        Args:
            engine: Ignite Engine that is used for training.
        N)r'   r&   add_event_handlerr   )r(   r2   r)   r)   r*   attachW   s   
zParamSchedulerHandler.attachinitial_valuefloatstep_constantintstep_max_value	max_valuer5   c                 C  sL   ||kr
d}| | S ||kr||  }| | S ||  ||  ||  }| | S )a|  
        Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an
        additional step_one steps passed. Continues the trend until it reaches max_value.

        Args:
            initial_value (float): Starting value of the parameter.
            step_constant (int): Step index until parameter's value is kept constant.
            step_max_value (int): Step index at which parameter's value becomes max_value.
            max_value (float): Max parameter value.
            current_step (int): Current step index.

        Returns:
            float: new parameter value
        g        r)   )r=   r?   rA   rB   r5   deltar)   r)   r*   r   `   s   zParamSchedulerHandler._lineargammac                 C  s   | ||  S )a  
        Decays the parameter value by gamma every step.

        Based on the closed form of ExponentialLR from Pytorch:
        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html.

        Args:
            initial_value (float): Starting value of the parameter.
            gamma (float): Multiplicative factor of parameter value decay.
            current_step (int): Current step index.

        Returns:
            float: new parameter value
        r)   )r=   rD   r5   r)   r)   r*   r   {   s   z"ParamSchedulerHandler._exponential	step_sizec                 C  s   | |||   S )a  
        Decays the parameter value by gamma every step_size.

        Based on StepLR from Pytorch:
        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html.

        Args:
            initial_value (float): Starting value of the parameter.
            gamma (float): Multiplicative factor of parameter value decay.
            step_size (int): Period of parameter value decay.
            current_step (int): Current step index.

        Returns
            float: new parameter value
        r)   )r=   rD   rE   r5   r)   r)   r*   r      s   zParamSchedulerHandler._step
milestones	list[int]c                 C  s   | |t ||  S )aO  
        Decays the parameter value by gamma once the number of steps reaches one of the milestones.

        Based on MultiStepLR from Pytorch.
        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html.

        Args:
            initial_value (float): Starting value of the parameter.
            gamma (float): Multiplicative factor of parameter value decay.
            milestones (List[int]): List of step indices. Must be increasing.
            current_step (int): Current step index.

        Returns:
            float: new parameter value
        r   )r=   rD   rF   r5   r)   r)   r*   r      s   z ParamSchedulerHandler._multistep)FNN)r   r   r   r   r   r   r   r   r   r   r   r   )r2   r
   r3   r4   )r=   r>   r?   r@   rA   r@   rB   r>   r5   r@   r3   r>   )r=   r>   rD   r>   r5   r@   r3   r>   )
r=   r>   rD   r>   rE   r@   r5   r@   r3   r>   )
r=   r>   rD   r>   rF   rG   r5   r@   r3   r>   )__name__
__module____qualname____doc__r+   r"   r:   r<   staticmethodr   r   r   r   r)   r)   r)   r*   r      s"    
	
		r   )
__future__r   r$   bisectr   collections.abcr   typingr   monai.utilsr   r   r	   ignite.enginer
   r   OPT_IMPORT_VERSION_r   r)   r)   r)   r*   <module>   s   