U
    vPh                      @  s   d dl mZ d dlZd dlZd dlmZ d dlZd dlmZ d dl	m
Z
 d dlmZmZ edejed\ZZed	ejed
\ZZerd dlmZ nedejed\ZZG dd dZdS )    )annotationsN)TYPE_CHECKING)
IgniteInfo)copy_model_state)min_versionoptional_importzignite.engineEventszignite.handlers
Checkpoint)Enginer
   c                	   @  sL   e Zd ZdZdddddddd	d
ddZdd	dddZdd	dddZdS )CheckpointLoaderu  
    CheckpointLoader acts as an Ignite handler to load checkpoint data from file.
    It can load variables for network, optimizer, lr_scheduler, etc.
    If saving checkpoint after `torch.nn.DataParallel`, need to save `model.module` instead
    as PyTorch recommended and then use this loader to load the model.

    Usage example::

        trainer = SupervisedTrainer(...)
        save_dict = {
            "trainer": trainer,
            "net": network,
            "opt": optimizer,
            "lr": lr_scheduler,
        }

        map_location = "cuda:0"
        # checkpoint needs to have same save_dict for this to work
        handler = CheckpointLoader(load_path="/test/checkpoint.pt", load_dict=save_dict, map_location=map_location, strict=True)
        handler(trainer)
        # Trainer now has the same state as stored, including the number of epochs and iterations completed
        # so you can resume an interrupted training at the place where it left

    Args:
        load_path: the file path of checkpoint, it should be a PyTorch `pth` file.
        load_dict: target objects that load checkpoint to. examples::

            {'network': net, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}

        name: identifier of logging.logger to use, if None, defaulting to ``engine.logger``.
        map_location: when loading the module for distributed training/evaluation,
            need to provide an appropriate map_location argument to prevent a process
            to step into others’ devices. If map_location is missing, torch.load will
            first load the module to CPU and then copy each parameter to where it was
            saved, which would result in all processes on the same machine using the
            same set of devices.
        strict: whether to strictly enforce that the keys and data shape in the `state_dict` of every item
            of `load_dict` match the `state_dict` of the corresponding items of checkpoint, default to `True`.
        strict_shape: whether to enforce the data shape of the matched layers in the checkpoint,
            `if `False`, it will skip the layers that have different data shape with checkpoint content,
            and ignore the `strict` arg. this can be useful advanced feature for transfer learning.
            users should totally understand which layers will have different shape. default to `True`.

    Note: if `strict_shape=False`, will only load checkpoint for `torch.nn.Module` and skip other
        items in the `load_dict`. For example, if the shape of some layers in current model can't
        match the checkpoint, the `parameter_group` of current optimizer may also can't match the
        checkpoint, so skip loading checkpoint for optimizer.

        For more details about loading checkpoint, please refer to:
        https://pytorch.org/ignite/v0.4.5/generated/ignite.handlers.checkpoint.Checkpoint.html
        #ignite.handlers.checkpoint.Checkpoint.load_objects.
        https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.load_state_dict.

    NTstrdictz
str | Nonezdict | NoneboolNone)	load_path	load_dictnamemap_locationstrictstrict_shapereturnc                 C  sv   |d krt d|| _|d ks*t|dkr2t dt|| _|| _|| _|| _|rf|sft	
d d}|| _|| _d S )Nz+must provide clear path to load checkpoint.r   z$must provide target objects to load.z=as `strict_shape` is already False, change `strict` to False.F)AssertionErrorr   lenlogging	getLoggerloggerr   _namer   warningswarnr   r   )selfr   r   r   r   r   r    r    U/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/checkpoint_loader.py__init__X   s    	
zCheckpointLoader.__init__r
   )enginer   c                 C  s$   | j dkr|j| _|tj|  dS )g
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)r   r   add_event_handlerr   STARTED)r   r#   r    r    r!   attachp   s    
zCheckpointLoader.attachc           	      C  s  t j| j| jd}t| j d \}}t| jdkrF||krF||i}| jsg }| j D ]B\}}t	|t j
jrt||ddd ||< qZtd || qZ|D ]}| j| q|jj}tj| j|| jd |dk	r|jj|krtd	|jj d
| d||j_| jd| j  dS )r$   )r   r      F)inplacezO`strict_shape` is False, load checkpoint for model, skip others in `load_dict`.)to_load
checkpointr   NzEpoch count (z>) in checkpoint is larger than the `engine.state.max_epochs` (z) of engine. To further train from checkpoint, construct trainer with `max_epochs` larger than checkpoint's epoch count. To use checkpoint for inference, no need to load state_dict for the engine.zRestored all variables from )torchloadr   r   listr   itemsr   r   
isinstancennModuler   r   r   appendpopstate
max_epochsr	   load_objectsr   epoch
ValueErrorr   info)	r   r#   r+   k_Z	pop_itemsobjiZprior_max_epochsr    r    r!   __call__y   s*    
zCheckpointLoader.__call__)NNTT)__name__
__module____qualname____doc__r"   r'   r?   r    r    r    r!   r       s   ;    	r   )
__future__r   r   r   typingr   r,   monai.configr   monai.networks.utilsr   monai.utilsr   r   OPT_IMPORT_VERSIONr   r<   r	   ignite.enginer
   r   r    r    r    r!   <module>   s   