o
    !i08                    @  s  d dl mZ d dlZd dlZd dlZd dlZd dlZd dlZd dl	Z	d dl
Z
d dlmZmZ d dlmZmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZmZmZ d dlZd dl Z d dl!m"Z" d dl#m$Z$ d dl%m&Z' d dl%m(Z( d dl)m*Z* d dl+m,Z,m-Z-m.Z. d dl/m0Z0m1Z1m2Z2m3Z3m4Z4m5Z5 d dl6m7Z7m8Z8m9Z9m:Z:m;Z;m<Z< d dl=m>Z> erd dl?m?Z? dZ@n	e<dde;d\Z?Z@e<d\ZAZBe<d\ZCZBe<d\ZDZBe<d\ZEZBG dd de'Z&G dd de&ZFG d d! d!e&ZGG d"d# d#eGZHG d$d% d%eGZIG d&d' d'e&ZJG d(d) d)e1eJZKG d*d+ d+e&ZLG d,d- d-e1e'ZMG d.d/ d/e&ZNG d0d1 d1e&ZOG d2d3 d3eGZPdS )4    )annotationsN)CallableSequence)copydeepcopy)BytesIO)	ListProxy)
ThreadPool)Path)UnpicklingError)IOTYPE_CHECKINGAnycast)Manager)DEFAULT_PROTOCOLDataset)Subset)
MetaTensor)SUPPORTED_PICKLE_MODconvert_tables_to_dictspickle_hashing)ComposeRandomizableRandomizableTrait	Transformconvert_to_contiguousreset_ops_id)MAX_SEEDconvert_to_tensorget_seedlook_up_optionmin_versionoptional_import)first)tqdmTr&   z4.47.0cupylmdbpandaszkvikio.numpyc                   @  s:   e Zd ZdZddd	d
ZdddZdddZdddZdS )r   a.  
    A generic dataset with a length property and an optional callable data transform
    when fetching a data sample.
    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

    For example, typical input data can be a list of dictionaries::

        [{                            {                            {
             'img': 'image1.nii.gz',      'img': 'image2.nii.gz',      'img': 'image3.nii.gz',
             'seg': 'label1.nii.gz',      'seg': 'label2.nii.gz',      'seg': 'label3.nii.gz',
             'extra': 123                 'extra': 456                 'extra': 789
         },                           },                           }]
    Ndatar   	transform$Sequence[Callable] | Callable | NonereturnNonec              
   C  sH   || _ zt|tst|n|| _W dS  ty# } ztd|d}~ww )az  
        Args:
            data: input data to load and transform to generate dataset for model.
            transform: a callable, sequence of callables or None. If transform is not
            a `Compose` instance, it will be wrapped in a `Compose` instance. Sequences
            of callables are applied in order and if `None` is passed, the data is returned as is.
        zH`transform` must be a callable or a list of callables that is ComposableN)r*   
isinstancer   r+   	Exception
ValueError)selfr*   r+   e r4   T/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/data/dataset.py__init__I   s   
zDataset.__init__intc                 C  
   t | jS N)lenr*   r2   r4   r4   r5   __len__W      
zDataset.__len__indexc                 C  s   | j | }| |S )z:
        Fetch single data item from `self.data`.
        r*   r+   )r2   r>   Zdata_ir4   r4   r5   
_transformZ   s   

zDataset._transformint | slice | Sequence[int]c                 C  sZ   t |tr|t| \}}}t|||}t| |dS t |tjjr(t| |dS | 	|S )z^
        Returns a `Subset` if `index` is a slice or Sequence, a data item otherwise.
        )datasetindices)
r/   slicerC   r:   ranger   collectionsabcr   r@   )r2   r>   startstopsteprC   r4   r4   r5   __getitem__a   s   

zDataset.__getitem__r9   )r*   r   r+   r,   r-   r.   r-   r7   r>   r7   )r>   rA   )__name__
__module____qualname____doc__r6   r<   r@   rK   r4   r4   r4   r5   r   9   s    

r   c                      s.   e Zd ZdZd fdd	ZddddZ  ZS )DatasetFunca  
    Execute function on the input dataset and leverage the output to act as a new Dataset.
    It can be used to load / fetch the basic dataset items, like the list of `image, label` paths.
    Or chain together to execute more complicated logic, like `partition_dataset`, `resample_datalist`, etc.
    The `data` arg of `Dataset` will be applied to the first arg of callable `func`.
    Usage example::

        data_list = DatasetFunc(
            data="path to file",
            func=monai.data.load_decathlon_datalist,
            data_list_key="validation",
            base_dir="path to base dir",
        )
        # partition dataset for every rank
        data_partition = DatasetFunc(
            data=data_list,
            func=lambda **kwargs: monai.data.partition_dataset(**kwargs)[torch.distributed.get_rank()],
            num_partitions=torch.distributed.get_world_size(),
        )
        dataset = Dataset(data=data_partition, transform=transforms)

    Args:
        data: input data for the func to process, will apply to `func` as the first arg.
        func: callable function to generate dataset items.
        kwargs: other arguments for the `func` except for the first arg.

    r*   r   funcr   r-   r.   c                   s.   t  jd d d || _|| _|| _|   d S )Nr?   )superr6   srcrS   kwargsreset)r2   r*   rS   rV   	__class__r4   r5   r6      s
   zDatasetFunc.__init__N
Any | NoneCallable | Nonec                 K  sJ   |du r| j n|}|du r| j|fi | j| _dS ||fi || _dS )aL  
        Reset the dataset items with specified `func`.

        Args:
            data: if not None, execute `func` on it, default to `self.src`.
            func: if not None, execute the `func` with specified `kwargs`, default to `self.func`.
            kwargs: other arguments for the `func` except for the first arg.

        N)rU   rS   rV   r*   )r2   r*   rS   rV   rU   r4   r4   r5   rW      s   
8zDatasetFunc.reset)r*   r   rS   r   r-   r.   )NN)r*   rZ   rS   r[   )rN   rO   rP   rQ   r6   rW   __classcell__r4   r4   rX   r5   rR   p   s    rR   c                      sd   e Zd ZdZededdfd' fddZd(ddZd)ddZdd Z	d d! Z
d"d# Zd*d%d&Z  ZS )+PersistentDataseta  
    Persistent storage of pre-computed values to efficiently manage larger than memory dictionary format data,
    it can operate transforms for specific fields.  Results from the non-random transform components are computed
    when first used, and stored in the `cache_dir` for rapid retrieval on subsequent uses.
    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

    The transforms which are supposed to be cached must implement the `monai.transforms.Transform`
    interface and should not be `Randomizable`. This dataset will cache the outcomes before the first
    `Randomizable` `Transform` within a `Compose` instance.

    For example, typical input data can be a list of dictionaries::

        [{                            {                            {
            'image': 'image1.nii.gz',    'image': 'image2.nii.gz',    'image': 'image3.nii.gz',
            'label': 'label1.nii.gz',    'label': 'label2.nii.gz',    'label': 'label3.nii.gz',
            'extra': 123                 'extra': 456                 'extra': 789
        },                           },                           }]

    For a composite transform like

    .. code-block:: python

        [ LoadImaged(keys=['image', 'label']),
        Orientationd(keys=['image', 'label'], axcodes='RAS'),
        ScaleIntensityRanged(keys=['image'], a_min=-57, a_max=164, b_min=0.0, b_max=1.0, clip=True),
        RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96),
                                pos=1, neg=1, num_samples=4, image_key='image', image_threshold=0),
        ToTensord(keys=['image', 'label'])]

    Upon first use a filename based dataset will be processed by the transform for the
    [LoadImaged, Orientationd, ScaleIntensityRanged] and the resulting tensor written to
    the `cache_dir` before applying the remaining random dependant transforms
    [RandCropByPosNegLabeld, ToTensord] elements for use in the analysis.

    Subsequent uses of a dataset directly read pre-processed results from `cache_dir`
    followed by applying the random dependant parts of transform processing.

    During training call `set_data()` to update input data and recompute cache content.

    Note:
        The input data must be a list of file paths and will hash them as cache keys.

        The filenames of the cached files also try to contain the hash of the transforms. In this
        fashion, `PersistentDataset` should be robust to changes in transforms. This, however, is
        not guaranteed, so caution should be used when modifying transforms to avoid unexpected
        errors. If in doubt, it is advisable to clear the cache directory.

        Cached data is expected to be tensors, primitives, or dictionaries keying to these values. Numpy arrays will
        be converted to tensors, however any other object type returned by transforms will not be loadable since
        `torch.load` will be used with `weights_only=True` to prevent loading of potentially malicious objects.
        Legacy cache files may not be loadable and may need to be recomputed.

    Lazy Resampling:
        If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
        its documentation to familiarize yourself with the interaction between `PersistentDataset` and
        lazy resampling.

    pickleNTr*   r   r+   Sequence[Callable] | Callable	cache_dirPath | str | None	hash_funcCallable[..., bytes]pickle_modulestrpickle_protocolr7   hash_transformCallable[..., bytes] | Noner   boolr-   r.   c	           	        s   t  j||d |durt|nd| _|| _|| _|| _| jdur7| j s.| jjddd | j	 s7t
dd| _|durC| | || _dS )a
  
        Args:
            data: input data file paths to load and transform to generate dataset for model.
                `PersistentDataset` expects input data to be a list of serializable
                and hashes them as cache keys using `hash_func`.
            transform: transforms to execute operations on input data.
            cache_dir: If specified, this is the location for persistent storage
                of pre-computed transformed data tensors. The cache_dir is computed once, and
                persists on disk until explicitly removed.  Different runs, programs, experiments
                may share a common cache dir provided that the transforms pre-processing is consistent.
                If `cache_dir` doesn't exist, will automatically create it.
                If `cache_dir` is `None`, there is effectively no caching.
            hash_func: a callable to compute hash from data items to be cached.
                defaults to `monai.data.utils.pickle_hashing`.
            pickle_module: string representing the module used for pickling metadata and objects,
                default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader,
                we can't use `pickle` as arg directly, so here we use a string name instead.
                if want to use other pickle module at runtime, just register like:
                >>> from monai.data import utils
                >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
                and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            hash_transform: a callable to compute hash from the transform information when caching.
                This may reduce errors due to transforms changing during experiments. Default to None (no hash).
                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
                This is useful for skipping the transform instance checks when inverting applied operations
                using the cached content and with re-created transform instances.

        r?   NTparentsexist_okzcache_dir must be a directory. )rT   r6   r
   r`   rb   rd   rf   existsmkdiris_dirr1   transform_hashset_transform_hashr   )	r2   r*   r+   r`   rb   rd   rf   rg   r   rX   r4   r5   r6      s   .




zPersistentDataset.__init__hash_xform_funcc              
   C  s   g }| j  jD ]}t|tst|ts n|| qz||}W n( tyJ } zdt|vr2|d	dd |D }||}W Y d}~nd}~ww |
d| _dS )a  Get hashable transforms, and then hash them. Hashable transforms
        are deterministic transforms that inherit from `Transform`. We stop
        at the first non-deterministic transform, or first that does not
        inherit from MONAI's `Transform` class.zis not JSON serializablerm   c                 s  s    | ]}|j jV  qd S r9   )rY   rN   ).0trr4   r4   r5   	<genexpr>,      z7PersistentDataset.set_transform_hash.<locals>.<genexpr>Nutf-8)r+   flatten
transformsr/   r   r   append	TypeErrorre   joindecoderq   )r2   rs   Zhashable_transformsZ_trrq   tenamesr4   r4   r5   rr     s   z$PersistentDataset.set_transform_hashc                 C  sF   || _ | jdur| j r!tj| jdd | jjddd dS dS dS )Q
        Set the input data and delete all the out-dated cache content.

        NT)ignore_errorsrj   )r*   r`   rn   shutilrmtreero   r2   r*   r4   r4   r5   set_data0  s
   zPersistentDataset.set_datac                 C  s2   | j dd }| j ||dd}| jrt| |S )a  
        Process the data from original state up to the first random element.

        Args:
            item_transformed: The data to be transformed

        Returns:
            the transformed element up to the first identified
            random transform object

        c                 S     t | tp
t | t S r9   r/   r   r   tr4   r4   r5   <lambda>G      z2PersistentDataset._pre_transform.<locals>.<lambda>Tend	threading)r+   get_index_of_firstr   r2   item_transformedfirst_randomr4   r4   r5   _pre_transform:  s   z PersistentDataset._pre_transformc                 C  s*   | j dd }|dur| j ||d}|S )aD  
        Process the data from before the first random transform to the final state ready for evaluation.

        Args:
            item_transformed: The data to be transformed (already processed up to the first random transform)

        Returns:
            the transformed element through the random transforms

        c                 S  r   r9   r   r   r4   r4   r5   r   [  r   z3PersistentDataset._post_transform.<locals>.<lambda>NrH   )r+   r   r   r4   r4   r5   _post_transformO  s   z!PersistentDataset._post_transformc              
   C  s  d}| j dur| |d}|| j7 }| j | d }|durt| rtztj|ddW S  tyE } ztj	dkr;|W Y d}~n3d}~w t
tfys } z!dt|v sYt|t
rgtd| d	 |  n|W Y d}~nd}~ww | t|}|du r|S zit Z}t||j }tjt|d
d|t| jt| jd | r| sz
tt|| W n ty   Y nw W d   W |S W d   W |S W d   W |S W d   W |S 1 sw   Y  W |S  ty   Y |S w )a  
        A function to cache the expensive input data transform operations
        so that huge data sets (larger than computer memory) can be processed
        on the fly as needed, and intermediate results written to disk for
        future use.

        Args:
            item_transformed: The current data element to be mutated into transformed representation

        Returns:
            The transformed data_element, either from cache, or explicitly computing it.

        Warning:
            The current implementation does not encode transform information as part of the
            hashing mechanism used for generating cache names when `hash_transform` is None.
            If the transforms applied are changed in any way, the objects in the cache dir will be invalid.

        Nrx   .ptTweights_onlywin32z"Invalid magic number; corrupt filezCorrupt cache file detected: z. Deleting and recomputing.Fconvert_numericobjfrd   rf   ) r`   rb   r~   rq   is_filetorchloadPermissionErrorsysplatformr   RuntimeErrorre   r/   warningswarnunlinkr   r   tempfileTemporaryDirectoryr
   namesaver    r"   rd   r   rf   r   moveFileExistsError)r2   r   hashfiledata_item_md5r3   _item_transformed
tmpdirnametemp_hash_filer4   r4   r5   _cachechecka  sr   






zPersistentDataset._cachecheckr>   c                 C  s   |  | j| }| |S r9   )r   r*   r   )r2   r>   Zpre_random_itemr4   r4   r5   r@     s   
zPersistentDataset._transform)r*   r   r+   r_   r`   ra   rb   rc   rd   re   rf   r7   rg   rh   r   ri   r-   r.   )rs   rc   r*   r   rM   )rN   rO   rP   rQ   r   r   r6   rr   r   r   r   r   r@   r\   r4   r4   rX   r5   r]      s    A
=

@r]   c                      s>   e Zd ZdZededdfd fddZdd Zdd Z  Z	S )CacheNTransDatasetz~
    Extension of `PersistentDataset`, it can also cache the result of first N transforms, no matter it's random or not.

    r^   NTr*   r   r+   r_   cache_n_transr7   r`   ra   rb   rc   rd   re   rf   rg   rh   r   ri   r-   r.   c
           
   
     s&   t  j||||||||	d || _dS )a
  
        Args:
            data: input data file paths to load and transform to generate dataset for model.
                `PersistentDataset` expects input data to be a list of serializable
                and hashes them as cache keys using `hash_func`.
            transform: transforms to execute operations on input data.
            cache_n_trans: cache the result of first N transforms.
            cache_dir: If specified, this is the location for persistent storage
                of pre-computed transformed data tensors. The cache_dir is computed once, and
                persists on disk until explicitly removed.  Different runs, programs, experiments
                may share a common cache dir provided that the transforms pre-processing is consistent.
                If `cache_dir` doesn't exist, will automatically create it.
                If `cache_dir` is `None`, there is effectively no caching.
            hash_func: a callable to compute hash from data items to be cached.
                defaults to `monai.data.utils.pickle_hashing`.
            pickle_module: string representing the module used for pickling metadata and objects,
                default to `"pickle"`. due to the pickle limitation in multi-processing of Dataloader,
                we can't use `pickle` as arg directly, so here we use a string name instead.
                if want to use other pickle module at runtime, just register like:
                >>> from monai.data import utils
                >>> utils.SUPPORTED_PICKLE_MOD["test"] = other_pickle
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save,
                and ``monai.data.utils.SUPPORTED_PICKLE_MOD``.
            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            hash_transform: a callable to compute hash from the transform information when caching.
                This may reduce errors due to transforms changing during experiments. Default to None (no hash).
                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
                This is useful for skipping the transform instance checks when inverting applied operations
                using the cached content and with re-created transform instances.

        )r*   r+   r`   rb   rd   rf   rg   r   N)rT   r6   r   )
r2   r*   r+   r   r`   rb   rd   rf   rg   r   rX   r4   r5   r6     s   0

zCacheNTransDataset.__init__c                 C  s   | j || jdd}t| |S )z
        Process the data from original state up to the N element.

        Args:
            item_transformed: The data to be transformed

        Returns:
            the transformed element up to the N transform object
        Tr   )r+   r   r   r2   r   r4   r4   r5   r     s   
z!CacheNTransDataset._pre_transformc                 C  s   | j || jdS )a  
        Process the data from before the N + 1 transform to the final state ready for evaluation.

        Args:
            item_transformed: The data to be transformed (already processed up to the first N transform)

        Returns:
            the final transformed result
        r   )r+   r   r   r4   r4   r5   r     s   
z"CacheNTransDataset._post_transform)r*   r   r+   r_   r   r7   r`   ra   rb   rc   rd   re   rf   r7   rg   rh   r   ri   r-   r.   )
rN   rO   rP   rQ   r   r   r6   r   r   r\   r4   r4   rX   r5   r     s    <r   c                      sp   e Zd ZdZdeddedddfd' fddZd( fddZdd Zdd  Z	d)d!d"Z
 fd#d$Zd%d& Z  ZS )*LMDBDataseta  
    Extension of `PersistentDataset` using LMDB as the backend.

    See Also:
        :py:class:`monai.data.PersistentDataset`

    Examples:

        >>> items = [{"data": i} for i in range(5)]
        # [{'data': 0}, {'data': 1}, {'data': 2}, {'data': 3}, {'data': 4}]
        >>> lmdb_ds = monai.data.LMDBDataset(items, transform=monai.transforms.SimulateDelayd("data", delay_time=1))
        >>> print(list(lmdb_ds))  # using the cached results

    cacheZmonai_cacheTNr*   r   r+   r_   r`   
Path | strrb   rc   db_namere   progressri   rg   rh   r   lmdb_kwargsdict | Noner-   r.   c              	     s   t  j|||||||	d || _| jstd| j| d | _|
p#i | _| jdds1d| jd< d| _| j	| jd t
d	| j  d
 dS )a  
        Args:
            data: input data file paths to load and transform to generate dataset for model.
                `LMDBDataset` expects input data to be a list of serializable
                and hashes them as cache keys using `hash_func`.
            transform: transforms to execute operations on input data.
            cache_dir: if specified, this is the location for persistent storage
                of pre-computed transformed data tensors. The cache_dir is computed once, and
                persists on disk until explicitly removed.  Different runs, programs, experiments
                may share a common cache dir provided that the transforms pre-processing is consistent.
                If the cache_dir doesn't exist, will automatically create it. Defaults to "./cache".
            hash_func: a callable to compute hash from data items to be cached.
                defaults to `monai.data.utils.pickle_hashing`.
            db_name: lmdb database file name. Defaults to "monai_cache".
            progress: whether to display a progress bar.
            pickle_protocol: specifies pickle protocol when saving, with `torch.save`.
                Defaults to torch.serialization.DEFAULT_PROTOCOL. For more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            hash_transform: a callable to compute hash from the transform information when caching.
                This may reduce errors due to transforms changing during experiments. Default to None (no hash).
                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekeys.NONE``, defaults to ``True``.
                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
                This is useful for skipping the transform instance checks when inverting applied operations
                using the cached content and with re-created transform instances.
            lmdb_kwargs: additional keyword arguments to the lmdb environment.
                for more details please visit: https://lmdb.readthedocs.io/en/release/#environment-class
        )r*   r+   r`   rb   rf   rg   r   zcache_dir must be specified.z.lmdbmap_sizer   l        Nshow_progresszAccessing lmdb file: .)rT   r6   r   r`   r1   db_filer   get	_read_env_fill_cache_start_readerprintabsolute)r2   r*   r+   r`   rb   r   r   rf   rg   r   r   rX   r4   r5   r6     s&   )	

zLMDBDataset.__init__c                   s"   t  j|d | j| jd| _dS )r   )r*   r   N)rT   r   r   r   r   r   rX   r4   r5   r   T  s   zLMDBDataset.set_datac                 C  s.   t  }tjt||| jd |d | S )N)rf   r   )r   r   r   r    rf   seekread)r2   valoutr4   r4   r5   _safe_serialize\  s   
zLMDBDataset._safe_serializec                 C  s   t jt|dddS )NcpuT)map_locationr   )r   r   r   )r2   r   r4   r4   r5   _safe_deserializeb     zLMDBDataset._safe_deserializec                 C  sv  d| j d< tjd| j dd| j }|rtstd |jdd}tr,|r,t| j	n| j	D ]}| 
|}d\}}}|s|dkrzM| }	|	|}W d   n1 sVw   Y  |r_W q;|du ro| t|}| |}|jd	d}
|
|| W d   n1 sw   Y  d	}W nB tjy   d|d
 }}| d }|d }tdt|d?  dt|d?  d || Y n tjy   |d Y nw |s|dksA|s| d }|  td| dq/W d   n1 sw   Y  | d }|  d	| j d< || j d< | j dddu rd| j d< | j dddu r-d| j d< tjd| j dd| j S )aF  
        Check the LMDB cache and write the cache if needed. py-lmdb doesn't have a good support for concurrent write.
        This method can be used with multiple processes, but it may have a negative impact on the performance.

        Args:
            show_progress: whether to show the progress bar if possible.
        Freadonly)pathsubdirzHLMDBDataset: tqdm is not installed. not displaying the caching progress.write)F   Nr   NT   r      z!Resizing the cache database from    zMB to zMB.z;LMDB map size reached, increase size above current size of r   lock	readaheadr4   )r   r(   openr   has_tqdmr   r   beginr&   r*   rb   cursorZset_keyr   r   r   putZMapFullErrorinfor7   Zset_mapsizeZMapResizedErrorcloser1   r   )r2   r   envZ
search_txnitemkeydoneretryr   r   txnsizenew_sizer4   r4   r5   r   e  sd   
	




 !



z$LMDBDataset._fill_cache_start_readerc              
     s   | j du r| jdd| _ | j jdd}|| |}W d   n1 s&w   Y  |du r:td t |S z| 	|W S  t
yQ } ztd|d}~ww )zq
        if the item is not found in the lmdb file, resolves to the persistent cache default behaviour.

        NFr   r   z;LMDBDataset: cache key not found, running fallback caching.z)Invalid cache value, corrupted lmdb file?)r   r   r   r   rb   r   r   rT   r   r   r0   r   )r2   r   r   r*   errrX   r4   r5   r     s   


zLMDBDataset._cachecheckc                 C  sD   | j du r
|  | _ t| j  }t| j|d< | j  |d< |S )z4
        Returns: dataset info dictionary.

        Nr   filename)r   r   dictr   r:   r*   r   r   )r2   r   r4   r4   r5   r     s   

zLMDBDataset.info)r*   r   r+   r_   r`   r   rb   rc   r   re   r   ri   rg   rh   r   ri   r   r   r-   r.   r   )T)rN   rO   rP   rQ   r   r   r6   r   r   r   r   r   r   r\   r4   r4   rX   r5   r     s"    @
9r   c                
      sh   e Zd ZdZdejddddddedf
d) fddZd*ddZd+d,d!d"Z	d-d$d%Z
d. fd'd(Z  ZS )/CacheDatasetau  
    Dataset with cache mechanism that can load data and cache deterministic transforms' result during training.

    By caching the results of non-random preprocessing transforms, it accelerates the training data pipeline.
    If the requested data is not in the cache, all transforms will run normally
    (see also :py:class:`monai.data.dataset.Dataset`).

    Users can set the cache rate or number of items to cache.
    It is recommended to experiment with different `cache_num` or `cache_rate` to identify the best training speed.

    The transforms which are supposed to be cached must implement the `monai.transforms.Transform`
    interface and should not be `Randomizable`. This dataset will cache the outcomes before the first
    `Randomizable` `Transform` within a `Compose` instance.
    So to improve the caching efficiency, please always put as many as possible non-random transforms
    before the randomized ones when composing the chain of transforms.
    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

    For example, if the transform is a `Compose` of::

        transforms = Compose([
            LoadImaged(),
            EnsureChannelFirstd(),
            Spacingd(),
            Orientationd(),
            ScaleIntensityRanged(),
            RandCropByPosNegLabeld(),
            ToTensord()
        ])

    when `transforms` is used in a multi-epoch training pipeline, before the first training epoch,
    this dataset will cache the results up to ``ScaleIntensityRanged``, as
    all non-random transforms `LoadImaged`, `EnsureChannelFirstd`, `Spacingd`, `Orientationd`, `ScaleIntensityRanged`
    can be cached. During training, the dataset will load the cached results and run
    ``RandCropByPosNegLabeld`` and ``ToTensord``, as ``RandCropByPosNegLabeld`` is a randomized transform
    and the outcome not cached.

    During training call `set_data()` to update input data and recompute cache content, note that it requires
    `persistent_workers=False` in the PyTorch DataLoader.

    Note:
        `CacheDataset` executes non-random transforms and prepares cache content in the main process before
        the first epoch, then all the subprocesses of DataLoader will read the same cache content in the main process
        during training. it may take a long time to prepare cache content according to the size of expected cache data.
        So to debug or verify the program before real training, users can set `cache_rate=0.0` or `cache_num=0` to
        temporarily skip caching.

    Lazy Resampling:
        If you make use of the lazy resampling feature of `monai.transforms.Compose`, please refer to
        its documentation to familiarize yourself with the interaction between `CacheDataset` and
        lazy resampling.

    N      ?r   TFr*   r   r+   r,   	cache_numr7   
cache_ratefloatnum_workers
int | Noner   ri   
copy_cacheas_contiguoushash_as_keyrb   rc   runtime_cachebool | str | list | ListProxyr-   r.   c                   s   t  j||d || _|| _|| _|| _|| _|	| _|
| _|| _	| j	dur.t
t| j	d| _	|| _d| _g | _g | _| | dS )a  
        Args:
            data: input data to load and transform to generate dataset for model.
            transform: transforms to execute operations on input data.
            cache_num: number of items to be cached. Default is `sys.maxsize`.
                will take the minimum of (cache_num, data_length x cache_rate, data_length).
            cache_rate: percentage of cached data in total, default is 1.0 (cache all).
                will take the minimum of (cache_num, data_length x cache_rate, data_length).
            num_workers: the number of worker threads if computing cache in the initialization.
                If num_workers is None then the number returned by os.cpu_count() is used.
                If a value less than 1 is specified, 1 will be used instead.
            progress: whether to display a progress bar.
            copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
                default to `True`. if the random transforms don't modify the cached content
                (for example, randomly crop from the cached image and deepcopy the crop region)
                or if every cache item is only used once in a `multi-processing` environment,
                may set `copy=False` for better performance.
            as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
                it may help improve the performance of following logic.
            hash_as_key: whether to compute hash value of input data as the key to save cache,
                if key exists, avoid saving duplicated content. it can help save memory when
                the dataset has duplicated items or augmented dataset.
            hash_func: if `hash_as_key`, a callable to compute hash from data items to be cached.
                defaults to `monai.data.utils.pickle_hashing`.
            runtime_cache: mode of cache at the runtime. Default to `False` to prepare
                the cache content for the entire ``data`` during initialization, this potentially largely increase the
                time required between the constructor called and first mini-batch generated.
                Three options are provided to compute the cache on the fly after the dataset initialization:

                1. ``"threads"`` or ``True``: use a regular ``list`` to store the cache items.
                2. ``"processes"``: use a ListProxy to store the cache items, it can be shared among processes.
                3. A list-like object: a users-provided container to be used to store the cache items.

                For `thread-based` caching (typically for caching cuda tensors), option 1 is recommended.
                For single process workflows with multiprocessing data loading, option 2 is recommended.
                For multiprocessing workflows (typically for distributed training),
                where this class is initialized in subprocesses, option 3 is recommended,
                and the list-like object should be prepared in the main process and passed to all subprocesses.
                Not following these recommendations may lead to runtime errors or duplicated cache across processes.

        r?   Nr   r   )rT   r6   set_numset_rater   r   r   r   rb   r   maxr7   r   r   _cache
_hash_keysr   )r2   r*   r+   r   r   r   r   r   r   r   rb   r   rX   r4   r5   r6     s    7
zCacheDataset.__init__c                   s  | _ d fdd} jr5 fddt j D }|t| t|d j  _t| d j }n|t j  tt j} j	dv rP 
| _dS t j	trhd	 j	v rht dg j  _dS  j	d
u sxt j	trd j	v rdg j  _dS  j	 _dS )aA  
        Set the input data and run deterministic transforms to generate cache content.

        Note: should call this func after an entire epoch and must set `persistent_workers=False`
        in PyTorch DataLoader, because it needs to create new worker processes based on new
        generated cache content.

        data_lenr7   c                   s"   t t jt|  j |  _d S r9   )minr7   r   r   r   )r  r;   r4   r5   _compute_cache_numH  s   "z1CacheDataset.set_data.<locals>._compute_cache_numc                   s   i | ]
\}}  ||qS r4   )rb   )rt   ivr;   r4   r5   
<dictcomp>M  s    z)CacheDataset.set_data.<locals>.<dictcomp>N)FNprocessTthread)r  r7   )r*   r   	enumerater:   listr   r  valuesrE   r   _fill_cacher   r/   re   r   )r2   r*   r  mappingrC   r4   r;   r5   r   =  s(   	
 zCacheDataset.set_datar  c                 C  s   | j dkrg S |du rtt| j }| jrtstd t| j-}| jr>tr>tt	|
| j|t|ddW  d   S t|
| j|W  d   S 1 sQw   Y  dS )z
        Compute and fill the cache content from data source.

        Args:
            indices: target indices in the `self.data` source to compute cache.
                if None, use the first `cache_num` items.

        r   Nz>tqdm is not installed, will not show the caching progress bar.zLoading dataset)totaldesc)r   r  rE   r   r   r   r   r	   r   r&   imap_load_cache_itemr:   )r2   rC   pr4   r4   r5   r  b  s   
	


$zCacheDataset._fill_cacheidxc                 C  sB   | j | }| jdd }| j||dd}| jrt|tjd}|S )zN
        Args:
            idx: the index of the input data sequence.
        c                 S  r   r9   r   r   r4   r4   r5   r   ~  r   z/CacheDataset._load_cache_item.<locals>.<lambda>Tr   )memory_format)r*   r+   r   r   r   r   contiguous_format)r2   r  r   r   r4   r4   r5   r  v  s   
zCacheDataset._load_cache_itemr>   c                   s   d }| j r| | j| }|| jv r| j|}n|t|  | jk r$|}|d u r.t |S | j	d u r7t
d| j	| }|d u rJ| | }| j	|< t| jtsTtd| jdd }|d urr| jdu rit|n|}| j||d}|S )Nz@cache buffer is not initialized, please call `set_data()` first.z:transform must be an instance of monai.transforms.Compose.c                 S  r   r9   r   r   r4   r4   r5   r     r   z)CacheDataset._transform.<locals>.<lambda>Tr   )r   rb   r*   r  r>   r:   r   rT   r@   r   r   r  r/   r+   r   r1   r   r   r   )r2   r>   Zcache_indexr   r*   r   rX   r4   r5   r@     s0   


zCacheDataset._transform)r*   r   r+   r,   r   r7   r   r   r   r   r   ri   r   ri   r   ri   r   ri   rb   rc   r   r   r-   r.   r*   r   r-   r.   r9   )r-   r  )r  r7   rM   )rN   rO   rP   rQ   r   maxsizer   r6   r   r  r  r@   r\   r4   r4   rX   r5   r     s"    9
H%
r   c                      s   e Zd ZdZddejdddddddddfd> fddZd? fdd Zd@d!d"ZdAd#d$Z	d%d& Z
d'd( Zd)d* Zd+d, Zd-d. Zd/d0 Zd1d2 ZdBd4d5Zd6d7 Zd8d9 ZdAd:d;Zd<d= Z  ZS )CSmartCacheDataseta;  
    Re-implementation of the SmartCache mechanism in NVIDIA Clara-train SDK.
    At any time, the cache pool only keeps a subset of the whole dataset. In each epoch, only the items
    in the cache are used for training. This ensures that data needed for training is readily available,
    keeping GPU resources busy. Note that cached items may still have to go through a non-deterministic
    transform sequence before being fed to GPU. At the same time, another thread is preparing replacement
    items by applying the transform sequence to items not in cache. Once one epoch is completed, Smart
    Cache replaces the same number of items with replacement items.
    Smart Cache uses a simple `running window` algorithm to determine the cache content and replacement items.
    Let N be the configured number of objects in cache; and R be the number of replacement objects (R = ceil(N * r),
    where r is the configured replace rate).
    For more details, please refer to:
    https://docs.nvidia.com/clara/clara-train-archive/3.1/nvmidl/additional_features/smart_cache.html
    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

    For example, if we have 5 images: `[image1, image2, image3, image4, image5]`, and `cache_num=4`, `replace_rate=0.25`.
    so the actual training images cached and replaced for every epoch are as below::

        epoch 1: [image1, image2, image3, image4]
        epoch 2: [image2, image3, image4, image5]
        epoch 3: [image3, image4, image5, image1]
        epoch 3: [image4, image5, image1, image2]
        epoch N: [image[N % 5] ...]

    The usage of `SmartCacheDataset` contains 4 steps:

        1. Initialize `SmartCacheDataset` object and cache for the first epoch.
        2. Call `start()` to run replacement thread in background.
        3. Call `update_cache()` before every epoch to replace training items.
        4. Call `shutdown()` when training ends.

    During training call `set_data()` to update input data and recompute cache content, note to call
    `shutdown()` to stop first, then update data and call `start()` to restart.

    Note:
        This replacement will not work for below cases:
        1. Set the `multiprocessing_context` of DataLoader to `spawn`.
        2. Launch distributed data parallel with `torch.multiprocessing.spawn`.
        3. Run on windows(the default multiprocessing method is `spawn`) with `num_workers` greater than 0.
        4. Set the `persistent_workers` of DataLoader to `True` with `num_workers` greater than 0.

        If using MONAI workflows, please add `SmartCacheHandler` to the handler list of trainer,
        otherwise, please make sure to call `start()`, `update_cache()`, `shutdown()` during training.

    Args:
        data: input data to load and transform to generate dataset for model.
        transform: transforms to execute operations on input data.
        replace_rate: percentage of the cached items to be replaced in every epoch (default to 0.1).
        cache_num: number of items to be cached. Default is `sys.maxsize`.
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        cache_rate: percentage of cached data in total, default is 1.0 (cache all).
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        num_init_workers: the number of worker threads to initialize the cache for first epoch.
            If num_init_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        num_replace_workers: the number of worker threads to prepare the replacement cache for every epoch.
            If num_replace_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        progress: whether to display a progress bar when caching for the first epoch.
        shuffle: whether to shuffle the whole data list before preparing the cache content for first epoch.
            it will not modify the original input data sequence in-place.
        seed: random seed if shuffle is `True`, default to `0`.
        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
            default to `True`. if the random transforms don't modify the cache content
            or every cache item is only used once in a `multi-processing` environment,
            may set `copy=False` for better performance.
        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
            it may help improve the performance of following logic.
        runtime_cache: Default to `False`, other options are not implemented yet.
    Ng?r   r   Tr   Fr*   r   r+   r,   replace_rater   r   r7   r   num_init_workersr   num_replace_workersr   ri   shuffleseedr   r   r-   r.   c                   s"  |	r| j |
d |	| _d| _t | _d| _d| _d | _|dur$t	dt
 j||||||||dd	 | jd u r=|  | _| jt|krItd |dkrQtd|| _| jd urbtt| jd| _t|| _tt| j| t|| j | _d	d
 t| jD | _tt| j| _|   d S )Nr  r   r   Fz@Options other than `runtime_cache=False` is not implemented yet.)	r*   r+   r   r   r   r   r   r   r   z`cache_num is greater or equal than dataset length, fall back to regular monai.data.CacheDataset.zSreplace_rate must be greater than 0, otherwise, please use monai.data.CacheDataset.c                 S  s   g | ]}d qS r9   r4   rt   _r4   r4   r5   
<listcomp>(      z.SmartCacheDataset.__init__.<locals>.<listcomp>) set_random_stater  
_start_posr   Lock_update_lock_round_replace_done_replace_mgrNotImplementedErrorrT   r6   r   r  r   r:   r   r   r1   r  r   r7   
_total_numr  mathceil_replace_numrE   _replacementsr  _replace_data_idx_compute_data_idx)r2   r*   r+   r  r   r   r  r  r   r  r  r   r   r   rX   r4   r5   r6     sJ   




"zSmartCacheDataset.__init__c                   sB   |   rtd |   | jrt|}| | t | dS )z
        Set the input data and run deterministic transforms to generate cache content.

        Note: should call `shutdown()` before calling this func.

        z<SmartCacheDataset is not shutdown yet, shutdown it directly.N)	
is_startedr   r   shutdownr  r   	randomizerT   r   r   rX   r4   r5   r   ,  s   

zSmartCacheDataset.set_datac              
   C  sL   z	| j | W d S  ty% } ztd| d W Y d }~d S d }~ww )NzOinput data can't be shuffled in SmartCacheDataset with numpy.random.shuffle(): r   )Rr  r|   r   r   )r2   r*   r3   r4   r4   r5   r5  <  s    zSmartCacheDataset.randomizec                 C  sB   t | jD ]}| j| j | }|| jkr|| j8 }|| j|< qdS )zJ
        Update the replacement data position in the total data.

        N)rE   r/  r%  r   r,  r1  )r2   r  posr4   r4   r5   r2  B  s   

z#SmartCacheDataset._compute_data_idxc                 C  s   | j du rdS | j  S )zK
        Check whether the replacement thread is already started.

        NF)r*  is_aliver;   r4   r4   r5   r3  M  s   zSmartCacheDataset.is_startedc                 C  s   |   s
|   dS dS )zY
        Start the background thread to replace training items for every epoch.

        N)r3  _restartr;   r4   r4   r5   rH   T  s   zSmartCacheDataset.startc                 C  s&   d| _ tj| jdd| _| j  dS )zG
        Restart background thread if killed for some reason.

        r   T)targetdaemonN)r(  r   Threadmanage_replacementr*  rH   r;   r4   r4   r5   r9  \  s   zSmartCacheDataset._restartc                 C  s   | j H | js	 W d   dS | jd| j= | j| j |  j| j7  _| j| jkr4|  j| j8  _|   |  j	d7  _	d| _	 W d   dS 1 sNw   Y  dS )zQ
        Update the cache items with new replacement for current epoch.

        NFr   T)
r'  r)  r   r/  extendr0  r%  r,  r2  r(  r;   r4   r4   r5   _try_update_cachee  s   $z#SmartCacheDataset._try_update_cachec                 C  s*   |    |  std |  rdS dS )z
        Update cache items for current epoch, need to call this function before every epoch.
        If the cache has been shutdown before, need to restart the `_replace_mgr` thread.

        {Gz?N)rH   r?  timesleepr;   r4   r4   r5   update_cache|  s   
zSmartCacheDataset.update_cachec                 C  s`   | j # | jrd| _d| _|   d| _	 W d   dS 	 W d   dS 1 s)w   Y  dS )zK
        Wait for thread lock to shut down the background thread.

        r   FNT)r'  r)  r(  r%  r2  r;   r4   r4   r5   _try_shutdown  s   $zSmartCacheDataset._try_shutdownc                 C  sD   |   sdS |  std |  r
| jdur | jd dS dS )zC
        Shut down the background thread for replacement.

        Nr@  i,  )r3  rD  rA  rB  r*  r}   r;   r4   r4   r5   r4    s   

zSmartCacheDataset.shutdownr>   c                 C  s   | j | }| || j|< dS )zT
        Execute deterministic transforms on the new data for replacement.

        N)r1  r  r0  )r2   r>   r7  r4   r4   r5   _replace_cache_thread  s   
z'SmartCacheDataset._replace_cache_threadc                 C  sL   t | j}|| jtt| j W d   n1 sw   Y  d| _dS )z
        Compute expected items for the replacement of next epoch, execute deterministic transforms.
        It can support multi-threads to accelerate the computation progress.

        NT)r	   r  maprE  r  rE   r/  r)  )r2   r  r4   r4   r5   _compute_replacements  s   
z'SmartCacheDataset._compute_replacementsc                 C  sh   | j ' | jdkrd| _	 W d   dS | j|kr|   d| jfW  d   S 1 s-w   Y  dS )zX
        Wait thread lock and replace training items in the background thread.

        r   TN)TF)r'  r(  r)  rG  )r2   check_roundr4   r4   r5   _try_manage_replacement  s   

$z)SmartCacheDataset._try_manage_replacementc                 C  s0   d}d}|s|  |\}}td |rdS dS )z5
        Background thread for replacement.

        rH  Fr@  N)rJ  rA  rB  )r2   rI  r   r4   r4   r5   r=    s   
z$SmartCacheDataset.manage_replacementc                 C     | j S )zQ
        The dataset length is given by cache_num instead of len(data).

        )r   r;   r4   r4   r5   r<     s   zSmartCacheDataset.__len__)r*   r   r+   r,   r  r   r   r7   r   r   r  r   r  r   r   ri   r  ri   r  r7   r   ri   r   ri   r-   r.   r   r  )r-   r.   rM   )rN   rO   rP   rQ   r   r  r6   r   r5  r2  r3  rH   r9  r?  rC  rD  r4  rE  rG  rJ  r=  r<   r\   r4   r4   rX   r5   r    s<    K:

	

r  c                      s8   e Zd ZdZdd fd	d
ZdddZdddZ  ZS )
ZipDatasetaS  
    Zip several PyTorch datasets and output data(with the same index) together in a tuple.
    If the output of single dataset is already a tuple, flatten it and extend to the result.
    For example: if datasetA returns (img, imgmeta), datasetB returns (seg, segmeta),
    finally return (img, imgmeta, seg, segmeta).
    And if the datasets don't have same length, use the minimum length of them as the length
    of ZipDataset.
    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

    Examples::

        >>> zip_data = ZipDataset([[1, 2, 3], [4, 5]])
        >>> print(len(zip_data))
        2
        >>> for item in zip_data:
        >>>    print(item)
        [1, 4]
        [2, 5]

    Ndatasetsr   r+   r[   r-   r.   c                   s   t  jt||d dS )z
        Args:
            datasets: list of datasets to zip together.
            transform: a callable data transform operates on the zipped item from `datasets`.
        )r+   N)rT   r6   r  )r2   rM  r+   rX   r4   r5   r6     s   zZipDataset.__init__r7   c                 C  s   t dd | jD S )Nc                 s  s    | ]}t |V  qd S r9   )r:   )rt   rB   r4   r4   r5   rv     rw   z%ZipDataset.__len__.<locals>.<genexpr>)r  r*   r;   r4   r4   r5   r<     r   zZipDataset.__len__r>   c                 C  sN   dd }g }| j D ]}||||  q	| jd ur#d| j_| |}t|S )Nc                 S  s   t | ttfrt| S | gS r9   )r/   tupler  )xr4   r4   r5   to_list  s   z&ZipDataset._transform.<locals>.to_listF)r*   r>  r+   Z	map_itemsrN  )r2   r>   rP  r*   rB   r4   r4   r5   r@     s   


zZipDataset._transformr9   )rM  r   r+   r[   r-   r.   rL   rM   rN   rO   rP   rQ   r6   r<   r@   r\   r4   r4   rX   r5   rL    s
    
rL  c                   @  sF   e Zd ZdZ					ddddZdddZddddZdddZdS ) ArrayDataseta4  
    Dataset for segmentation and classification tasks based on array format input data and transforms.
    It ensures the same random seeds in the randomized transforms defined for image, segmentation and label.
    The `transform` can be :py:class:`monai.transforms.Compose` or any other callable object.
    For example:
    If train based on Nifti format images without metadata, all transforms can be composed::

        img_transform = Compose(
            [
                LoadImage(image_only=True),
                EnsureChannelFirst(),
                RandAdjustContrast()
            ]
        )
        ArrayDataset(img_file_list, img_transform=img_transform)

    If training based on images and the metadata, the array transforms can not be composed
    because several transforms receives multiple parameters or return multiple values. Then Users need
    to define their own callable method to parse metadata from `LoadImage` or set `affine` matrix
    to `Spacing` transform::

        class TestCompose(Compose):
            def __call__(self, input_):
                img, metadata = self.transforms[0](input_)
                img = self.transforms[1](img)
                img, _, _ = self.transforms[2](img, metadata["affine"])
                return self.transforms[3](img), metadata
        img_transform = TestCompose(
            [
                LoadImage(image_only=False),
                EnsureChannelFirst(),
                Spacing(pixdim=(1.5, 1.5, 3.0)),
                RandAdjustContrast()
            ]
        )
        ArrayDataset(img_file_list, img_transform=img_transform)

    Examples::

        >>> ds = ArrayDataset([1, 2, 3, 4], lambda x: x + 0.1)
        >>> print(ds[0])
        1.1

        >>> ds = ArrayDataset(img=[1, 2, 3, 4], seg=[5, 6, 7, 8])
        >>> print(ds[0])
        [1, 5]

    Nimgr   img_transformr[   segSequence | Noneseg_transformlabelslabel_transformr-   r.   c           	      C  sZ   ||f||f||fg}| j t d dd |D }t|dkr#|d nt|| _d| _dS )a  
        Initializes the dataset with the filename lists. The transform `img_transform` is applied
        to the images and `seg_transform` to the segmentations.

        Args:
            img: sequence of images.
            img_transform: transform to apply to each element in `img`.
            seg: sequence of segmentations.
            seg_transform: transform to apply to each element in `seg`.
            labels: sequence of labels.
            label_transform: transform to apply to each element in `labels`.

        r  c                 S  s*   g | ]}|d  durt |d  |d qS )r   Nr   r   )rt   rO  r4   r4   r5   r"  U  s   * z)ArrayDataset.__init__.<locals>.<listcomp>r   r   N)r$  r!   r:   rL  rB   _seed)	r2   rS  rT  rU  rW  rX  rY  itemsrM  r4   r4   r5   r6   =  s
   
zArrayDataset.__init__r7   c                 C  r8   r9   )r:   rB   r;   r4   r4   r5   r<   Z  r=   zArrayDataset.__len__r*   rZ   c                 C  s   t | jjtdd| _d S )Nuint32)dtype)r7   r6  randintr   rZ  r   r4   r4   r5   r5  ]  s   zArrayDataset.randomizer>   c                 C  sv   |    t| jtr#| jjD ]}t|dd }t|tr"|j| jd qt| jdd }t|tr6|j| jd | j| S )Nr+   r  )	r5  r/   rB   rL  r*   getattrr   r$  rZ  )r2   r>   rB   r+   r4   r4   r5   rK   `  s   


zArrayDataset.__getitem__)NNNNN)rS  r   rT  r[   rU  rV  rW  r[   rX  rV  rY  r[   r-   r.   rL   r9   )r*   rZ   r-   r.   rM   )rN   rO   rP   rQ   r6   r<   r5  rK   r4   r4   r4   r5   rR    s    4
rR  c                      s:   e Zd ZdZ		dd fddZdd ZdddZ  ZS )NPZDictItemDataseta=  
    Represents a dataset from a loaded NPZ file. The members of the file to load are named in the keys of `keys` and
    stored under the keyed name. All loaded arrays must have the same 0-dimension (batch) size. Items are always dicts
    mapping names to an item extracted from the loaded arrays.
    If passing slicing indices, will return a PyTorch Subset, for example: `data: Subset = dataset[1:4]`,
    for more details, please check: https://pytorch.org/docs/stable/data.html#torch.utils.data.Subset

    Args:
        npzfile: Path to .npz file or stream containing .npz file data
        keys: Maps keys to load from file to name to store in dataset
        transform: Transform to apply to batch dict
        other_keys: secondary data to load from file and store in dict `other_keys`, not returned by __getitem__
    Nr4   npzfilestr | IOkeysdict[str, str]r+   $Callable[..., dict[str, Any]] | None
other_keysSequence[str] | Nonec              	     s   t |tr|nd| _t|| _t|  fdd| j D | _| jt	tt
| j  jd | _|d u r8i n fdd|D | _| j D ]\}}|jd | jkrdtd| j d| d|jd  qGt g | d S )	NZSTREAMc                   s   i | ]	\}}| | qS r4   r4   )rt   ZdatakZstoredkdatr4   r5   r        z/NPZDictItemDataset.__init__.<locals>.<dictcomp>r   c                   s   i | ]}| | qS r4   r4   )rt   krh  r4   r5   r    r   z:All loaded arrays must have the same first dimension size z	, array `z` has size )r/   re   ra  r   rc  npr   r[  arraysr   r%   r  shapelengthrf  r1   rT   r6   )r2   ra  rc  r+   rf  rk  r  rX   rh  r5   r6   }  s&   

" zNPZDictItemDataset.__init__c                 C  rK  r9   )ro  r;   r4   r4   r5   r<     s   zNPZDictItemDataset.__len__r>   r7   c                   s^    fdd| j  D }| jd ur| |n|}t|ts)t|tr+t|d tr+|S td)Nc                   s   i | ]	\}}||  qS r4   r4   )rt   rk  r  r>   r4   r5   r    rj  z1NPZDictItemDataset._transform.<locals>.<dictcomp>r   zIWith a dict supplied to Compose, should return a dict or a list of dicts.)rm  r[  r+   r/   r   r  AssertionError)r2   r>   r*   resultr4   rp  r5   r@     s
   "zNPZDictItemDataset._transform)Nr4   )ra  rb  rc  rd  r+   re  rf  rg  rM   rQ  r4   r4   rX   r5   r`  n  s    r`  c                      s2   e Zd ZdZ							dd fddZ  ZS )
CSVDataseta
  
    Dataset to load data from CSV files and generate a list of dictionaries,
    every dictionary maps to a row of the CSV file, and the keys of dictionary
    map to the column names of the CSV file.

    It can load multiple CSV files and join the tables with additional `kwargs` arg.
    Support to only load specific rows and columns.
    And it can also group several loaded columns to generate a new column, for example,
    set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::

        [
            {"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
            {"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
        ]

    Args:
        src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.
            also support to provide pandas `DataFrame` directly, will skip loading from filename.
            if provided a list of filenames or pandas `DataFrame`, it will join the tables.
        row_indices: indices of the expected rows to load. it should be a list,
            every item can be a int number or a range `[start, end)` for the indices.
            for example: `row_indices=[[0, 100], 200, 201, 202, 300]`. if None,
            load all the rows in the file.
        col_names: names of the expected columns to load. if None, load all the columns.
        col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
            it should be a dictionary, every item maps to an expected column, the `key` is the column
            name and the `value` is None or a dictionary to define the default value and data type.
            the supported keys in dictionary are: ["type", "default"]. for example::

                col_types = {
                    "subject_id": {"type": str},
                    "label": {"type": int, "default": 0},
                    "ehr_0": {"type": float, "default": 0.0},
                    "ehr_1": {"type": float, "default": 0.0},
                    "image": {"type": str, "default": None},
                }

        col_groups: args to group the loaded columns to generate a new column,
            it should be a dictionary, every item maps to a group, the `key` will
            be the new column name, the `value` is the names of columns to combine. for example:
            `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
        transform: transform to apply on the loaded items of a dictionary data.
        kwargs_read_csv: dictionary args to pass to pandas `read_csv` function.
        kwargs: additional arguments for `pandas.merge()` API to join tables.

    NrU   str | Sequence[str] | Nonerow_indicesSequence[int | str] | None	col_namesrg  	col_types'dict[str, dict[str, Any] | None] | None
col_groupsdict[str, Sequence[str]] | Noner+   r[   kwargs_read_csvr   c                   s   t |ttfs
|fn|}	g }
|	D ]*}t |tr+|
|r$tj|fi |nt| qt |tjr7|
| qtdt	d|
||||d|}t
 j||d d S )Nz.`src` must be file path or pandas `DataFrame`.)dfsru  rw  rx  rz  r?   r4   )r/   rN  r  re   r{   pdread_csv	DataFramer1   r   rT   r6   )r2   rU   ru  rw  rx  rz  r+   r|  rV   Zsrcsr}  r  r*   rX   r4   r5   r6     s   
(
zCSVDataset.__init__)NNNNNNN)rU   rt  ru  rv  rw  rg  rx  ry  rz  r{  r+   r[   r|  r   )rN   rO   rP   rQ   r6   r\   r4   r4   rX   r5   rs    s    1rs  c                      sB   e Zd ZdZeddfd fddZdd Zdd Zdd Z  Z	S )
GDSDataseta  
    An extension of the PersistentDataset using direct memory access(DMA) data path between
    GPU memory and storage, thus avoiding a bounce buffer through the CPU. This direct path can increase system
    bandwidth while decreasing latency and utilization load on the CPU and GPU.

    A tutorial is available: https://github.com/Project-MONAI/tutorials/blob/main/modules/GDS_dataset.ipynb.

    See also: https://github.com/rapidsai/kvikio
    NTr*   r   r+   r_   r`   ra   devicer7   rb   rc   rg   rh   r   ri   rV   r   r-   r.   c           	   	     s0   t  jd||||||d| || _i | _dS )aM  
        Args:
            data: input data file paths to load and transform to generate dataset for model.
                `GDSDataset` expects input data to be a list of serializable
                and hashes them as cache keys using `hash_func`.
            transform: transforms to execute operations on input data.
            cache_dir: If specified, this is the location for gpu direct storage
                of pre-computed transformed data tensors. The cache_dir is computed once, and
                persists on disk until explicitly removed.  Different runs, programs, experiments
                may share a common cache dir provided that the transforms pre-processing is consistent.
                If `cache_dir` doesn't exist, will automatically create it.
                If `cache_dir` is `None`, there is effectively no caching.
            device: target device to put the output Tensor data. Note that only int can be used to
                specify the gpu to be used.
            hash_func: a callable to compute hash from data items to be cached.
                defaults to `monai.data.utils.pickle_hashing`.
            hash_transform: a callable to compute hash from the transform information when caching.
                This may reduce errors due to transforms changing during experiments. Default to None (no hash).
                Other options are `pickle_hashing` and `json_hashing` functions from `monai.data.utils`.
            reset_ops_id: whether to set `TraceKeys.ID` to ``Tracekys.NONE``, defaults to ``True``.
                When this is enabled, the traced transform instance IDs will be removed from the cached MetaTensors.
                This is useful for skipping the transform instance checks when inverting applied operations
                using the cached content and with re-created transform instances.

        )r*   r+   r`   rb   rg   r   Nr4   )rT   r6   r  _meta_cache)	r2   r*   r+   r`   r  rb   rg   r   rV   rX   r4   r5   r6     s   $	
zGDSDataset.__init__c              	   C  s  d}| j dur| |d}|| j7 }| j | d }|dur@| r@tj| j	 t	|t
ri }|D ]=}| j|j d| dd}tj| d| |d tdd	||< t|| |d
 d| j d||< ||| d< q8|W  d   S t	|tjtjfr| j|j dd}tj| |d tdd	}t||d
 d| j d}ttdd | }	t|	r||fW  d   S |W  d   S dd tt|D }t|D ]L\}
}|D ]E}| j|j d| d|
 d}tj| d| d|
 |d tdd	}t||
 |d
 d| j d}||
 ||| d|i qq|W  d   S 1 s;w   Y  | t|}|du rN|S t	|t
r|D ],}| d| }|j d| d}t	|| tjtjfr|  || || qV|  S nGt	|tjtjfr| }|j d}|  ||| n,t|D ]'\}
}|D ]}| d| d|
 }|j d| d|
 }|  ||| qqt!|d"  |S )a  
        In order to enable direct storage to the GPU when loading the hashfile, rewritten this function.
        Note that in this function, it will always return `torch.Tensor` when load data from cache.

        Args:
            item_transformed: The current data element to be mutated into transformed representation

        Returns:
            The transformed data_element, either from cache, or explicitly computing it.

        Warning:
            The current implementation does not encode transform information as part of the
            hashing mechanism used for generating cache names when `hash_transform` is None.
            If the transforms applied are changed in any way, the objects in the cache dir will be invalid.

        Nrx   r   -z-meta)meta_hash_file_namer]  r4   )r]  likern  zcuda:)r  Z
_meta_dictc                 S  s   | dvS )N)r]  rn  r4   )r   r4   r4   r5   r   N  s    z(GDSDataset._cachecheck.<locals>.<lambda>c                 S  s   g | ]}i qS r4   r4   r   r4   r4   r5   r"  S  r#  z*GDSDataset._cachecheck.<locals>.<listcomp>z-meta-a)#r`   rb   r~   rq   r   cpcudaDevicer  r/   r   _load_meta_cacher   kvikio_numpyfromfileemptyr    reshaperl  ndarrayr   Tensorr  filterrc  ri   rE   r:   r
  updater   r   _create_new_cacher   r   )r2   r   r   r   r   rk  Zmeta_k_meta_dataZfiltered_keysr  Z_itemZmeta_i_kZitem_kr   data_hashfiler  r4   r4   r5   r   (  sz   


(&	 ""
zGDSDataset._cachecheckc              	   C  sl  t |tr
t|jni | j|< t |tr|jn|}t |tjr#| }|j	| j| d< t
|j| j| d< t|| zpt a}| j| }t|| }tjt| j| dd|t| jt| jd | r}| sz
tt
|| W n ty|   Y nw W d    W d S W d    W d S W d    W d S W d    W d S 1 sw   Y  W d S  ty   Y d S w )Nrn  r]  Fr   r   )r/   r   r   metar  arrayr   r  numpyrn  re   r]  r  tofiler   r   r`   r
   r   r    r"   rd   r   rf   r   r   r   r   r   )r2   r*   r  r  Z_item_transformed_datar   Zmeta_hash_filer   r4   r4   r5   r  w  sD   


	&zGDSDataset._create_new_cachec                 C  s(   || j v r
| j | S tj| j| ddS )NTr   )r  r   r   r`   )r2   r  r4   r4   r5   r    s   

zGDSDataset._load_meta_cache)r*   r   r+   r_   r`   ra   r  r7   rb   rc   rg   rh   r   ri   rV   r   r-   r.   )
rN   rO   rP   rQ   r   r6   r   r  r  r\   r4   r4   rX   r5   r    s    0O r  )Q
__future__r   collections.abcrF   r-  r   r   r   r   rA  r   r   r   r   r   ior   Zmultiprocessing.managersr   Zmultiprocessing.poolr	   pathlibr
   r^   r   typingr   r   r   r   r  rl  r   torch.multiprocessingr   Ztorch.serializationr   torch.utils.datar   Z_TorchDatasetr   monai.data.meta_tensorr   monai.data.utilsr   r   r   monai.transformsr   r   r   r   r   r   monai.utilsr   r    r!   r"   r#   r$   monai.utils.miscr%   r&   r   r  r!  r(   r~  r  rR   r]   r   r   r   r  rL  rR  r`  rs  r  r4   r4   r4   r5   <module>   sl     72  ^ ; l  22c4K