U
    PhC                     @  s   d dl mZ d dlZd dlZd dlZd dlZd dlmZ d dlm	Z	m
Z
 d dlZd dlmZ d dlmZ dgZG dd dZdS )	    )annotationsN)
ModuleType)AnyHashable)DEFAULT_PROTOCOL)PathLikeStateCacherc                   @  sb   e Zd ZdZddeefddddddd	d
dZdddddddddZdddddZdd Z	dS )r   a  Class to cache and retrieve the state of an object.

    Objects can either be stored in memory or on disk. If stored on disk, they can be
    stored in a given directory, or alternatively a temporary location will be used.

    If necessary/possible, restored objects will be returned to their original device.

    Example:

    >>> state_cacher = StateCacher(memory_cache, cache_dir=cache_dir)
    >>> state_cacher.store("model", model.state_dict())
    >>> model.load_state_dict(state_cacher.retrieve("model"))
    NTboolzPathLike | Noner   intNone)	in_memory	cache_dirallow_overwritepickle_modulepickle_protocolreturnc                 C  sN   || _ |dkrt n|| _tj| js2td|| _|| _	|| _
i | _dS )aJ  Constructor.

        Args:
            in_memory: boolean to determine if the object will be cached in memory or on
                disk.
            cache_dir: directory for data to be cached if `in_memory==False`. Defaults
                to using a temporary directory. Any created files will be deleted during
                the `StateCacher`'s destructor.
            allow_overwrite: allow the cache to be overwritten. If set to `False`, an
                error will be thrown if a matching already exists in the list of cached
                objects.
            pickle_module: module used for pickling metadata and objects, default to `pickle`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            pickle_protocol: can be specified to override the default protocol, default to `2`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

        Nz+Given `cache_dir` is not a valid directory.)r   tempfile
gettempdirr   ospathisdir
ValueErrorr   r   r   cached)selfr   r   r   r   r    r   M/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/utils/state_cacher.py__init__,   s    zStateCacher.__init__r   r   zModuleType | Nonez
int | None)keydata_objr   r   r   c                 C  s   || j kr| jstd| jr:| j |dt|ii n~tj	| j
d| dt|  d}| j |d|ii tj|||dkr| jn||dkr| jn|d t|dr|j| j | d< dS )	a  
        Store a given object with the given key name.

        Args:
            key: key of the data object to store.
            data_obj: data object to store.
            pickle_module: module used for pickling metadata and objects, default to `self.pickle_module`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            pickle_protocol: can be specified to override the default protocol, default to `self.pickle_protocol`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

        z6Cached key already exists and overwriting is disabled.objZstate__z.ptN)r   fr   r   device)r   r   RuntimeErrorr   updatecopydeepcopyr   r   joinr   idtorchsaver   r   hasattrr"   )r   r   r   r   r   fnr   r   r   storeQ   s    "
zStateCacher.store)r   r   c                 C  s   || j krtd| d| jr.| j | d S | j | d }tj|sXtd| dtj|dd d}d	| j | kr|	| j | d	 }|S )
z2Retrieve the object stored under a given key name.zTarget z was not cached.r   zFailed to load state in z. File doesn't exist anymore.c                 S  s   | S )Nr   )storagelocationr   r   r   <lambda>~       z&StateCacher.retrieve.<locals>.<lambda>)map_locationr"   )
r   KeyErrorr   r   r   existsr#   r)   loadto)r   r   r,   r   r   r   r   retrieves   s    
zStateCacher.retrievec                 C  s@   | j s<| jD ].}tj| j| d rt| j| d  qdS )z>If necessary, delete any cached files existing in `cache_dir`.r   N)r   r   r   r   r4   remove)r   kr   r   r   __del__   s    
zStateCacher.__del__)NN)
__name__
__module____qualname____doc__pickler   r   r-   r7   r:   r   r   r   r   r      s   &   ")
__future__r   r%   r   r?   r   typesr   typingr   r   r)   Ztorch.serializationr   Zmonai.config.type_definitionsr   __all__r   r   r   r   r   <module>   s   