o
    i                     @  s   d dl mZ d dlmZ d dlmZ d dlmZmZm	Z	 e	dej
ed\ZZe	dej
ed\ZZer9d d	lmZ ne	dej
ed
dd\ZZG dd dZdS )    )annotations)Callable)TYPE_CHECKING)
IgniteInfomin_versionoptional_importzignite.engineEventszignite.handlersEarlyStopping)Enginer
   	decorator)as_typec                   @  sB   e Zd ZdZ				ddddZdddZd ddZdddZdS )!EarlyStopHandleru
  
    EarlyStopHandler acts as an Ignite handler to stop training if no improvement after a given number of events.
    It‘s based on the `EarlyStopping` handler in ignite.

    Args:
        patience: number of events to wait if no improvement and then stop the training.
        score_function: It should be a function taking a single argument, an :class:`~ignite.engine.engine.Engine`
            object that the handler attached, can be a trainer or validator, and return a score `float`.
            an improvement is considered if the score is higher.
        trainer: trainer engine to stop the run if no improvement, if None, must call `set_trainer()` before training.
        min_delta: a minimum increase in the score to qualify as an improvement,
            i.e. an increase of less than or equal to `min_delta`, will count as no improvement.
        cumulative_delta: if True, `min_delta` defines an increase since the last `patience` reset, otherwise,
            it defines an increase after the last event, default to False.
        epoch_level: check early stopping for every epoch or every iteration of the attached engine,
            `True` is epoch level, `False` is iteration level, default to epoch level.

    Note:
        If in distributed training and uses loss value of every iteration to detect early stopping,
        the values may be different in different ranks. When using this handler with distributed training,
        please also note that to prevent "dist.destroy_process_group()" hangs, you can use an "all_reduce" operation
        to synchronize the stop signal across all ranks. The mechanism can be implemented in the `score_function`. The following
        is an example:

        .. code-block:: python

            import os

            import torch
            import torch.distributed as dist


            def score_function(engine):
                val_metric = engine.state.metrics["val_mean_dice"]
                if dist.is_initialized():
                    device = torch.device("cuda:" + os.environ["LOCAL_RANK"])
                    val_metric = torch.tensor([val_metric]).to(device)
                    dist.all_reduce(val_metric, op=dist.ReduceOp.SUM)
                    val_metric /= dist.get_world_size()
                    return val_metric.item()
                return val_metric


        User may attach this handler to validator engine to detect validation metrics and stop the training,
        in this case, the `score_function` is executed on validator engine and `trainer` is the trainer engine.

    N        FTpatienceintscore_functionr   trainerEngine | None	min_deltafloatcumulative_deltaboolepoch_levelreturnNonec                 C  s@   || _ || _|| _|| _|| _d | _|d ur| j|d d S d S )N)r   )r   r   r   r   r   _handlerset_trainer)selfr   r   r   r   r   r    r   b/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/earlystop_handler.py__init__O   s   	zEarlyStopHandler.__init__enginer
   c                 C  s*   | j r|tj|  dS |tj|  dS )zg
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)r   add_event_handlerr   EPOCH_COMPLETEDITERATION_COMPLETEDr   r!   r   r   r   attachb   s   zEarlyStopHandler.attachc                 C  s    t | j| j|| j| jd| _dS )z\
        Set trainer to execute early stop if not setting properly in `__init__()`.
        )r   r   r   r   r   N)r	   r   r   r   r   r   )r   r   r   r   r   r   l   s   zEarlyStopHandler.set_trainerc                 C  s    | j d u r	td|  | d S )NzGplease set trainer in __init__() or call set_trainer() before training.)r   RuntimeErrorr%   r   r   r   __call__x   s   
zEarlyStopHandler.__call__)Nr   FT)r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r!   r
   r   r   )r   r
   r   r   )__name__
__module____qualname____doc__r    r&   r   r(   r   r   r   r   r      s    4


r   N)
__future__r   collections.abcr   typingr   monai.utilsr   r   r   OPT_IMPORT_VERSIONr   _r	   ignite.enginer
   r   r   r   r   r   <module>   s   
