o
    ix                     @  s   d dl mZ d dlZd dlZd dlZd dlZd dlmZ d dlm	Z	 d dl
mZ d dlZd dlmZ d dlmZ dgZG d	d dZdS )
    )annotationsN)Hashable)
ModuleType)Any)DEFAULT_PROTOCOL)PathLikeStateCacherc                   @  sD   e Zd ZdZddeefdddZ	ddddZd ddZdd Z	dS )!r   a  Class to cache and retrieve the state of an object.

    Objects can either be stored in memory or on disk. If stored on disk, they can be
    stored in a given directory, or alternatively a temporary location will be used.

    If necessary/possible, restored objects will be returned to their original device.

    Example:

    >>> state_cacher = StateCacher(memory_cache, cache_dir=cache_dir)
    >>> state_cacher.store("model", model.state_dict())
    >>> model.load_state_dict(state_cacher.retrieve("model"))
    NT	in_memorybool	cache_dirPathLike | Noneallow_overwritepickle_moduler   pickle_protocolintreturnNonec                 C  sN   || _ |du rt n|| _tj| jstd|| _|| _	|| _
i | _dS )aQ  Constructor.

        Args:
            in_memory: boolean to determine if the object will be cached in memory or on
                disk.
            cache_dir: directory for data to be cached if `in_memory==False`. Defaults
                to using a temporary directory. Any created files will be deleted during
                the `StateCacher`'s destructor.
            allow_overwrite: allow the cache to be overwritten. If set to `False`, an
                error will be thrown if a matching already exists in the list of cached
                objects.
            pickle_module: module used for pickling metadata and objects, default to `pickle`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

        Nz+Given `cache_dir` is not a valid directory.)r	   tempfile
gettempdirr   ospathisdir
ValueErrorr   r   r   cached)selfr	   r   r   r   r    r   Z/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/utils/state_cacher.py__init__-   s   
zStateCacher.__init__keyr   data_objr   ModuleType | None
int | Nonec                 C  s   || j v r| jstd| jr| j |dt|ii dS tj	| j
d| dt|  d}| j |d|ii tj|||du rD| jn||du rL| jn|d t|dr_|j| j | d< dS dS )	a  
        Store a given object with the given key name.

        Args:
            key: key of the data object to store.
            data_obj: data object to store.
            pickle_module: module used for pickling metadata and objects, default to `self.pickle_module`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            pickle_protocol: can be specified to override the default protocol, default to `self.pickle_protocol`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.

        z6Cached key already exists and overwriting is disabled.objZstate__z.ptN)r"   fr   r   device)r   r   RuntimeErrorr	   updatecopydeepcopyr   r   joinr   idtorchsaver   r   hasattrr%   )r   r   r   r   r   fnr   r   r   storeR   s   "
zStateCacher.storec                 C  s   || j vrtd| d| jr| j | d S | j | d }tj|s,td| dtj|dd dd	}d
| j | v rG|	| j | d
 }|S )z2Retrieve the object stored under a given key name.zTarget z was not cached.r"   zFailed to load state in z. File doesn't exist anymore.c                 S  s   | S )Nr   )storagelocationr   r   r   <lambda>   s    z&StateCacher.retrieve.<locals>.<lambda>T)map_locationweights_onlyr%   )
r   KeyErrorr	   r   r   existsr&   r,   loadto)r   r   r/   r   r   r   r   retrievet   s   
zStateCacher.retrievec                 C  sD   | j s| jD ]}tj| j| d rt| j| d  qdS dS )z>If necessary, delete any cached files existing in `cache_dir`.r"   N)r	   r   r   r   r7   remove)r   kr   r   r   __del__   s   
zStateCacher.__del__)r	   r
   r   r   r   r
   r   r   r   r   r   r   )NN)
r   r   r   r   r   r    r   r!   r   r   )r   r   r   r   )
__name__
__module____qualname____doc__pickler   r   r0   r:   r=   r   r   r   r   r      s    &
")
__future__r   r(   r   rB   r   collections.abcr   typesr   typingr   r,   Ztorch.serializationr   monai.config.type_definitionsr   __all__r   r   r   r   r   <module>   s   