o
    i>                     @  s   d dl mZ d dlZd dlZd dlZd dlmZ d dlmZm	Z	 d dl
mZmZmZmZ edejed\ZZerGd dlmZ d d	lmZmZ nedejed
\ZZedejed\ZZedejed\ZZG dd dZdS )    )annotationsN)Mapping)TYPE_CHECKINGAny)
IgniteInfo	is_scalarmin_versionoptional_importzignite.engineEvents)Engine)
Checkpoint	DiskSaverr   zignite.handlersr   r   c                   @  s|   e Zd ZdZ														d5d6d d!Zd7d#d$Zd8d'd(Zd)d* Zd8d+d,Zd9d/d0Z	d8d1d2Z
d8d3d4ZdS ):CheckpointSaveraR  
    CheckpointSaver acts as an Ignite handler to save checkpoint data into files.
    It supports to save according to metrics result, epoch number, iteration number
    and last model or exception.

    Args:
        save_dir: the target directory to save the checkpoints.
        save_dict: source objects that save to the checkpoint. examples::

            {'network': net, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

        name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
        file_prefix: prefix for the filenames to which objects will be saved.
        save_final: whether to save checkpoint or session at final iteration or exception.
            If checkpoints are to be saved when an exception is raised, put this handler before
            `StatsHandler` in the handler list, because the logic with Ignite can only trigger
            the first attached handler for `EXCEPTION_RAISED` event.
        final_filename: set a fixed filename to save the final model if `save_final=True`.
            If None, default to `checkpoint_final_iteration=N.pt`.
        save_key_metric: whether to save checkpoint or session when the value of key_metric is
            higher than all the previous values during training.keep 4 decimal places of metric,
            checkpoint name is: {file_prefix}_key_metric=0.XXXX.pth.
        key_metric_name: the name of key_metric in ignite metrics dictionary.
            If None, use `engine.state.key_metric` instead.
        key_metric_n_saved: save top N checkpoints or sessions, sorted by the value of key
            metric in descending order.
        key_metric_filename: set a fixed filename to set the best metric model, if not None,
            `key_metric_n_saved` should be 1 and only keep the best metric model.
        key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file.
            if `True`, then will save an object in the checkpoint file with key `checkpointer` to be
            consistent with the `include_self` arg of `Checkpoint` in ignite:
            https://pytorch.org/ignite/v0.4.5/generated/ignite.handlers.checkpoint.Checkpoint.html.
            typically, it's used to resume training and compare current metric with previous N values.
        key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise,
            save the first equally scored model. default to `False`.
        key_metric_negative_sign: whether adding a negative sign to the metric score to compare metrics,
            because for error-like metrics, smaller is better(objects with larger score are retained).
            default to `False`.
        epoch_level: save checkpoint during training for every N epochs or every N iterations.
            `True` is epoch level, `False` is iteration level.
        save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.
        n_saved: save latest N checkpoints of epoch level or iteration level, 'None' is to save all.

    Note:
        CheckpointHandler can be used during training, validation or evaluation.
        example of saved files:

            - checkpoint_iteration=400.pt
            - checkpoint_iteration=800.pt
            - checkpoint_epoch=1.pt
            - checkpoint_final_iteration=1000.pt
            - checkpoint_key_metric=0.9387.pt

    N F   Tr   save_dirstr	save_dictdictname
str | Nonefile_prefix
save_finalboolfinal_filenamesave_key_metrickey_metric_namekey_metric_n_savedintkey_metric_filenamekey_metric_save_statekey_metric_greater_or_equalkey_metric_negative_signepoch_levelsave_intervaln_saved
int | NonereturnNonec              
     sH  |d u rt d|_|d urt|dkst d|_t|_|_|_d _	d _
d _|_|_G dd dt}|rWdd
d}tj|jjd||dd_	|rd fdd}|
d urm|	dkrmtdtj|j|
d||d|	||d_
|dkrdfdd}tj|jd||jrdnd|d_d S d S )Nz/must provide directory to save the checkpoints.r   z$must provide source objects to save.c                      sB   e Zd ZdZdd fddZdd fddZd fddZ  ZS )z,CheckpointSaver.__init__.<locals>._DiskSaverzK
            Enhance the DiskSaver to support fixed filename.

            Ndirnamer   filenamer   c                   s   t  j|ddd || _d S )NF)r)   Zrequire_emptyatomic)super__init__r*   )selfr)   r*   	__class__ a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/checkpoint_saver.pyr-      s   
z5CheckpointSaver.__init__.<locals>._DiskSaver.__init__
checkpointr   metadataMapping | Noner'   r(   c                   s&   | j d ur| j }t j|||d d S )N)r3   r*   r4   )r*   r,   __call__)r.   r3   r*   r4   r/   r1   r2   r6      s   
z5CheckpointSaver.__init__.<locals>._DiskSaver.__call__c                   s"   | j d ur| j }t j|d d S )N)r*   )r*   r,   remove)r.   r*   r/   r1   r2   r7      s   
z3CheckpointSaver.__init__.<locals>._DiskSaver.removeN)r)   r   r*   r   )r3   r   r*   r   r4   r5   r'   r(   )r*   r   r'   r(   )__name__
__module____qualname____doc__r-   r6   r7   __classcell__r1   r1   r/   r2   
_DiskSaver{   s
    r>   enginer   r'   r   c                 S  s   | j jS r8   )state	iterationr?   r1   r1   r2   _final_func   s   z-CheckpointSaver.__init__.<locals>._final_func)r)   r*   Zfinal_iteration)to_savesave_handlerfilename_prefixscore_function
score_namec                   sz   t  tr }nt| jdr| jj}ntd  d| jj| }t|s3t	d| d| d dS r9d| S d| S )Nr   z>Incompatible values: save_key_metric=True and key_metric_name=.zkey metric is not a scalar value, skip metric comparison and don't save a model.please use other metrics as key metric, or change the `reduction` mode to 'mean'.got metric: =r   )

isinstancer   hasattrr@   r   
ValueErrormetricsr   warningswarn)r?   metric_namemetric)r   r"   r1   r2   _score_func   s$   


z-CheckpointSaver.__init__.<locals>._score_funcr   zSif using fixed filename to save the best metric model, we should only save 1 model.
key_metric)rD   rE   rF   rG   rH   r%   include_selfZgreater_or_equalc                   s    j r| jjS | jjS r8   )r#   r@   epochrA   rB   )r.   r1   r2   _interval_func   s   z0CheckpointSaver.__init__.<locals>._interval_func)r)   rW   rA   )rD   rE   rF   rG   rH   r%   )r?   r   r'   r   )AssertionErrorr   lenr   logging	getLoggerloggerr#   r$   _final_checkpoint_key_metric_checkpoint_interval_checkpoint_name_final_filenamer   r   rN   )r.   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r>   rC   rT   rX   r1   )r   r"   r.   r2   r-   Y   sb   

zCheckpointSaver.__init__
state_dictc                 C  s(   | j dur| j | dS td dS )a  
        Utility to resume the internal state of key metric tracking list if configured to save
        checkpoints based on the key metric value.
        Note to set `key_metric_save_state=True` when saving the previous checkpoint.

        Example::

            CheckpointSaver(
                ...
                save_key_metric=True,
                key_metric_save_state=True,  # config to also save the state of this saver
            ).attach(engine)
            engine.run(...)

            # resumed training with a new CheckpointSaver
            saver = CheckpointSaver(save_key_metric=True, ...)
            # load the previous key metric tracking list into saver
            CheckpointLoader("/test/model.pt"), {"checkpointer": saver}).attach(engine)

        NzFno key metric checkpoint saver to resume the key metric tracking list.)r_   load_state_dictrP   rQ   )r.   rc   r1   r1   r2   rd      s   
zCheckpointSaver.load_state_dictr?   r   c                 C  s   | j du r	|j| _| jdur|tj| j |tj| j | j	dur+|tj
| j | jdurO| jrA|tj
| jd| j dS |tj| jd| j dS dS )zg
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)every)ra   r]   r^   add_event_handlerr
   Z	COMPLETED	completedZEXCEPTION_RAISEDexception_raisedr_   EPOCH_COMPLETEDmetrics_completedr`   r#   r$   interval_completedITERATION_COMPLETEDr.   r?   r1   r1   r2   attach   s   



zCheckpointSaver.attachc                 C  sX   | j d ur(| j j}t|dkr*|d}| j j|j | jd|j  d S d S d S )Nr   z)Deleted previous saved final checkpoint: )	r^   Z_savedrZ   poprE   r7   r*   r]   info)r.   saveditemr1   r1   r2   _delete_previous_final_ckpt   s   

z+CheckpointSaver._delete_previous_final_ckptc                 C  s   t | js	td|   | | | jdu rtt| jds#td| jdur2tj	| j
| j}n| jj}| jd|  dS )zCallback for train or validation/evaluation completed Event.
        Save final checkpoint if configure save_final is True.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        0Error: _final_checkpoint function not specified.Nrp   .Error, provided logger has not info attribute.z)Train completed, saved final checkpoint: callabler^   rY   rs   r]   rM   rb   ospathjoinr   Zlast_checkpointrp   )r.   r?   _final_checkpoint_pathr1   r1   r2   rg     s   



zCheckpointSaver.completede	Exceptionc                 C  s   t | js	td|   | | | jdu rtt| jds#td| jdur2tj	| j
| j}n| jj}| jd|  |)a  Callback for train or validation/evaluation exception raised Event.
        Save current data as final checkpoint if configure save_final is True. This callback may be skipped
        because the logic with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            e: the exception caught in Ignite during engine.run().
        rt   Nrp   ru   z-Exception raised, saved the last checkpoint: rv   )r.   r?   r|   r{   r1   r1   r2   rh     s   
	


z CheckpointSaver.exception_raisedc                 C  s    t | js	td| | dS )zCallback to compare metrics and save models in train or validation when epoch completed.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        z5Error: _key_metric_checkpoint function not specified.N)rw   r_   rY   rm   r1   r1   r2   rj   2  s   
z!CheckpointSaver.metrics_completedc                 C  sx   t | js	td| | | jdu rtt| jdstd| jr/| jd|jj  dS | jd|jj	  dS )zCallback for train epoch/iteration completed Event.
        Save checkpoint if configure save_interval = N

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        z3Error: _interval_checkpoint function not specified.Nrp   ru   zSaved checkpoint at epoch: zSaved checkpoint at iteration: )
rw   r`   rY   r]   rM   r#   rp   r@   rW   rA   rm   r1   r1   r2   rk   <  s   


z"CheckpointSaver.interval_completed)Nr   FNFNr   NFFFTr   N)"r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r   r!   r   r"   r   r#   r   r$   r   r%   r&   r'   r(   )rc   r   r'   r(   )r?   r   r'   r(   )r?   r   r|   r}   r'   r(   )r9   r:   r;   r<   r-   rd   rn   rs   rg   rh   rj   rk   r1   r1   r1   r2   r   !   s0    ;
v




r   )
__future__r   r[   rx   rP   collections.abcr   typingr   r   monai.utilsr   r   r   r	   OPT_IMPORT_VERSIONr
   _ignite.enginer   Zignite.handlersr   r   r   r1   r1   r1   r2   <module>   s   