o
    #iL                     @  s\  d dl mZ d dlZd dlZd dlmZmZmZmZm	Z	m
Z
mZ d dlmZ d dlmZ d dlmZ d dlmZ d dlZd dlZd dlmZ d d	lmZ d d
lmZ d dlmZmZ d dl m!Z!m"Z"m#Z#m$Z$m%Z% d dl&m'Z'm(Z(m)Z)m*Z*m+Z+ erd dl,m,Z, dZ-n	e+dde*d\Z,Z-g dZ.G dd dZ/G dd dZ0G dd deZ1G dd deZ2dS )    )annotationsN)Callable	GeneratorHashableIterableIteratorMappingSequence)deepcopy)	ListProxy)
ThreadPool)TYPE_CHECKING)KeysCollection)NdarrayTensor)IterableDataset)
iter_patchpickle_hashing)ComposeRandomizableTrait	Transformapply_transformconvert_to_contiguous)NumpyPadModeensure_tuplefirstmin_versionoptional_import)tqdmTr   z4.47.0)PatchDatasetGridPatchDataset	PatchIter
PatchIterdc                   @  s,   e Zd ZdZdejfdd
dZdddZdS )r    z
    Return a patch generator with predefined properties such as `patch_size`.
    Typically used with :py:class:`monai.data.GridPatchDataset`.

     
patch_sizeSequence[int]	start_posmode
str | Nonepad_optsdictc                 K  s(   dt | | _t|| _|| _|| _dS )a  

        Args:
            patch_size: size of patches to generate slices for, 0/None selects whole dimension
            start_pos: starting position in the array, default is 0 for each dimension
            mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
                ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
                (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
                One of the listed string values or a user supplied function.
                If None, no wrapping is performed. Defaults to ``"wrap"``.
                See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
                https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
                requires pytorch >= 1.10 for best compatibility.
            pad_opts: other arguments for the `np.pad` function.
                note that `np.pad` treats channel dimension as the first dimension.

        Note:
            The `patch_size` is the size of the
            patch to sample from the input arrays. It is assumed the arrays first dimension is the channel dimension which
            will be yielded in its entirety so this should not be specified in `patch_size`. For example, for an input 3D
            array with 1 channel of size (1, 20, 20, 20) a regular grid sampling of eight patches (1, 10, 10, 10) would be
            specified by a `patch_size` of (10, 10, 10).

        NN)tupler#   r   r%   r&   r(   )selfr#   r%   r&   r(   r"   r"   Y/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/data/grid_dataset.py__init__1   s   

zPatchIter.__init__arrayr   return7Generator[tuple[NdarrayTensor, np.ndarray], None, None]c                 c  s0    t |f| j| jdd| jd| jE dH  dS )zO
        Args:
            array: the image to generate patches from.

        g        F)r#   r%   overlap	copy_backr&   N)r   r#   r%   r&   r(   )r,   r/   r"   r"   r-   __call__U   s   zPatchIter.__call__N)r#   r$   r%   r$   r&   r'   r(   r)   )r/   r   r0   r1   )__name__
__module____qualname____doc__r   WRAPr.   r4   r"   r"   r"   r-   r    *   s    	$r    c                   @  s8   e Zd ZdZdZdZdZdejfdddZ	dddZ
dS )r!   aA  
    Dictionary-based wrapper of :py:class:`monai.data.PatchIter`.
    Return a patch generator for dictionary data and the coordinate, Typically used
    with :py:class:`monai.data.GridPatchDataset`.
    Suppose all the expected fields specified by `keys` have same shape.

    Args:
        keys: keys of the corresponding items to iterate patches.
        patch_size: size of patches to generate slices for, 0/None selects whole dimension
        start_pos: starting position in the array, default is 0 for each dimension
        mode: available modes: (Numpy) {``"constant"``, ``"edge"``, ``"linear_ramp"``, ``"maximum"``,
            ``"mean"``, ``"median"``, ``"minimum"``, ``"reflect"``, ``"symmetric"``, ``"wrap"``, ``"empty"``}
            (PyTorch) {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}.
            One of the listed string values or a user supplied function.
            If None, no wrapping is performed. Defaults to ``"wrap"``.
            See also: https://numpy.org/doc/stable/reference/generated/numpy.pad.html
            https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
            requires pytorch >= 1.10 for best compatibility.
        pad_opts: other arguments for the `np.pad` function.
            note that `np.pad` treats channel dimension as the first dimension.

    Zpatch_coordsoriginal_spatial_shaper%   r"   keysr   r#   r$   r&   r'   c                 K  s&   t || _td|||d|| _d S )N)r#   r%   r&   r"   )r   r;   r    
patch_iter)r,   r;   r#   r%   r&   r(   r"   r"   r-   r.      s   
zPatchIterd.__init__data Mapping[Hashable, NdarrayTensor]r0   JGenerator[tuple[Mapping[Hashable, NdarrayTensor], np.ndarray], None, None]c                 #  s    t |  tj jdd  }t fddjD  D ]@}|d d }dd tj|D }t  tjD ]
}t | ||< q=||j< ||j	< j
j|j< ||fV  qd S )N   c                   s   g | ]	}  | qS r"   )r<   ).0keydr,   r"   r-   
<listcomp>       z'PatchIterd.__call__.<locals>.<listcomp>r   c                 S  s   i | ]	\}}||d  qS )r   r"   )rA   kvr"   r"   r-   
<dictcomp>   rF   z'PatchIterd.__call__.<locals>.<dictcomp>)r)   r   r;   shapezipset
differencer
   
coords_keyoriginal_spatial_shape_keyr<   r%   start_pos_key)r,   r=   r:   patchcoordsretrG   r"   rC   r-   r4      s   

zPatchIterd.__call__N)r;   r   r#   r$   r%   r$   r&   r'   )r=   r>   r0   r?   )r5   r6   r7   r8   rN   rO   rP   r   r9   r.   r4   r"   r"   r"   r-   r!   f   s    r!   c                
      sn   e Zd ZdZdddejdddddef
d, fddZd-d d!Zd.d/d#d$Z	d0d&d'Z
d(d) Z fd*d+Z  ZS )1r   a  
    Yields patches from data read from an image dataset.
    Typically used with `PatchIter` or `PatchIterd` so that the patches are chosen in a contiguous grid sampling scheme.

     .. code-block:: python

        import numpy as np

        from monai.data import GridPatchDataset, DataLoader, PatchIter, RandShiftIntensity

        # image-level dataset
        images = [np.arange(16, dtype=float).reshape(1, 4, 4),
                  np.arange(16, dtype=float).reshape(1, 4, 4)]
        # image-level patch generator, "grid sampling"
        patch_iter = PatchIter(patch_size=(2, 2), start_pos=(0, 0))
        # patch-level intensity shifts
        patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)

        # construct the dataset
        ds = GridPatchDataset(data=images,
                              patch_iter=patch_iter,
                              transform=patch_intensity)
        # use the grid patch dataset
        for item in DataLoader(ds, batch_size=2, num_workers=2):
            print("patch size:", item[0].shape)
            print("coordinates:", item[1])

        # >>> patch size: torch.Size([2, 1, 2, 2])
        #     coordinates: tensor([[[0, 1], [0, 2], [0, 2]],
        #                          [[0, 1], [2, 4], [0, 2]]])

    Args:
        data: the data source to read image data from.
        patch_iter: converts an input image (item from dataset) into a iterable of image patches.
            `patch_iter(dataset[idx])` must yield a tuple: (patches, coordinates).
            see also: :py:class:`monai.data.PatchIter` or :py:class:`monai.data.PatchIterd`.
        transform: a callable data transform operates on the patches.
        with_coordinates: whether to yield the coordinates of each patch, default to `True`.
        cache: whether to use cache mache mechanism, default to `False`.
            see also: :py:class:`monai.data.CacheDataset`.
        cache_num: number of items to be cached. Default is `sys.maxsize`.
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        cache_rate: percentage of cached data in total, default is 1.0 (cache all).
            will take the minimum of (cache_num, data_length x cache_rate, data_length).
        num_workers: the number of worker threads if computing cache in the initialization.
            If num_workers is None then the number returned by os.cpu_count() is used.
            If a value less than 1 is specified, 1 will be used instead.
        progress: whether to display a progress bar.
        copy_cache: whether to `deepcopy` the cache content before applying the random transforms,
            default to `True`. if the random transforms don't modify the cached content
            (for example, randomly crop from the cached image and deepcopy the crop region)
            or if every cache item is only used once in a `multi-processing` environment,
            may set `copy=False` for better performance.
        as_contiguous: whether to convert the cached NumPy array or PyTorch tensor to be contiguous.
            it may help improve the performance of following logic.
        hash_func: a callable to compute hash from data items to be cached.
            defaults to `monai.data.utils.pickle_hashing`.

    NTFg      ?r@   r=   Iterable | Sequencer<   r   	transformCallable | Nonewith_coordinatesboolcache	cache_numint
cache_ratefloatnum_workers
int | Noneprogress
copy_cacheas_contiguous	hash_funcCallable[..., bytes]r0   Nonec                   s   t  j|d d |d urt|tst|}|| _|| _|| _|| _|| _|	| _	|
| _
|| _|| _|| _| jd urAtt| jd| _g | _g | _|| _d | _| jd ur[| jdd | _| jrnt|trgtd| | d S d S )Nr=   rU   r@   c                 S  s   t | tp
t | t S r*   )
isinstancer   r   )tr"   r"   r-   <lambda>  s    z+GridPatchDataset.__init__.<locals>.<lambda>z+Data can not be iterator when cache is True)superr.   rg   r   r<   patch_transformrW   set_numset_rater`   ra   rb   rc   r^   maxr[   _cache_cache_otherrY   first_randomget_index_of_firstr   	TypeErrorset_data)r,   r=   r<   rU   rW   rY   rZ   r\   r^   r`   ra   rb   rc   	__class__r"   r-   r.      s8   


zGridPatchDataset.__init__r	   c                   s   | _  fddt j D }tt jtt| j t| _t|d j  _	t|
 d j }t | \ _ _dS )aA  
        Set the input data and run deterministic transforms to generate cache content.

        Note: should call this func after an entire epoch and must set `persistent_workers=False`
        in PyTorch DataLoader, because it needs to create new worker processes based on new
        generated cache content.

        c                   s   i | ]
\}}  ||qS r"   )rc   )rA   irH   r,   r"   r-   rI     s    z-GridPatchDataset.set_data.<locals>.<dictcomp>N)r=   	enumerateminr[   rl   lenrm   rZ   list
_hash_keysvaluesrK   _fill_cachero   rp   )r,   r=   mappingindicesr"   rx   r-   rt   	  s   	&zGridPatchDataset.set_datar|   c                 C  s   | j dkrg S |du rtt| j }| jrtstd | jr#tr#tndd }t| j	}t||
| j|t|ddW  d   S 1 sGw   Y  dS )z
        Compute and fill the cache content from data source.

        Args:
            indices: target indices in the `self.data` source to compute cache.
                if None, use the first `cache_num` items.

        r   Nz>tqdm is not installed, will not show the caching progress bar.c                 [  s   | S r*   r"   )rH   _r"   r"   r-   ri   +  s    z.GridPatchDataset._fill_cache.<locals>.<lambda>zLoading dataset)totaldesc)rZ   r|   ranger`   has_tqdmwarningswarnr   r   r^   imap_load_cache_itemr{   )r,   r   Zpfuncpr"   r"   r-   r     s   
	

$zGridPatchDataset._fill_cacheidxc                 C  s   | j | }g g }}| |D ]1^}}| jdur!| j|| jdd}| jr+t|tjd}| jr;t	|dkr;|
|d  |
| q||fS )zN
        Args:
            idx: the index of the input data sequence.
        NT)end	threading)memory_formatr   )r=   r<   rq   rk   rb   r   torchcontiguous_formatrW   r{   append)r,   r   itemZpatch_cacheZother_cacherQ   othersr"   r"   r-   r   /  s   


z!GridPatchDataset._load_cache_itemc                 k  s\    |D ](^}}|}| j dur| j |fi |}| jr(t|dkr(||d fV  q|V  qdS )z
        yield patches optionally post-processed by transform.

        Args:
            src: a iterable of image patches.
            apply_args: other args for `self.patch_transform`.

        Nr   )rk   rW   r{   )r,   srcZ
apply_argsrQ   r   	out_patchr"   r"   r-   _generate_patchesA  s   	
z"GridPatchDataset._generate_patchesc                 #  s    | j rZd }t  D ]L}| |}|| jv r| j|}|d u r-| | |E d H  q| jd u r6t	d| j| }| j
| }| jrGt|n|}| jt||| jdE d H  qd S t  D ]}| | |E d H  q_d S )NzNCache buffer is not initialized, please call `set_data()` before epoch begins.)start)rY   rj   __iter__rc   r}   indexr   r<   ro   RuntimeErrorrp   ra   r
   rK   rq   )r,   cache_indeximagerB   r=   otherru   r"   r-   r   S  s*   




zGridPatchDataset.__iter__)r=   rT   r<   r   rU   rV   rW   rX   rY   rX   rZ   r[   r\   r]   r^   r_   r`   rX   ra   rX   rb   rX   rc   rd   r0   re   )r=   r	   r0   re   r*   )r0   r|   )r   r[   )r5   r6   r7   r8   sysmaxsizer   r.   rt   r   r   r   r   __classcell__r"   r"   ru   r-   r      s$    @
,
r   c                      s<   e Zd ZdZ	dd fddZdddZ fddZ  ZS )r   a  
    Yields patches from data read from an image dataset.
    The patches are generated by a user-specified callable `patch_func`,
    and are optionally post-processed by `transform`.
    For example, to generate random patch samples from an image dataset:

    .. code-block:: python

        import numpy as np

        from monai.data import PatchDataset, DataLoader
        from monai.transforms import RandSpatialCropSamples, RandShiftIntensity

        # image dataset
        images = [np.arange(16, dtype=float).reshape(1, 4, 4),
                  np.arange(16, dtype=float).reshape(1, 4, 4)]
        # image patch sampler
        n_samples = 5
        sampler = RandSpatialCropSamples(roi_size=(3, 3), num_samples=n_samples,
                                         random_center=True, random_size=False)
        # patch-level intensity shifts
        patch_intensity = RandShiftIntensity(offsets=1.0, prob=1.0)
        # construct the patch dataset
        ds = PatchDataset(dataset=images,
                          patch_func=sampler,
                          samples_per_image=n_samples,
                          transform=patch_intensity)

        # use the patch dataset, length: len(images) x samplers_per_image
        print(len(ds))

        >>> 10

        for item in DataLoader(ds, batch_size=2, shuffle=True, num_workers=2):
            print(item.shape)

        >>> torch.Size([2, 1, 3, 3])

    r@   Nr=   r	   
patch_funcr   samples_per_imager[   rU   rV   r0   re   c                   s:   t  j|dd || _|dkrtdt|| _|| _dS )a  
        Args:
            data: an image dataset to extract patches from.
            patch_func: converts an input image (item from dataset) into a sequence of image patches.
                patch_func(dataset[idx]) must return a sequence of patches (length `samples_per_image`).
            samples_per_image: `patch_func` should return a sequence of `samples_per_image` elements.
            transform: transform applied to each patch.
        Nrf   r   z-sampler_per_image must be a positive integer.)rj   r.   r   
ValueErrorr[   r   rk   )r,   r=   r   r   rU   ru   r"   r-   r.     s   

zPatchDataset.__init__c                 C  s   t | j| j S r*   )r{   r=   r   rx   r"   r"   r-   __len__  s   zPatchDataset.__len__c                 #  sn    t   D ].}| |}t|| jkrtd| j d|D ]}|}| jd ur0t| j|dd}|V  qqd S )NzA`patch_func` must return a sequence of length: samples_per_image=.F)	map_items)rj   r   r   r{   r   RuntimeWarningrk   r   )r,   r   patchesrQ   r   ru   r"   r-   r     s   

zPatchDataset.__iter__)r@   N)
r=   r	   r   r   r   r[   rU   rV   r0   re   )r0   r[   )r5   r6   r7   r8   r.   r   r   r   r"   r"   ru   r-   r   n  s    )
r   )3
__future__r   r   r   collections.abcr   r   r   r   r   r   r	   copyr
   multiprocessing.managersr   multiprocessing.poolr   typingr   numpynpr   monai.configr   monai.config.type_definitionsr   Zmonai.data.iterable_datasetr   monai.data.utilsr   r   monai.transformsr   r   r   r   r   monai.utilsr   r   r   r   r   r   r   __all__r    r!   r   r   r"   r"   r"   r-   <module>   s4   $<: O