U
    PhH                     @  s   d dl mZ d dlZd dlZd dlmZmZ d dlZd dl	Z	d dl
mZ d dlmZ d dlmZ d dlmZmZ d dlmZmZmZ d d	lmZmZ d d
lmZmZmZ dddgZG dd deZG dd deeZ G dd deZ!dS )    )annotationsN)CallableSequence)Dataset)
MetaTensor)iter_patch_position)BaseWSIReader	WSIReader)ForegroundMaskRandomizableapply_transform)convert_to_dst_typeensure_tuple_rep)
CommonKeysProbMapKeysWSIPatchKeysPatchWSIDatasetSlidingPatchWSIDatasetMaskedPatchWSIDatasetc                	      s   e Zd ZdZd ddddd	d	d
d fddZddddZddddZddddZddddZddddZ	ddddZ
ddddZ  ZS )!r   a'  
    This dataset extracts patches from whole slide images (without loading the whole image)
    It also reads labels for each patch and provides each patch with its associated class labels.

    Args:
        data: the list of input samples including image, location, and label (see the note below for more details).
        patch_size: the size of patch to be extracted from the whole slide image.
        patch_level: the level at which the patches to be extracted (default to 0).
        transform: transforms to be executed on input data.
        include_label: whether to load and include labels in the output
        center_location: whether the input location information is the position of the center of the patch
        additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data
        reader: the module to be used for loading whole slide imaging. If `reader` is

            - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

        kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

    Returns:
        dict: a dictionary of loaded image (in MetaTensor format) along with the labels (if requested).
        {"image": MetaTensor, "label": torch.Tensor}

    Note:
        The input data has the following form as an example:

        .. code-block:: python

            [
                {"image": "path/to/image1.tiff", "location": [200, 500], "label": 0},
                {"image": "path/to/image2.tiff", "location": [100, 700], "patch_size": [20, 20], "patch_level": 2, "label": 1}
            ]

    NTcuCIMr   int | tuple[int, int] | None
int | NoneCallable | NoneboolzSequence[str] | None)data
patch_sizepatch_level	transforminclude_labelcenter_locationadditional_meta_keysc	           
        s   t  || |d krd | _nt|d| _|| _|d kr<d}|  t|trbtf ||d|	| _nLt	
|rt|tr|f d|i|	| _n"t|tr|| _ntd| d| jj| _|| _|| _|pg | _i | _d S )N   r   )backendlevelr#   zUnsupported reader type: .)super__init__r   r   r   
isinstancestrr	   
wsi_readerinspectisclass
issubclassr   
ValueErrorr"   r   r   r    wsi_object_dict)
selfr   r   r   r   r   r   r    readerkwargs	__class__ L/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/data/wsi_datasets.pyr&   E   s(    



zPatchWSIDataset.__init__dictsamplec                 C  s0   |t j }|| jkr&| j|| j|< | j| S N)r   IMAGEr.   r)   read)r/   r8   
image_pathr4   r4   r5   _get_wsi_objectr   s    

zPatchWSIDataset._get_wsi_objectc                 C  s   t j|tj t jdS )N)dtype)torchtensorr   LABELfloat32r/   r8   r4   r4   r5   
_get_labelx   s    zPatchWSIDataset._get_labelc                   s:   | j r,|   fddttD S  tj S d S )Nc                   s&   g | ]} t j | | d   qS )r!   )r   LOCATION).0ir8   sizer4   r5   
<listcomp>~   s     z1PatchWSIDataset._get_location.<locals>.<listcomp>)r   	_get_sizerangelenr   rE   rC   r4   rH   r5   _get_location{   s    
zPatchWSIDataset._get_locationc                 C  s   | j d kr|tjdS | j S )Nr   )r   getr   LEVELrC   r4   r4   r5   
_get_level   s    
zPatchWSIDataset._get_levelc                 C  s"   | j d krt|tjdS | j S )Nr!   )r   r   rO   r   SIZErC   r4   r4   r5   rK      s    
zPatchWSIDataset._get_sizec                 C  sL   | j dkri | _| |}| |}| |}| |}| jj||||dS )NZ	openslide)wsilocationrI   r#   )r"   r.   r=   rN   rQ   rK   r)   get_data)r/   r8   wsi_objrT   r#   rI   r4   r4   r5   	_get_data   s    




zPatchWSIDataset._get_dataint)indexc                 C  sn   | j | }| |\}}| jD ]}|| ||< qtjt||di}| jrX| ||tj< | j	rjt
| j	|S |S )N)meta)r   rW   r    r   r:   r   r   rD   rA   r   r   )r/   rY   r8   imagemetadatakeyoutputr4   r4   r5   
_transform   s    

zPatchWSIDataset._transform)NNNTTNr   )__name__
__module____qualname____doc__r&   r=   rD   rN   rQ   rK   rW   r_   __classcell__r4   r4   r2   r5   r       s    '        -
c                      sr   e Zd ZdZdddddddddejejejfddfdd	d
dddddddddd fddZdd Z	dd Z
  ZS )r   av	  
    This dataset extracts patches in sliding-window manner from whole slide images (without loading the whole image).
    It also reads labels for each patch and provides each patch with its associated class labels.

    Args:
        data: the list of input samples including image, location, and label (see the note below for more details).
        patch_size: the size of patch to be extracted from the whole slide image.
        patch_level: the level at which the patches to be extracted (default to 0).
        mask_level: the resolution level at which the mask/map is created (for `ProbMapProducer` for instance).
        overlap: the amount of overlap of neighboring patches in each dimension (a value between 0.0 and 1.0).
            If only one float number is given, it will be applied to all dimensions. Defaults to 0.0.
        offset: the offset of image to extract patches (the starting position of the upper left patch).
        offset_limits: if offset is set to "random", a tuple of integers defining the lower and upper limit of the
            random offset for all dimensions, or a tuple of tuples that defines the limits for each dimension.
        transform: transforms to be executed on input data.
        include_label: whether to load and include labels in the output
        center_location: whether the input location information is the position of the center of the patch
        additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data
        reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is

            - a string, it defines the backend of `monai.data.WSIReader`.
            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,
            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

        seed: random seed to randomly generate offsets. Defaults to 0.
        kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

    Note:
        The input data has the following form as an example:

        .. code-block:: python

            [
                {"image": "path/to/image1.tiff"},
                {"image": "path/to/image2.tiff", "patch_size": [20, 20], "patch_level": 2}
            ]

        Unlike `MaskedPatchWSIDataset`, this dataset does not filter any patches.
    Nr   g        )r   r   Fr   r   r   r   rX   ztuple[float, float] | floatztuple[int, int] | int | strz@tuple[tuple[int, int], tuple[int, int]] | tuple[int, int] | Noner   r   Sequence[str])r   r   r   
mask_leveloverlapoffsetoffset_limitsr   r   r   r    seedc                   s  t  jf g ||||	|
||d| || _| | d| _t|tr|dkrd| _|  |d krfd | _qt|trt|d t	r||f| _qt|d tr|| _qt
dqt
dqt
d| d	nt|d
| _|| _|  t|| _| jD ]}| |}| j| qd S )Nr   r   r   r   r   r   r    r0   FrandomTr   zUThe offset limits should be either a tuple of integers or tuple of tuple of integers.z$The offset limits should be a tuple.zInvalid string for offset "zc". It should be either "random" as a string,an integer, or a tuple of integers defining the offset.r!   )r%   r&   rg   set_random_staterandom_offsetr'   r(   ri   tuplerX   r-   r   rh   rf   list
image_data_evaluate_patch_locationsr   extend)r/   r   r   r   rf   rg   rh   ri   r   r   r   r    r0   rj   r1   r8   patch_samplesr2   r4   r5   r&      sP    	







zSlidingPatchWSIDataset.__init__c                   sL    j rF jd kr*tdd  |D }n j}t fdd|D S  jS )Nc                 s  s   | ]}| |fV  qd S r9   r4   )rF   sr4   r4   r5   	<genexpr>  s     z5SlidingPatchWSIDataset._get_offset.<locals>.<genexpr>c                 3  s    | ]\}} j ||V  qd S r9   )Rrandint)rF   lowhighr/   r4   r5   rv     s     )rn   ri   ro   rK   rh   )r/   r8   ri   r4   r{   r5   _get_offset  s    
z"SlidingPatchWSIDataset._get_offsetc              
     s  |  }| }| }| j|d}| j|| j}| j|| t fdd|D }| 	}tt
t|||| jdd}	t|	|d  t| }
|tjj< |tjj< tjtj tjj< t|	tjj< t| j|| jtjj< fddt|	|
D S )z@Calculate the location for each patch in a sliding-window mannerr   c                   s   g | ]}|  qS r4   r4   rF   ppatch_ratior4   r5   rJ   )  s     zDSlidingPatchWSIDataset._evaluate_patch_locations.<locals>.<listcomp>F)
image_sizer   	start_posrg   paddedr!   c                   s.   g | ]&\}} t jjt|tjj|iqS r4   r   rE   valuenparrayr   rF   locZmask_locr7   r4   r5   rJ   ;  s   )rK   rQ   r=   r)   get_sizeget_downsample_ratiorf   r   r   r|   rp   r   rg   roundfloatr   rR   r   rP   ospathbasenamer   r:   r   NAMErM   COUNTzip)r/   r8   r   r   rV   Zwsi_size
mask_ratiopatch_size_0rh   patch_locationsmask_locationsr4   r   r8   r5   rr     s8    



    
z0SlidingPatchWSIDataset._evaluate_patch_locations)r`   ra   rb   rc   r   rE   rR   r   r&   r|   rr   rd   r4   r4   r2   r5   r      s    +*A	c                
      sV   e Zd ZdZddddddejejfdfdddd	d
dddd fddZdd Z  Z	S )r   a4  
    This dataset extracts patches from whole slide images at the locations where foreground mask
    at a given level is non-zero.

    Args:
        data: the list of input samples including image, location, and label (see the note below for more details).
        patch_size: the size of patch to be extracted from the whole slide image.
        patch_level: the level at which the patches to be extracted (default to 0).
        mask_level: the resolution level at which the mask is created.
        transform: transforms to be executed on input data.
        include_label: whether to load and include labels in the output
        center_location: whether the input location information is the position of the center of the patch
        additional_meta_keys: the list of keys for items to be copied to the output metadata from the input data
        reader: the module to be used for loading whole slide imaging. Defaults to cuCIM. If `reader` is

            - a string, it defines the backend of `monai.data.WSIReader`.
            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader,
            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

        kwargs: additional arguments to pass to `WSIReader` or provided whole slide reader class

    Note:
        The input data has the following form as an example:

        .. code-block:: python

            [
                {"image": "path/to/image1.tiff"},
                {"image": "path/to/image2.tiff", "size": [20, 20], "level": 2}
            ]

    N   Fr   r   r   r   rX   r   r   re   )r   r   r   rf   r   r   r   r    c
                   s^   t  jf g |||||||	d|
 || _|  t|| _| jD ]}| |}| j| q>d S )Nrk   )r%   r&   rf   rp   rq   rr   r   rs   )r/   r   r   r   rf   r   r   r   r    r0   r1   r8   rt   r2   r4   r5   r&   c  s$    	


zMaskedPatchWSIDataset.__init__c                   s*  |  }| }| }| jj|| jd\}}ttt	ddid||dd }t
| j}| j|| j}	| j|| t fdd|D }
t|d	 t|	 |
d
  t}|tjj< |tjj< tjtj tjj< t|tjj< |j tjj< fddt!||D S )zUCalculate the location for each patch based on the mask at different resolution level)r#   Sotsu)hsv_threshold)dstr   c                   s   g | ]}|  qS r4   r4   r}   r   r4   r5   rJ     s     zCMaskedPatchWSIDataset._evaluate_patch_locations.<locals>.<listcomp>g      ?r!   c                   s.   g | ]&\}} t jjt|tjj|iqS r4   r   r   r7   r4   r5   rJ     s   )"rK   rQ   r=   r)   rU   rf   r   squeezer   r
   vstacknonzeroTr   r   r   r   astyperX   r   rR   r   rP   r   r   r   r   r:   r   r   rM   r   shaper   )r/   r8   r   r   rV   rS   _maskr   r   r   r   r4   r   r5   rr     s$    


$$
z/MaskedPatchWSIDataset._evaluate_patch_locations)
r`   ra   rb   rc   r   rE   r   r&   rr   rd   r4   r4   r2   r5   r   A  s   $
"!)"
__future__r   r*   r   collections.abcr   r   numpyr   r?   
monai.datar   monai.data.meta_tensorr   monai.data.utilsr   Zmonai.data.wsi_readerr   r	   monai.transformsr
   r   r   monai.utilsr   r   monai.utils.enumsr   r   r   __all__r   r   r   r4   r4   r4   r5   <module>   s$   
  