U
    Ph,3                     @  s   d dl mZ d dlmZmZmZmZ d dlmZ d dl	m
Z d dl	mZ d dlmZ d dlmZ d dlmZ d d	lmZ ed
\ZZG dd deZ
G dd dee
ZG dd de
ZdS )    )annotations)CallableIterableIteratorSequence)Any)IterableDataset)get_worker_info)convert_tables_to_dicts)apply_transform)Randomizable)optional_importpandasc                   @  s,   e Zd ZdZdddddddZd	d
 ZdS )r   a  
    A generic dataset for iterable data source and an optional callable data transform
    when fetching a data sample. Inherit from PyTorch IterableDataset:
    https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset.
    For example, typical input data can be web data stream which can support multi-process access.

    To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,
    every process executes transforms on part of every loaded data.
    Note that the order of output data may not match data source in multi-processing mode.
    And each worker process will have a different copy of the dataset object, need to guarantee
    process-safe from data source or DataLoader.

    NzIterable[Any]Callable | NoneNone)data	transformreturnc                 C  s   || _ || _d| _dS )z
        Args:
            data: input data source to load and transform to generate dataset for model.
            transform: a callable data transform on input data.
        N)r   r   source)selfr   r    r   P/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/data/iterable_dataset.py__init__+   s    zIterableDataset.__init__c                 c  sv   t  }|d k	r|jnd}|d k	r&|jnd}t| j| _t| jD ]0\}}|| |kr@| jd k	rjt| j|}|V  q@d S )N   r   )	r	   num_workersiditerr   r   	enumerater   r   )r   infor   r   iitemr   r   r   __iter__5   s    
zIterableDataset.__iter__)N)__name__
__module____qualname____doc__r   r!   r   r   r   r   r      s   
r   c                      sZ   e Zd ZdZdddddd fd	d
Zdd Zdd Z fddZdddddZ  Z	S )ShuffleBuffera  
    Extend the IterableDataset with a buffer and randomly pop items.

    Args:
        data: input data source to load and transform to generate dataset for model.
        transform: a callable data transform on input data.
        buffer_size: size of the buffer to store items and randomly pop, default to 512.
        seed: random seed to initialize the random state of all workers, set `seed += 1` in
            every iter() call, refer to the PyTorch idea:
            https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
        epochs: number of epochs to iterate over the dataset, default to 1, -1 means infinite epochs.

    Note:
        Both ``monai.data.DataLoader`` and ``torch.utils.data.DataLoader`` do not seed this class (as a subclass of
        ``IterableDataset``) at run time. ``persistent_workers=True`` flag (and pytorch>1.8) is therefore required
        for multiple epochs of loading when ``num_workers>0``. For example::

            import monai

            def run():
                dss = monai.data.ShuffleBuffer([1, 2, 3, 4], buffer_size=30, seed=42)

                dataloader = monai.data.DataLoader(
                    dss, batch_size=1, num_workers=2, persistent_workers=True)
                for epoch in range(3):
                    for item in dataloader:
                        print(f"epoch: {epoch} item: {item}.")

            if __name__ == '__main__':
                run()

    N   r   r   intr   )buffer_sizeseedepochsr   c                   s,   t  j||d || _|| _|| _d| _d S )Nr   r   r   )superr   sizer*   r+   _idx)r   r   r   r)   r*   r+   	__class__r   r   r   d   s
    zShuffleBuffer.__init__c                 C  s4   |  t| || j |d  }|| j< |  |S )zAReturn the item at a randomized location `self._idx` in `buffer`.)	randomizelenr/   pop)r   bufferretr   r   r   randomized_popk   s    zShuffleBuffer.randomized_popc                 c  sN   g }t | jD ](}t|| jkr,| |V  || q|rJ| |V  q8dS )zLFill a `buffer` list up to `self.size`, then generate randomly popped items.N)r   r   r4   r.   r8   append)r   r6   r    r   r   r   generate_itemr   s    zShuffleBuffer.generate_itemc                 #  s^   |  j d7  _ t j| j d | jdkr2t| jnttdD ]}t|  | j	dE dH  q<dS )z
        Randomly pop buffered items from `self.data`.
        Multiple dataloader workers sharing this dataset will generate identical item sequences.
        r   )r*   r   )r   N)
r*   r-   set_random_stater+   ranger   r(   r   r:   r   )r   _r0   r   r   r!   |   s    "zShuffleBuffer.__iter__)r.   r   c                 C  s   | j || _d S )N)Rrandintr/   )r   r.   r   r   r   r3      s    zShuffleBuffer.randomize)Nr'   r   r   )
r"   r#   r$   r%   r   r8   r:   r!   r3   __classcell__r   r   r0   r   r&   B   s   !

r&   c                      sb   e Zd ZdZddddd	d
dddddd
 fddZdddddZdd Zdd Zdd Z  Z	S )CSVIterableDataseta  
    Iterable dataset to load CSV files and generate dictionary data.
    It is particularly useful when data come from a stream, inherits from PyTorch IterableDataset:
    https://pytorch.org/docs/stable/data.html?highlight=iterabledataset#torch.utils.data.IterableDataset.

    It also can be helpful when loading extremely big CSV files that can't read into memory directly,
    just treat the big CSV file as stream input, call `reset()` of `CSVIterableDataset` for every epoch.
    Note that as a stream input, it can't get the length of dataset.

    To effectively shuffle the data in the big dataset, users can set a big buffer to continuously store
    the loaded data, then randomly pick data from the buffer for following tasks.

    To accelerate the loading process, it can support multi-processing based on PyTorch DataLoader workers,
    every process executes transforms on part of every loaded data.
    Note: the order of output data may not match data source in multi-processing mode.

    It can load data from multiple CSV files and join the tables with additional `kwargs` arg.
    Support to only load specific columns.
    And it can also group several loaded columns to generate a new column, for example,
    set `col_groups={"meta": ["meta_0", "meta_1", "meta_2"]}`, output can be::

        [
            {"image": "./image0.nii", "meta_0": 11, "meta_1": 12, "meta_2": 13, "meta": [11, 12, 13]},
            {"image": "./image1.nii", "meta_0": 21, "meta_1": 22, "meta_2": 23, "meta": [21, 22, 23]},
        ]

    Args:
        src: if provided the filename of CSV file, it can be a str, URL, path object or file-like object to load.
            also support to provide iter for stream input directly, will skip loading from filename.
            if provided a list of filenames or iters, it will join the tables.
        chunksize: rows of a chunk when loading iterable data from CSV files, default to 1000. more details:
            https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html.
        buffer_size: size of the buffer to store the loaded chunks, if None, set to `2 x chunksize`.
        col_names: names of the expected columns to load. if None, load all the columns.
        col_types: `type` and `default value` to convert the loaded columns, if None, use original data.
            it should be a dictionary, every item maps to an expected column, the `key` is the column
            name and the `value` is None or a dictionary to define the default value and data type.
            the supported keys in dictionary are: ["type", "default"]. for example::

                col_types = {
                    "subject_id": {"type": str},
                    "label": {"type": int, "default": 0},
                    "ehr_0": {"type": float, "default": 0.0},
                    "ehr_1": {"type": float, "default": 0.0},
                    "image": {"type": str, "default": None},
                }

        col_groups: args to group the loaded columns to generate a new column,
            it should be a dictionary, every item maps to a group, the `key` will
            be the new column name, the `value` is the names of columns to combine. for example:
            `col_groups={"ehr": [f"ehr_{i}" for i in range(10)], "meta": ["meta_1", "meta_2"]}`
        transform: transform to apply on the loaded items of a dictionary data.
        shuffle: whether to shuffle all the data in the buffer every time a new chunk loaded.
        seed: random seed to initialize the random state for all the workers if `shuffle` is True,
            set `seed += 1` in every iter() call, refer to the PyTorch idea:
            https://github.com/pytorch/pytorch/blob/v1.10.0/torch/utils/data/distributed.py#L98.
        kwargs_read_csv: dictionary args to pass to pandas `read_csv` function. Default to ``{"chunksize": chunksize}``.
        kwargs: additional arguments for `pandas.merge()` API to join tables.

      NFr   z3str | Sequence[str] | Iterable | Sequence[Iterable]r(   z
int | NonezSequence[str] | Nonez'dict[str, dict[str, Any] | None] | Nonezdict[str, Sequence[str]] | Noner   boolzdict | None)
src	chunksizer)   	col_names	col_types
col_groupsr   shuffler*   kwargs_read_csvc                   sr   || _ || _|d krd| n|| _|| _|| _|| _|| _|	| _|
pJd|i| _|| _	| 
 | _t jd |d d S )N   rE   r,   )rD   rE   r)   rF   rG   rH   rI   r*   rJ   kwargsresetitersr-   r   )r   rD   rE   r)   rF   rG   rH   r   rI   r*   rJ   rL   r0   r   r   r      s    
zCSVIterableDataset.__init__z:str | Sequence[str] | Iterable | Sequence[Iterable] | None)rD   c                 C  s   |dkr| j n|}t|ttfs&|fn|}g | _|D ]H}t|tr\| jtj|f| j	 q4t|t
rt| j| q4tdq4| jS )a;  
        Reset the pandas `TextFileReader` iterable object to read data. For more details, please check:
        https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.

        Args:
            src: if not None and provided the filename of CSV file, it can be a str, URL, path object
                or file-like object to load. also support to provide iter for stream input directly,
                will skip loading from filename. if provided a list of filenames or iters, it will join the tables.
                default to `self.src`.

        Nz+`src` must be file path or iterable object.)rD   
isinstancetuplelistrN   strr9   pdread_csvrJ   r   
ValueError)r   rD   srcsr   r   r   r   rM      s    


zCSVIterableDataset.resetc                 C  s   | j D ]}|  qdS )a  
        Close the pandas `TextFileReader` iterable objects.
        If the input src is file path, TextFileReader was created internally, need to close it.
        If the input src is iterable object, depends on users requirements whether to close it in this function.
        For more details, please check:
        https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html?#iteration.

        N)rN   close)r   r   r   r   r   rW      s    	
zCSVIterableDataset.closec                 c  s:   t | j D ]*}tf || j| j| jd| jE d H  q
d S )N)dfsrF   rG   rH   )ziprN   r
   rF   rG   rH   rL   )r   chunksr   r   r   
_flattened  s    zCSVIterableDataset._flattenedc                 c  sT   | j r8|  jd7  _t|  | j| j| jd}|E d H  t|  | jdE d H  d S )Nr   )r   r   r)   r*   r,   )rI   r*   r&   r[   r   r)   r   )r   r6   r   r   r   r!     s       
zCSVIterableDataset.__iter__)	rB   NNNNNFr   N)N)
r"   r#   r$   r%   r   rM   rW   r[   r!   r@   r   r   r0   r   rA      s   @         &
rA   N)
__future__r   collections.abcr   r   r   r   typingr   torch.utils.datar   Z_TorchIterableDatasetr	   monai.data.utilsr
   monai.transformsr   monai.transforms.transformr   monai.utilsr   rS   r=   r&   rA   r   r   r   r   <module>   s   &H