U
    Ph                     @  s   d dl mZ d dlZd dlZd dlZd dlZd dlmZmZ d dl	m
Z
 d dlmZ d dlZd dlmZmZmZmZmZ d dlmZ d dlmZ d d	lmZmZmZmZmZmZ d d
l m!Z!m"Z" d dl#m$Z$ ddddgZ%G dd de"eZ&G dd de"eZ'G dd de"eZ(G dd dZ)dS )    )annotationsN)CallableSequence)Path)Any)DCM_FILENAME_REGEXdownload_tcia_series_instanceget_tcia_metadataget_tcia_ref_uidmatch_tcia_ref_uid_in_study)download_and_extract)PathLike)CacheDatasetPydicomReaderload_decathlon_datalistload_decathlon_propertiespartition_datasetselect_cross_validation_folds)
LoadImagedRandomizable)ensure_tupleMedNISTDatasetDecathlonDatasetCrossValidationTciaDatasetc                   @  s   e Zd ZdZdZdZdZdZdddd	d	ej	d
dddddfddddddddddddddddddZ
dddddZddddZddd d!d"Zd#S )$r   aQ
  
    The Dataset to automatically download MedNIST data and generate items for training, validation or test.
    It's based on `CacheDataset` to accelerate the training process.

    Args:
        root_dir: target directory to download and load MedNIST dataset.
        section: expected data section, can be: `training`, `validation` or `test`.
        transform: transforms to execute operations on input data.
        download: whether to download and extract the MedNIST from resource link, default is False.
            if expected file already exists, skip downloading even set it to True.
            user can manually copy `MedNIST.tar.gz` file or `MedNIST` folder to root directory.
        seed: random seed to randomly split training, validation and test datasets, default is 0.
        val_frac: percentage of validation fraction in the whole dataset, default is 0.1.
        test_frac: percentage of test fraction in the whole dataset, default is 0.1.
        cache_num: number of items to be cached. Default is `sys.maxsize`.
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        cache_rate: percentage of cached data in total, default is 1.0 (cache all).
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        num_workers: the number of worker threads if computing cache in the initialization.
            If num_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
            default to `True`. if the random transforms don't modify the cached content
            (for example, randomly crop from the cached image and deepcopy the crop region)
            or if every cache item is only used once in a `multi-processing` environment,
            may set `copy=False` for better performance.
        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
            it may help improve the performance of following logic.
        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.

    Raises:
        ValueError: When ``root_dir`` is not a directory.
        RuntimeError: When ``dataset_dir`` doesn't exist and downloading is not selected (``download=False``).

    z]https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/MedNIST.tar.gzZ 0bc7306e7427e00ad1c5526a6677552dzMedNIST.tar.gzZMedNIST Fr   g?      ?   Tr   strSequence[Callable] | Callableboolintfloatz
int | NoneNone)root_dirsection	transformdownloadseedval_frac	test_frac	cache_num
cache_ratenum_workersprogress
copy_cacheas_contiguousruntime_cachereturnc                 C  s   t |}| std|| _|| _|| _| j|d || j }|| j }d| _	|rlt
| j||| jd|d | std| d| |}|dkrtd	}tj| ||||	|
||||d

 d S )N,Root directory root_dir must be a directory.r(   r   md5urlfilepath
output_dirZhash_valZ	hash_typer.   Cannot find dataset directory: *, please use download=True to download it.r   image	datar&   r+   r,   r-   r.   r/   r0   r1   )r   is_dir
ValueErrorr%   r)   r*   set_random_statecompressed_file_namedataset_folder_name	num_classr   resourcer5   RuntimeError_generate_data_listr   r   __init__)selfr$   r%   r&   r'   r(   r)   r*   r+   r,   r-   r.   r/   r0   r1   tarfile_namedataset_dirr>   r   r   H/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/datasets.pyrH   [   sL    

	

zMedNISTDataset.__init__
np.ndarrayr>   r2   c                 C  s   | j | d S NRshufflerI   r>   r   r   rL   	randomize   s    zMedNISTDataset.randomizer2   c                 C  s   | j S )zGet number of classes.)rD   rI   r   r   rL   get_num_classes   s    zMedNISTDataset.get_num_classes
list[dict]rK   r2   c           	        sf  t tdd  D t| _fddt| jD fddt| jD }g g g  t| jD ]>}|  |g||    | g||   qpt}t|}| 	| t
|| j }t
|| j }| jdkr|d| }nN| jdkr ||||  }n0| jd	kr>||| d }ntd
| j d fdd|D S )zu
        Raises:
            ValueError: When ``section`` is not one of ["training", "validation", "test"].

        c                 s  s   | ]}|  r|j V  qd S rO   r?   name.0xr   r   rL   	<genexpr>   s      z5MedNISTDataset._generate_data_list.<locals>.<genexpr>c                   s&   g | ]}d d  |    D qS )c                 S  s   g | ]
}| qS r   r   r\   r   r   rL   
<listcomp>   s     zAMedNISTDataset._generate_data_list.<locals>.<listcomp>.<listcomp>)iterdirr]   i)class_namesrK   r   rL   r`      s     z6MedNISTDataset._generate_data_list.<locals>.<listcomp>c                   s   g | ]}t  | qS r   )lenrb   )image_filesr   rL   r`      s     testN
validationtrainingzUnsupported section: z;, available options are ["training", "validation", "test"].c                   s$   g | ]}| |  | d qS ))r<   label
class_namer   rb   )rk   image_classimage_files_listr   rL   r`      s   )r   sortedra   re   rD   rangeextendnparangerT   r!   r*   r)   r%   r@   )	rI   rK   Znum_eachrc   lengthindicesZtest_length
val_lengthZsection_indicesr   )rk   rd   rK   rl   rf   rm   rL   rG      s:    


z"MedNISTDataset._generate_data_listN)__name__
__module____qualname____doc__rE   r5   rB   rC   sysmaxsizerH   rT   rW   rG   r   r   r   rL   r   /   s*   &,9c                   @  s   e Zd ZdZdddddddd	d
dd
Zddddddddddd
Zddddejddddddfdddd d!d"d#d"d#d"d!d!d!d!d$d%d&d'Zd(d)d*d+Z	d(d$d,d-d.Z
d<d0d1d2d3d4Zdd5d6d7d8Zd5d5d9d:d;Zd/S )=r   a  
    The Dataset to automatically download the data of Medical Segmentation Decathlon challenge
    (http://medicaldecathlon.com/) and generate items for training, validation or test.
    It will also load these properties from the JSON config file of dataset. user can call `get_properties()`
    to get specified properties or all the properties loaded.
    It's based on :py:class:`monai.data.CacheDataset` to accelerate the training process.

    Args:
        root_dir: user's local directory for caching and loading the MSD datasets.
        task: which task to download and execute: one of list ("Task01_BrainTumour", "Task02_Heart",
            "Task03_Liver", "Task04_Hippocampus", "Task05_Prostate", "Task06_Lung", "Task07_Pancreas",
            "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon").
        section: expected data section, can be: `training`, `validation` or `test`.
        transform: transforms to execute operations on input data.
            for further usage, use `EnsureChannelFirstd` to convert the shape to [C, H, W, D].
        download: whether to download and extract the Decathlon from resource link, default is False.
            if expected file already exists, skip downloading even set it to True.
            user can manually copy tar file or dataset folder to the root directory.
        val_frac: percentage of validation fraction in the whole dataset, default is 0.2.
        seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
            note to set same seed for `training` and `validation` sections.
        cache_num: number of items to be cached. Default is `sys.maxsize`.
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        cache_rate: percentage of cached data in total, default is 1.0 (cache all).
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        num_workers: the number of worker threads if computing cache in the initialization.
            If num_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
            default to `True`. if the random transforms don't modify the cached content
            (for example, randomly crop from the cached image and deepcopy the crop region)
            or if every cache item is only used once in a `multi-processing` environment,
            may set `copy=False` for better performance.
        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
            it may help improve the performance of following logic.
        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.

    Raises:
        ValueError: When ``root_dir`` is not a directory.
        ValueError: When ``task`` is not one of ["Task01_BrainTumour", "Task02_Heart",
            "Task03_Liver", "Task04_Hippocampus", "Task05_Prostate", "Task06_Lung", "Task07_Pancreas",
            "Task08_HepaticVessel", "Task09_Spleen", "Task10_Colon"].
        RuntimeError: When ``dataset_dir`` doesn't exist and downloading is not selected (``download=False``).

    Example::

        transform = Compose(
            [
                LoadImaged(keys=["image", "label"]),
                EnsureChannelFirstd(keys=["image", "label"]),
                ScaleIntensityd(keys="image"),
                ToTensord(keys=["image", "label"]),
            ]
        )

        val_data = DecathlonDataset(
            root_dir="./", task="Task09_Spleen", transform=transform, section="validation", seed=12345, download=True
        )

        print(val_data[0]["image"], val_data[0]["label"])

    zGhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task01_BrainTumour.tarzAhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task02_Heart.tarzAhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task03_Liver.tarzGhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task04_Hippocampus.tarzDhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task05_Prostate.tarz@https://msd-for-monai.s3-us-west-2.amazonaws.com/Task06_Lung.tarzDhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task07_Pancreas.tarzIhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task08_HepaticVessel.tarzBhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task09_Spleen.tarzAhttps://msd-for-monai.s3-us-west-2.amazonaws.com/Task10_Colon.tar)
ZTask01_BrainTumourZTask02_HeartZTask03_LiverZTask04_HippocampusZTask05_ProstateZTask06_LungZTask07_PancreasZTask08_HepaticVesselZTask09_SpleenZTask10_ColonZ 240a19d752f0d9e9101544901065d872Z 06ee59366e1e5124267b774dbd654057Z a90ec6c4aa7f6a3d087205e23d4e6397Z 9d24dba78a72977dbd1d2e110310f31bZ 35138f08b1efaef89d7424d2bcc928dbZ 8afd997733c7fc0432f71255ba4e52dcZ 4f7080cfca169fa8066d17ce6eb061e4Z 641d79e80ec66453921d997fbf12a29cZ 410d4a301da4e5b2f6f86ec3ddba524eZ bad7a188931dc2f6acf72b08eb6202d0r   Fr   皙?r   r   Tr   r   r   r    r!   r"   r#   )r$   taskr%   r&   r'   r(   r)   r+   r,   r-   r.   r/   r0   r1   r2   c                 C  s  t |}| std|| _|| _| j|d || jkrZtd| dt| j  d|| }| d}|rt	| j| ||| j
| d|d | std	| d
tg | _| |}dddddddddg	}t|d || _|dkrtddg}tj| ||||	|
||||d
 d S )Nr3   r4   zUnsupported task: z, available options are: .z.tarr5   r6   r:   r;   r[   description	referencelicenceZtensorImageSizeZmodalitylabelsZnumTrainingZnumTestdataset.jsonr   r<   rj   r=   )r   r?   r@   r%   r)   rA   rE   listkeysr   r5   existsrF   rq   arrayrt   rG   r   _propertiesr   r   rH   )rI   r$   r}   r%   r&   r'   r(   r)   r+   r,   r-   r.   r/   r0   r1   rK   rJ   r>   Zproperty_keysr   r   rL   rH      sd    
 
	

zDecathlonDataset.__init__rM   rU   c                 C  s   | j S zD
        Get the indices of datalist used in this dataset.

        rt   rV   r   r   rL   get_indicesg  s    zDecathlonDataset.get_indicesrN   c                 C  s   | j | d S rO   rP   rS   r   r   rL   rT   n  s    zDecathlonDataset.randomizeNzSequence[str] | str | Nonedict)r   r2   c                   s2   |dkr j S  j dk	r. fddt|D S i S )z
        Get the loaded properties of dataset with specified keys.
        If no keys specified, return all the loaded properties.

        Nc                   s   i | ]}| j | qS r   )r   )r]   keyrV   r   rL   
<dictcomp>z  s      z3DecathlonDataset.get_properties.<locals>.<dictcomp>)r   r   )rI   r   r   rV   rL   get_propertiesq  s
    
zDecathlonDataset.get_propertiesrX   rY   c                 C  s4   t |}| jdkrdnd}t|d d|}| |S )N)ri   rh   ri   rg   r   T)r   r%   r   _split_datalist)rI   rK   r%   datalistr   r   rL   rG   }  s    z$DecathlonDataset._generate_data_listr   r2   c                   st   | j dkr S t }t|}| | t|| j }| j dkrR||d  | _n|d | | _ fdd| jD S )Nrg   ri   c                   s   g | ]} | qS r   r   rb   r   r   rL   r`     s     z4DecathlonDataset._split_datalist.<locals>.<listcomp>r%   re   rq   rr   rT   r!   r)   rt   rI   r   rs   rt   ru   r   r   rL   r     s    



z DecathlonDataset._split_datalist)N)rv   rw   rx   ry   rE   r5   rz   r{   rH   r   rT   r   rG   r   r   r   r   rL   r      sP   B,Gc                   @  s   e Zd ZdZdddddddd	ed
dejddddddfdddddddddddddddddddddddddZddddZddddd Z	dddd!d"d#Z
dd$d%d&d'Zd$d$d(d)d*Zd+S ),r   aA  
    The Dataset to automatically download the data from a public The Cancer Imaging Archive (TCIA) dataset
    and generate items for training, validation or test.

    The Highdicom library is used to load dicom data with modality "SEG", but only a part of collections are
    supported, such as: "C4KC-KiTS", "NSCLC-Radiomics", "NSCLC-Radiomics-Interobserver1", " QIN-PROSTATE-Repeatability"
    and "PROSTATEx". Therefore, if "seg" is included in `keys` of the `LoadImaged` transform and loading some
    other collections, errors may be raised. For supported collections, the original "SEG" information may not
    always be consistent for each dicom file. Therefore, to avoid creating different format of labels, please use
    the `label_dict` argument of `PydicomReader` when calling the `LoadImaged` transform. The prepared label dicts
    of collections that are mentioned above is also saved in: `monai.apps.tcia.TCIA_LABEL_DICT`. You can also refer
    to the second example bellow.


    This class is based on :py:class:`monai.data.CacheDataset` to accelerate the training process.

    Args:
        root_dir: user's local directory for caching and loading the TCIA dataset.
        collection: name of a TCIA collection.
            a TCIA dataset is defined as a collection. Please check the following list to browse
            the collection list (only public collections can be downloaded):
            https://www.cancerimagingarchive.net/collections/
        section: expected data section, can be: `training`, `validation` or `test`.
        transform: transforms to execute operations on input data.
            for further usage, use `EnsureChannelFirstd` to convert the shape to [C, H, W, D].
            If not specified, `LoadImaged(reader="PydicomReader", keys=["image"])` will be used as the default
            transform. In addition, we suggest to set the argument `labels` for `PydicomReader` if segmentations
            are needed to be loaded. The original labels for each dicom series may be different, using this argument
            is able to unify the format of labels.
        download: whether to download and extract the dataset, default is False.
            if expected file already exists, skip downloading even set it to True.
            user can manually copy tar file or dataset folder to the root directory.
        download_len: number of series that will be downloaded, the value should be larger than 0 or -1, where -1 means
            all series will be downloaded. Default is -1.
        seg_type: modality type of segmentation that is used to do the first step download. Default is "SEG".
        modality_tag: tag of modality. Default is (0x0008, 0x0060).
        ref_series_uid_tag: tag of referenced Series Instance UID. Default is (0x0020, 0x000e).
        ref_sop_uid_tag: tag of referenced SOP Instance UID. Default is (0x0008, 0x1155).
        specific_tags: tags that will be loaded for "SEG" series. This argument will be used in
            `monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010),
            (0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)].
        fname_regex: a regular expression to match the file names when the input is a folder.
            If provided, only the matched files will be included. For example, to include the file name
            "image_0001.dcm", the regular expression could be `".*image_(\d+).dcm"`.
            Default to `"^(?!.*LICENSE).*"`, ignoring any file name containing `"LICENSE"`.
        val_frac: percentage of validation fraction in the whole dataset, default is 0.2.
        seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0.
            note to set same seed for `training` and `validation` sections.
        cache_num: number of items to be cached. Default is `sys.maxsize`.
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        cache_rate: percentage of cached data in total, default is 0.0 (no cache).
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        num_workers: the number of worker threads if computing cache in the initialization.
            If num_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        progress: whether to display a progress bar when downloading dataset and computing the transform cache content.
        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
            default to `True`. if the random transforms don't modify the cached content
            (for example, randomly crop from the cached image and deepcopy the crop region)
            or if every cache item is only used once in a `multi-processing` environment,
            may set `copy=False` for better performance.
        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
            it may help improve the performance of following logic.
        runtime_cache: whether to compute cache at the runtime, default to `False` to prepare
            the cache content at initialization. See: :py:class:`monai.data.CacheDataset`.

    Example::

        # collection is "Pancreatic-CT-CBCT-SEG", seg_type is "RTSTRUCT"
        data = TciaDataset(
            root_dir="./", collection="Pancreatic-CT-CBCT-SEG", seg_type="RTSTRUCT", download=True
        )

        # collection is "C4KC-KiTS", seg_type is "SEG", and load both images and segmentations
        from monai.apps.tcia import TCIA_LABEL_DICT
        transform = Compose(
            [
                LoadImaged(reader="PydicomReader", keys=["image", "seg"], label_dict=TCIA_LABEL_DICT["C4KC-KiTS"]),
                EnsureChannelFirstd(keys=["image", "seg"]),
                ResampleToMatchd(keys="image", key_dst="seg"),
            ]
        )
        data = TciaDataset(
            root_dir="./", collection="C4KC-KiTS", section="validation", seed=12345, download=True
        )

        print(data[0]["seg"].shape)

    r   FZSEG)   `   )       )r   iU  ))r   i  )r   i@  )i0     )r      )r   r   )r   r   )r      )r      r   r|   g        r   Tr   r   r   r    r!   tupler"   r#   )r$   
collectionr%   r&   r'   download_lenseg_typemodality_tagref_series_uid_tagref_sop_uid_tagspecific_tagsfname_regexr(   r)   r+   r,   r-   r.   r/   r0   r1   r2   c                 C  sN  t |}| std|| _|| _|| _|| _|	| _|
| _| j	|d t
j||}t|}||g7 }|| _|rtd| d| dd}|dkr|d | }t|dkrtd| d	| |D ]}| || qt
j|std
| d|| _tg | _| || _|dkr*tdgd| jd}tj| | j||||||||d
 d S )Nr3   r4   zgetSeries?Collection=z
&Modality=ZSeriesInstanceUID)query	attributer   z"Cannot find data with collection: z seg_type: r:   r~   r   r<   r   )r   readerr   r=   )r   r?   r@   r%   r)   r   r   r   r   rA   ospathjoinr   	load_tagsr	   re   _download_series_reference_datar   rF   r   rq   r   rt   rG   r   r   r   rH   )rI   r$   r   r%   r&   r'   r   r   r   r   r   r   r   r(   r)   r+   r,   r-   r.   r/   r0   r1   download_dirr   Zseg_series_list
series_uidr   r   rL   rH     sX    !
 
zTciaDataset.__init__rM   rU   c                 C  s   | j S r   r   rV   r   r   rL   r   A  s    zTciaDataset.get_indicesrN   c                 C  s   | j | d S rO   rP   rS   r   r   rL   rT   H  s    zTciaDataset.randomize)r   r   r2   c                 C  s  t j|d|}t|||dd dd tt |D }t j||d }td| jd|}|j	rj|j	n|j
}|std	| d
 d}|jr|jn|j}|std| d d}t|}t j|||| j }	t j|||d}
g }|D ]}t j||}td| jd|}|| j j| jkrt|d| j| jd}|dkrft|d| j| jd}t|j|}|dkr|| q|std| d nt|d ||
dd t j|	st||	 dS )z
        First of all, download a series from TCIA according to `series_uid`.
        Then find all referenced series and download.
        rawF)r   r   r9   Z	check_md5c                 S  s   g | ]}| d r|qS )z.dcm)endswithr]   fr   r   rL   r`   T  s     
 z?TciaDataset._download_series_reference_data.<locals>.<listcomp>r   T)Zstop_before_pixelsr   z+unable to find patient name of dicom file: z, use 'patient' instead.Zpatientz,unable to find series number of dicom file: z, use '0' instead.r<   )Zfind_sopr   r    z<Cannot find the referenced Series Instance UID from series: r~   N)r   r   r   r   rn   listdirr   r   readZ	PatientIDZPatientNamewarningswarnZSeriesNumberZAcquisitionNumberr   r   lowerr   valuer
   r   r   r   ZStudyInstanceUIDappendr   shutilcopytree)rI   r   r   Zseg_first_dirZdicom_filesZdcm_pathds
patient_id
series_numZseg_dirZdcm_dirZref_uid_listZdcm_fileZref_uidZref_sop_uidr   r   rL   r   K  sh          
   z+TciaDataset._download_series_reference_datarX   rY   c           
      C  s   t |}g }dd t|D }|D ]}dd ttj||D }|D ]`}| j }tj|||d}tj||||}	tj|r|d|||	i qJ|||	i qJq$| 	|S )Nc                 S  s$   g | ]}|  r|jd kr|jqS )r   rZ   r   r   r   rL   r`     s      
 z3TciaDataset._generate_data_list.<locals>.<listcomp>c                 S  s   g | ]}|  r|jqS r   rZ   r   r   r   rL   r`     s      r<   )
r   r   scandirr   r   r   r   r   r   r   )
rI   rK   r   Zpatient_listr   Zseries_listr   Zseg_key
image_pathZ	mask_pathr   r   rL   rG     s    
zTciaDataset._generate_data_listr   c                   st   | j dkr S t }t|}| | t|| j }| j dkrR||d  | _n|d | | _ fdd| jD S )Nrg   ri   c                   s   g | ]} | qS r   r   rb   r   r   rL   r`     s     z/TciaDataset._split_datalist.<locals>.<listcomp>r   r   r   r   rL   r     s    



zTciaDataset._split_datalistN)rv   rw   rx   ry   r   rz   r{   rH   r   rT   r   rG   r   r   r   r   rL   r     s2   _
:R8c                   @  s:   e Zd ZdZdddddddd	d
ZddddddZdS )r   aa  
    Cross validation dataset based on the general dataset which must have `_split_datalist` API.

    Args:
        dataset_cls: dataset class to be used to create the cross validation partitions.
            It must have `_split_datalist` API.
        nfolds: number of folds to split the data for cross validation.
        seed: random seed to randomly shuffle the datalist before splitting into N folds, default is 0.
        dataset_params: other additional parameters for the dataset_cls base class.

    Example of 5 folds cross validation training::

        cvdataset = CrossValidation(
            dataset_cls=DecathlonDataset,
            nfolds=5,
            seed=12345,
            root_dir="./",
            task="Task09_Spleen",
            section="training",
            transform=train_transform,
            download=True,
        )
        dataset_fold0_train = cvdataset.get_dataset(folds=[1, 2, 3, 4])
        dataset_fold0_val = cvdataset.get_dataset(folds=0, transform=val_transform, download=False)
        # execute training for fold 0 ...

        dataset_fold1_train = cvdataset.get_dataset(folds=[0, 2, 3, 4])
        dataset_fold1_val = cvdataset.get_dataset(folds=1, transform=val_transform, download=False)
        # execute training for fold 1 ...

        ...

        dataset_fold4_train = ...
        # execute training for fold 4 ...

       r   objectr!   r   r#   )dataset_clsnfoldsr(   dataset_paramsr2   c                 K  s.   t |dstd|| _|| _|| _|| _d S )Nr   z,dataset class must have _split_datalist API.)hasattrr@   r   r   r(   r   )rI   r   r   r(   r   r   r   rL   rH     s    
zCrossValidation.__init__zSequence[int] | int)foldsr   r2   c                   sD   | j | jt| j}|| G  fddd| j}|f |S )a  
        Generate dataset based on the specified fold indices in the cross validation group.

        Args:
            folds: index of folds for training or validation, if a list of values, concatenate the data.
            dataset_params: other additional parameters for the dataset_cls base class, will override
                the same parameters in `self.dataset_params`.

        c                      s$   e Zd Zddd fddZdS )z4CrossValidation.get_dataset.<locals>._NsplitsDatasetrX   r   c                   s   t |dd}t| dS )NT)r>   Znum_partitionsrR   r(   )
partitionsr   )r   r   )rI   r   r>   r   r   r(   r   rL   r     s    zDCrossValidation.get_dataset.<locals>._NsplitsDataset._split_datalistN)rv   rw   rx   r   r   r   r   rL   _NsplitsDataset  s   r   )r   r(   r   r   updater   )rI   r   r   Zdataset_params_r   r   r   rL   get_dataset  s    


zCrossValidation.get_datasetN)r   r   )rv   rw   rx   ry   rH   r   r   r   r   rL   r     s   %)*
__future__r   r   r   rz   r   collections.abcr   r   pathlibr   typingr   numpyrq   Zmonai.apps.tciar   r   r	   r
   r   Zmonai.apps.utilsr   monai.config.type_definitionsr   Z
monai.datar   r   r   r   r   r   monai.transformsr   r   monai.utilsr   __all__r   r   r   r   r   r   r   rL   <module>   s.      P  