U
    {Ph>                     @  s   d dl mZ d dlZd dlZd dlZd dlmZ d dlmZm	Z	 d dl
mZ d dlmZmZmZ edejed\ZZerd d	lmZ d d
lmZmZ n<edejed\ZZedejed\ZZedejed\ZZG dd dZdS )    )annotationsN)Mapping)TYPE_CHECKINGAny)
IgniteInfo)	is_scalarmin_versionoptional_importzignite.engineEvents)Engine)
Checkpoint	DiskSaverr   zignite.handlersr   r   c                   @  s   e Zd ZdZd%dd	d
ddd
dd
dd
ddddddddddZd	ddddZdddddZdd ZdddddZdddddd Z	dddd!d"Z
dddd#d$ZdS )&CheckpointSaveraR  
    CheckpointSaver acts as an Ignite handler to save checkpoint data into files.
    It supports to save according to metrics result, epoch number, iteration number
    and last model or exception.

    Args:
        save_dir: the target directory to save the checkpoints.
        save_dict: source objects that save to the checkpoint. examples::

            {'network': net, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

        name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
        file_prefix: prefix for the filenames to which objects will be saved.
        save_final: whether to save checkpoint or session at final iteration or exception.
            If checkpoints are to be saved when an exception is raised, put this handler before
            `StatsHandler` in the handler list, because the logic with Ignite can only trigger
            the first attached handler for `EXCEPTION_RAISED` event.
        final_filename: set a fixed filename to save the final model if `save_final=True`.
            If None, default to `checkpoint_final_iteration=N.pt`.
        save_key_metric: whether to save checkpoint or session when the value of key_metric is
            higher than all the previous values during training.keep 4 decimal places of metric,
            checkpoint name is: {file_prefix}_key_metric=0.XXXX.pth.
        key_metric_name: the name of key_metric in ignite metrics dictionary.
            If None, use `engine.state.key_metric` instead.
        key_metric_n_saved: save top N checkpoints or sessions, sorted by the value of key
            metric in descending order.
        key_metric_filename: set a fixed filename to set the best metric model, if not None,
            `key_metric_n_saved` should be 1 and only keep the best metric model.
        key_metric_save_state: whether to save the tracking list of key metric in the checkpoint file.
            if `True`, then will save an object in the checkpoint file with key `checkpointer` to be
            consistent with the `include_self` arg of `Checkpoint` in ignite:
            https://pytorch.org/ignite/v0.4.5/generated/ignite.handlers.checkpoint.Checkpoint.html.
            typically, it's used to resume training and compare current metric with previous N values.
        key_metric_greater_or_equal: if `True`, the latest equally scored model is stored. Otherwise,
            save the first equally scored model. default to `False`.
        key_metric_negative_sign: whether adding a negative sign to the metric score to compare metrics,
            because for error-like metrics, smaller is better(objects with larger score are retained).
            default to `False`.
        epoch_level: save checkpoint during training for every N epochs or every N iterations.
            `True` is epoch level, `False` is iteration level.
        save_interval: save checkpoint every N epochs, default is 0 to save no checkpoint.
        n_saved: save latest N checkpoints of epoch level or iteration level, 'None' is to save all.

    Note:
        CheckpointHandler can be used during training, validation or evaluation.
        example of saved files:

            - checkpoint_iteration=400.pt
            - checkpoint_iteration=800.pt
            - checkpoint_epoch=1.pt
            - checkpoint_final_iteration=1000.pt
            - checkpoint_key_metric=0.9387.pt

    N F   Tr   strdict
str | Noneboolintz
int | NoneNone)save_dir	save_dictnamefile_prefix
save_finalfinal_filenamesave_key_metrickey_metric_namekey_metric_n_savedkey_metric_filenamekey_metric_save_statekey_metric_greater_or_equalkey_metric_negative_signepoch_levelsave_intervaln_savedreturnc              
     s\  |d krt d|_|d k	r*t|dks2t d|_t|_|_|_d _	d _
d _|_|_G dd dt}|rdddd	d
}tj|jjd||dd_	|rddd fdd}|
d k	r|	dkrtdtj|j|
d||d|	||d_
|dkrXdddfdd}tj|jd||jrLdnd|d_d S )Nz/must provide directory to save the checkpoints.r   z$must provide source objects to save.c                      sX   e Zd ZdZdddd fddZdddd	d
d fddZdd
d fddZ  ZS )z,CheckpointSaver.__init__.<locals>._DiskSaverzK
            Enhance the DiskSaver to support fixed filename.

            Nr   r   dirnamefilenamec                   s   t  j|ddd || _d S )NF)r)   Zrequire_emptyatomic)super__init__r*   )selfr)   r*   	__class__ T/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/checkpoint_saver.pyr-      s    z5CheckpointSaver.__init__.<locals>._DiskSaver.__init__r   zMapping | Noner   )
checkpointr*   metadatar'   c                   s&   | j d k	r| j }t j|||d d S )N)r3   r*   r4   )r*   r,   __call__)r.   r3   r*   r4   r/   r1   r2   r5      s    
z5CheckpointSaver.__init__.<locals>._DiskSaver.__call__)r*   r'   c                   s"   | j d k	r| j }t j|d d S )N)r*   )r*   r,   remove)r.   r*   r/   r1   r2   r6      s    
z3CheckpointSaver.__init__.<locals>._DiskSaver.remove)N)N)__name__
__module____qualname____doc__r-   r5   r6   __classcell__r1   r1   r/   r2   
_DiskSaver|   s   r<   r   r   enginer'   c                 S  s   | j jS N)state	iterationr>   r1   r1   r2   _final_func   s    z-CheckpointSaver.__init__.<locals>._final_funcr(   Zfinal_iteration)to_savesave_handlerfilename_prefixscore_function
score_namec                   sv   t  tr }n&t| jdr&| jj}ntd  d| jj| }t|sft	d| d| d dS rndnd| S )Nr   z>Incompatible values: save_key_metric=True and key_metric_name=.zkey metric is not a scalar value, skip metric comparison and don't save a model.please use other metrics as key metric, or change the `reduction` mode to 'mean'.got metric: =r   )

isinstancer   hasattrr@   r   
ValueErrormetricsr   warningswarn)r>   metric_namemetric)r   r#   r1   r2   _score_func   s    


z-CheckpointSaver.__init__.<locals>._score_funcr   zSif using fixed filename to save the best metric model, we should only save 1 model.
key_metric)rD   rE   rF   rG   rH   r&   include_selfZgreater_or_equalc                   s    j r| jjS | jjS r?   )r$   r@   epochrA   rB   )r.   r1   r2   _interval_func   s    z0CheckpointSaver.__init__.<locals>._interval_func)r)   rW   rA   )rD   rE   rF   rG   rH   r&   )AssertionErrorr   lenr   logging	getLoggerloggerr$   r%   _final_checkpoint_key_metric_checkpoint_interval_checkpoint_name_final_filenamer   r   rN   )r.   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r%   r&   r<   rC   rT   rX   r1   )r   r#   r.   r2   r-   Z   s`    

zCheckpointSaver.__init__)
state_dictr'   c                 C  s&   | j dk	r| j | n
td dS )a  
        Utility to resume the internal state of key metric tracking list if configured to save
        checkpoints based on the key metric value.
        Note to set `key_metric_save_state=True` when saving the previous checkpoint.

        Example::

            CheckpointSaver(
                ...
                save_key_metric=True,
                key_metric_save_state=True,  # config to also save the state of this saver
            ).attach(engine)
            engine.run(...)

            # resumed training with a new CheckpointSaver
            saver = CheckpointSaver(save_key_metric=True, ...)
            # load the previous key metric tracking list into saver
            CheckpointLoader("/test/model.pt"), {"checkpointer": saver}).attach(engine)

        NzFno key metric checkpoint saver to resume the key metric tracking list.)r_   load_state_dictrP   rQ   )r.   rc   r1   r1   r2   rd      s    
zCheckpointSaver.load_state_dictr   r=   c                 C  s   | j dkr|j| _| jdk	r<|tj| j |tj| j | j	dk	rV|tj
| j | jdk	r| jr|tj
| jd| j n|tj| jd| j dS )zg
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)every)ra   r]   r^   add_event_handlerr
   Z	COMPLETED	completedZEXCEPTION_RAISEDexception_raisedr_   EPOCH_COMPLETEDmetrics_completedr`   r$   r%   interval_completedITERATION_COMPLETEDr.   r>   r1   r1   r2   attach   s    



zCheckpointSaver.attachc                 C  sP   | j d k	rL| j j}t|dkrL|d}| j j|j | jd|j  d S )Nr   z)Deleted previous saved final checkpoint: )	r^   Z_savedrZ   poprE   r6   r*   r]   info)r.   saveditemr1   r1   r2   _delete_previous_final_ckpt   s    

z+CheckpointSaver._delete_previous_final_ckptc                 C  s   t | jstd|   | | | jdkr2tt| jdsFtd| jdk	rdtj	| j
| j}n| jj}| jd|  dS )zCallback for train or validation/evaluation completed Event.
        Save final checkpoint if configure save_final is True.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        0Error: _final_checkpoint function not specified.Nrp   .Error, provided logger has not info attribute.z)Train completed, saved final checkpoint: callabler^   rY   rs   r]   rM   rb   ospathjoinr   Zlast_checkpointrp   )r.   r>   _final_checkpoint_pathr1   r1   r2   rg     s    



zCheckpointSaver.completed	Exception)r>   er'   c                 C  s   t | jstd|   | | | jdkr2tt| jdsFtd| jdk	rdtj	| j
| j}n| jj}| jd|  |dS )a  Callback for train or validation/evaluation exception raised Event.
        Save current data as final checkpoint if configure save_final is True. This callback may be skipped
        because the logic with Ignite can only trigger the first attached handler for `EXCEPTION_RAISED` event.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
            e: the exception caught in Ignite during engine.run().
        rt   Nrp   ru   z-Exception raised, saved the last checkpoint: rv   )r.   r>   r}   r{   r1   r1   r2   rh     s    	



z CheckpointSaver.exception_raisedc                 C  s    t | jstd| | dS )zCallback to compare metrics and save models in train or validation when epoch completed.

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        z5Error: _key_metric_checkpoint function not specified.N)rw   r_   rY   rm   r1   r1   r2   rj   3  s    
z!CheckpointSaver.metrics_completedc                 C  sv   t | jstd| | | jdkr*tt| jds>td| jr\| jd|jj  n| jd|jj	  dS )zCallback for train epoch/iteration completed Event.
        Save checkpoint if configure save_interval = N

        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        z3Error: _interval_checkpoint function not specified.Nrp   ru   zSaved checkpoint at epoch: zSaved checkpoint at iteration: )
rw   r`   rY   r]   rM   r$   rp   r@   rW   rA   rm   r1   r1   r2   rk   =  s    


z"CheckpointSaver.interval_completed)Nr   FNFNr   NFFFTr   N)r7   r8   r9   r:   r-   rd   rn   rs   rg   rh   rj   rk   r1   r1   r1   r2   r   "   s.   ;              0v
r   )
__future__r   r[   rx   rP   collections.abcr   typingr   r   monai.configr   monai.utilsr   r   r	   OPT_IMPORT_VERSIONr
   _ignite.enginer   Zignite.handlersr   r   r   r1   r1   r1   r2   <module>   s   