U
    |Ph                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
 d dlmZmZ erjd dlmZmZ n(ed	e
jed
\ZZed	e
jed\ZZG dd dZdS )    )annotationsNbisect_right)Callable)TYPE_CHECKING)
IgniteInfo)min_versionoptional_import)EngineEventszignite.enginer
   r   c                   @  s   e Zd ZdZd$ddddddd	d
dZdd ZdddddZdddddZedddddddddZ	edddddddZ
eddddddddZeddd ddd!d"d#ZdS )%ParamSchedulerHandlera^  
    General purpose scheduler for parameters values. By default it can schedule in a linear, exponential, step or
    multistep function. One can also pass Callables to have customized scheduling logic.

    Args:
        parameter_setter (Callable): Function that sets the required parameter
        value_calculator (Union[str,Callable]): Either a string ('linear', 'exponential', 'step' or 'multistep')
         or Callable for custom logic.
        vc_kwargs (Dict): Dictionary that stores the required parameters for the value_calculator.
        epoch_level (bool): Whether the step is based on epoch or iteration. Defaults to False.
        name (Optional[str]): Identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
        event (Optional[str]): Event to which the handler attaches. Defaults to Events.ITERATION_COMPLETED.
    FNr   zstr | Callabledictboolz
str | None)parameter_settervalue_calculator	vc_kwargsepoch_levelnameeventc                 C  sb   || _ |d k	r|ntj| _| j| j| j| jd| _|| _	|| _
| j|d| _t|| _|| _d S )N)linearexponentialstepZ	multistep)r   )r   r   ITERATION_COMPLETEDr   _linear_exponential_step
_multistep_calculators_parameter_setter
_vc_kwargs_get_value_calculator_value_calculatorlogging	getLoggerlogger_name)selfr   r   r   r   r   r    r'   W/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/parameter_scheduler.py__init__,   s    	zParamSchedulerHandler.__init__c                 C  s>   t |tr| j| S t|r |S tdt| j  dd S )Nz.value_calculator must be either a string from z or a Callable.)
isinstancestrr   callable
ValueErrorlistkeys)r&   r   r'   r'   r(   r    F   s    

z+ParamSchedulerHandler._get_value_calculatorr
   None)enginereturnc                 C  s@   | j r|jj| jd< n|jj| jd< | jf | j}| | d S )Ncurrent_step)r   stateepochr   	iterationr!   r   )r&   r1   	new_valuer'   r'   r(   __call__O   s
    zParamSchedulerHandler.__call__c                 C  s$   | j dkr|j| _|| j|  dS )zT
        Args:
            engine: Ignite Engine that is used for training.
        N)r%   r$   add_event_handlerr   )r&   r1   r'   r'   r(   attachX   s    
zParamSchedulerHandler.attachfloatint)initial_valuestep_constantstep_max_value	max_valuer3   r2   c                 C  s@   ||krd}n*||kr ||  }n||  ||  ||  }| | S )a|  
        Keeps the parameter value to zero until step_zero steps passed and then linearly increases it to 1 until an
        additional step_one steps passed. Continues the trend until it reaches max_value.

        Args:
            initial_value (float): Starting value of the parameter.
            step_constant (int): Step index until parameter's value is kept constant.
            step_max_value (int): Step index at which parameter's value becomes max_value.
            max_value (float): Max parameter value.
            current_step (int): Current step index.

        Returns:
            float: new parameter value
        g        r'   )r=   r>   r?   r@   r3   deltar'   r'   r(   r   a   s    
zParamSchedulerHandler._linear)r=   gammar3   r2   c                 C  s   | ||  S )a  
        Decays the parameter value by gamma every step.

        Based on the closed form of ExponentialLR from Pytorch:
        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.ExponentialLR.html.

        Args:
            initial_value (float): Starting value of the parameter.
            gamma (float): Multiplicative factor of parameter value decay.
            current_step (int): Current step index.

        Returns:
            float: new parameter value
        r'   )r=   rB   r3   r'   r'   r(   r   |   s    z"ParamSchedulerHandler._exponential)r=   rB   	step_sizer3   r2   c                 C  s   | |||   S )a  
        Decays the parameter value by gamma every step_size.

        Based on StepLR from Pytorch:
        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.StepLR.html.

        Args:
            initial_value (float): Starting value of the parameter.
            gamma (float): Multiplicative factor of parameter value decay.
            step_size (int): Period of parameter value decay.
            current_step (int): Current step index.

        Returns
            float: new parameter value
        r'   )r=   rB   rC   r3   r'   r'   r(   r      s    zParamSchedulerHandler._stepz	list[int])r=   rB   
milestonesr3   r2   c                 C  s   | |t ||  S )aO  
        Decays the parameter value by gamma once the number of steps reaches one of the milestones.

        Based on MultiStepLR from Pytorch.
        https://pytorch.org/docs/stable/generated/torch.optim.lr_scheduler.MultiStepLR.html.

        Args:
            initial_value (float): Starting value of the parameter.
            gamma (float): Multiplicative factor of parameter value decay.
            milestones (List[int]): List of step indices. Must be increasing.
            current_step (int): Current step index.

        Returns:
            float: new parameter value
        r   )r=   rB   rD   r3   r'   r'   r(   r      s    z ParamSchedulerHandler._multistep)FNN)__name__
__module____qualname____doc__r)   r    r8   r:   staticmethodr   r   r   r   r'   r'   r'   r(   r      s       			r   )
__future__r   r"   bisectr   collections.abcr   typingr   monai.configr   monai.utilsr   r	   ignite.enginer
   r   OPT_IMPORT_VERSION_r   r'   r'   r'   r(   <module>   s   