o
    /iR                     @  s   d dl mZ d dlZd dlZd dlmZmZ d dlmZm	Z	m
Z
 d dlmZmZmZ d dlmZ d dlZd dlmZ d dlmZmZ d d	lmZ d d
lmZmZmZ g dZG dd deZG dd deZG dd deZ dS )    )annotationsN)ABCabstractmethod)CallableIterableSequence)_emptyisclass	signature)Any)iter_patch_position)BaseWSIReader	WSIReader)convert_to_tensor)PathLikeensure_tupleensure_tuple_rep)SplitterSlidingWindowSplitterWSISlidingWindowSplitterc                   @  sF   e Zd ZdZddd	d
ZedddZedddZedddZdS )r   a9  
    A base class for splitting the inputs into iterable tuple of patches and locations
    Extend this class to support operations for `PatchInference`, e.g. SlidingPatchSplitter.

    Args:
        patch_size: the size of patches to be generated.
        device: the device where the patches are generated.
    N
patch_sizeSequence[int] | intdevicetorch.device | str | NonereturnNonec                 C  s   || _ || _d S Nr   r   )selfr   r    r   Y/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/inferers/splitter.py__init__)   s   
zSplitter.__init__inputsr   tuplec                 C     t d| jj d)aD  
        Return the input spatial shape.

        Args:
            inputs: either a tensor of shape BCHW[D], representing a batch of images,
                or a filename (str) or list of filenames to the image(s).

        Raises:
            NotImplementedError: When the subclass does not override this method.

        	Subclass  must implement this method.NotImplementedError	__class____name__r   r"   r   r   r    get_input_shape-   s   zSplitter.get_input_shapec                 C  r$   )a  
        Return the actual spatial shape covered by the output split patches.
        For instance, if the input image is padded, the actual spatial shape will be enlarged
        and not the same as input spatial shape.

        Args:
            inputs: either a tensor of shape BCHW[D], representing a batch of images,
                or a filename (str) or list of filenames to the image(s).

        Raises:
            NotImplementedError: When the subclass does not override this method.

        r%   r&   r'   r+   r   r   r    get_padded_shape<   s   zSplitter.get_padded_shape,Iterable[tuple[torch.Tensor, Sequence[int]]]c                 C  r$   )a  
        Split the input image (or batch of images) into patches and return pairs of (patch, location).
        Where location is the coordinate of top left [front] corner of a patch.

        Args:
            inputs: either a tensor of shape BCHW[D], representing a batch of images,
                or a filename (str) or list of filenames to the image(s).

        Raises:
            NotImplementedError: When the subclass does not override this method.

        r%   r&   r'   r+   r   r   r    __call__M   s   zSplitter.__call__r   )r   r   r   r   r   r   r"   r   r   r#   r"   r   r   r.   )	r*   
__module____qualname____doc__r!   r   r,   r-   r/   r   r   r   r    r      s    	r   c                      sv   e Zd ZdZ						d.d/ fddZedd Zdd Zd0ddZd1d$d%Z	d2d'd(Z
d2d)d*Zd3d,d-Z  ZS )4r   a  
    Splits the input into patches with sliding window strategy and a possible overlap.
    It also allows offsetting the starting position and filtering the patches.

    Args:
        patch_size : the size of the patches to be generated.
        offset: the amount of offset for the patches with respect to the original input.  Defaults to 0.
        overlap: the amount of overlap between patches in each dimension. It can be either a float in
            the range of [0.0, 1.0) that defines relative overlap to the patch size, or it can be a non-negative int
            that defines number of pixels for overlap. Defaults to 0.0.
        filter_fn: a callable to filter patches. It should accepts exactly two parameters (patch, location), and
            return True for a patch to keep. Defaults to no filtering.
        pad_mode: string define the mode for `torch.nn.functional.pad`. The acceptable values are
            `"constant"`, `"reflect"`, `"replicate"`, `"circular"` or `None`. Default to `"constant"`.
            If None, no padding will be applied, so it will drop the patches crossing the border of
            the image (either when the offset is negative or the image is non-divisible by the patch_size).
        pad_value: the value for `"constant"` padding. Defaults to 0.
        device: the device where the patches are generated. Defaults to the device of inputs.

    Note:
        When a scaler value is provided for `patch_size`, `offset`, or `overlap`,
            it is broadcasted to all the spatial dimensions.
            r   Nconstantr   r   overlap-Sequence[float] | float | Sequence[int] | intoffset	filter_fnCallable | Nonepad_mode
str | None	pad_valuefloat | intr   r   r   r   c                   s   t  j||d || _tt|d tr'tdd t|D r'td| dtdd t|D r:td| d	|| _| 	|| _
|| _|| _| js`td
d t|D rbtd| j dd S d S )Nr   r   c                 s  s     | ]}|d k p|dkV  qdS )r5   g      ?Nr   .0ovr   r   r    	<genexpr>   s    z1SlidingWindowSplitter.__init__.<locals>.<genexpr>z1Relative overlap must be between 0.0 and 1.0 but zZ is given. If you wish to use number of pixels as overlap, please provide integer numbers.c                 s      | ]}|d k V  qdS r   Nr   r@   r   r   r    rC          z1Number of pixels for overlap cannot be negative. z is given. c                 s  rD   rE   r   rA   offr   r   r    rC      rF   zDNegative `offset`requires a valid padding mode but `mode` is set to .)superr!   r9   
isinstancer   floatany
ValueErrorr7   _validate_filter_fnr:   r<   r>   )r   r   r7   r9   r:   r<   r>   r   r)   r   r    r!   w   s   
(
zSlidingWindowSplitter.__init__c                 C  s   t | r9t| }t|j}tdd |j D }|dk r(td|  d| d|dkr7td|  d| d| S | d urGtd	t|  d
| S )Nc                 S  s   g | ]	}|j tu r|qS r   )defaultr   )rA   vr   r   r    
<listcomp>   s    z=SlidingWindowSplitter._validate_filter_fn.<locals>.<listcomp>   z``filter_fn` requires to accept at least two parameters (patch, location).The provided callable (z) has z parameters.z``filter_fn` can have at most two positional parameters (patch, location).The provided callable (z positional parameters.zN`filter_fn` should be a callable with two input parameters (patch, location). 
 is given.)callabler
   len
parametersvaluesrN   type)r:   sign_paramsZnum_pos_paramsr   r   r    rO      s8   
	z)SlidingWindowSplitter._validate_filter_fnc                 C  s   dgd | }| j s|dfS dd |D |dd d< g }t||||D ]2\}}	}
}|
dkr0d}n t|trD|	| |
 t|
|
|   }n|	| |
 t|
|  }|| q#||d d d< |t|dd d fS )Nr   rT   Fc                 s  s    | ]	}t |d  V  qdS rE   )minrG   r   r   r    rC          z<SlidingWindowSplitter._calculate_pad_size.<locals>.<genexpr>   )r<   ziprK   rL   roundappendrM   )r   spatial_shapespatial_ndimr   r9   r7   pad_sizeZend_paddingshrH   psrB   
pad_amountr   r   r    _calculate_pad_size   s   
z)SlidingWindowSplitter._calculate_pad_sizerc   Sequence[int]Ltuple[tuple[int, ...], tuple[float, ...] | tuple[int, ...], tuple[int, ...]]c                   s   t |}t| j|}t| j| t fddt |D  tdd t |D r5td  d| dt| j|}t|||D ]$\}}}|| k rVtd| d| d	||kretd
| d| dqA| |fS )Nc                 3  s,    | ]\}}|r
|nt  d  d V  qdS rE   )rZ   )rA   opr7   r   r    rC      s   * zDSlidingWindowSplitter._get_valid_shape_parameters.<locals>.<genexpr>c                 s  s    | ]	\}}||kV  qd S r   r   )rA   rB   rg   r   r   r    rC      r^   z`overlap` (z$) cannot be larger than patch size (z).zNegative `offset` (z&) cannot be larger than `patch_size` (z) in magnitude.z
`offset` (z%) cannot be larger than inputs size ()	rW   r   r   r7   r#   r`   rM   rN   r9   )r   rc   rd   r   r9   rH   rg   rf   r   rn   r    _get_valid_shape_parameters   s   

z1SlidingWindowSplitter._get_valid_shape_parametersr"   r   locationtuple[int, ...]c                 C  s.   t d fd tdd t||D  }|| S )NrT   c                 s  s"    | ]\}}t ||| V  qd S r   )slice)rA   locrg   r   r   r    rC           z3SlidingWindowSplitter._get_patch.<locals>.<genexpr>)rr   r#   r`   )r   r"   rp   r   slicesr   r   r    
_get_patch   s   &z SlidingWindowSplitter._get_patchr#   c                 C  s   t |jdd S )a  
        Return the input spatial shape.

        Args:
            inputs: either a tensor of shape BCHW[D], representing a batch of images,
                or a filename (str) or list of filenames to the image(s).

        Returns:
            spatial_shape
        rT   N)r#   shaper+   r   r   r    r,      s   z%SlidingWindowSplitter.get_input_shapec           
   	   C  st   |  |}| js
|S t|}| |\}}}| |||||\}}tdd t||ddd |ddd D }	|	S )a  
        Return the actual spatial shape covered by the output split patches.
        For instance, if the input image is padded, the actual spatial shape will be enlarged
        and not the same as input spatial shape.

        Args:
            inputs: either a tensor of shape BCHW[D], representing a batch of images,
                or a filename (str) or list of filenames to the image(s).

        Returns:
            padded_spatial_shape

        c                 s  "    | ]\}}}|| | V  qd S r   r   rA   ssrg   per   r   r    rC      rt   z9SlidingWindowSplitter.get_padded_shape.<locals>.<genexpr>r_   NrT   )r,   r<   rW   ro   ri   r#   r`   )
r   r"   rc   rd   r   r7   r9   re   _padded_spatial_shaper   r   r    r-      s   
.z&SlidingWindowSplitter.get_padded_shaper.   c              	   c  s@   t |tjstdt| d|jdd }t|}| |\}}}| |||||\}}| j	rbt
|rbtjjj||ddd | j	| jd}|jdd }|rbtdd t||d	dd D }t||||d
D ]3}	| ||	|}
t|
| jd}
|rtdd t|	|d	dd D }	| jdu s| |
|	r|
|	fV  qjdS )a/  Split the input tensor into patches and return patches and locations.

        Args:
            inputs: either a torch.Tensor with BCHW[D] dimensions, representing an image or a batch of images

        Yields:
            tuple[torch.Tensor, Sequence[int]]: yields tuple of patch and location
        zThe input should be a tensor. rU   rT   N)modevaluec                 s      | ]	\}}|| V  qd S r   r   rA   rH   rm   r   r   r    rC     r^   z1SlidingWindowSplitter.__call__.<locals>.<genexpr>r_   Fr   c                 s      | ]	\}}|| V  qd S r   r   rA   rs   rm   r   r   r    rC   !  r^   )rK   torchTensorrN   rZ   rw   rW   ro   ri   r<   rM   nn
functionalpadr>   r#   r`   r   rv   r   r   r:   )r   r"   rc   rd   r   r7   r9   re   is_start_paddedrp   patchr   r   r    r/      s*   
$""
zSlidingWindowSplitter.__call__)r5   r   Nr6   r   N)r   r   r7   r8   r9   r   r:   r;   r<   r=   r>   r?   r   r   r   r   )rc   rj   r   rk   r"   r   rp   rq   r   rq   r   r   r0   r1   )r*   r2   r3   r4   r!   staticmethodrO   ri   ro   rv   r,   r-   r/   __classcell__r   r   rP   r    r   ^   s"    




r   c                      sX   e Zd ZdZ						d)d* fddZd+ddZd,d d!Zd-d#d$Zd.d'd(Z  Z	S )/r   a$  
    Splits the whole slide image input into patches with sliding window strategy and a possible overlap.
    This extracts patches from file without loading the entire slide into memory.
    It also allows offsetting the starting position and filtering the patches.

    Args:
        patch_size : the size of the patches to be generated.
        offset: the amount of offset for the patches with respect to the original input.  Defaults to 0.
        overlap: the amount of overlap between patches in each dimension. It can be either a float in
            the range of [0.0, 1.0) that defines relative overlap to the patch size, or it can be a non-negative int
            that defines number of pixels for overlap. Defaults to 0.0.
        filter_fn: a callable to filter patches. It should accepts exactly two parameters (patch, location), and
            return True for a patch to keep. Defaults to no filtering.
        pad_mode: define the mode for padding. Either "constant" or None. Default to "constant".
            Padding is only supported with "OpenSlide" or "cuCIM" backend, and the filling value is 256.
        device: the device where the patches are generated. Defaults to the device of inputs.
        reader: the module to be used for loading whole slide imaging. If `reader` is

            - a string, it defines the backend of `monai.data.WSIReader`. Defaults to "OpenSlide".
            - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
            - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.

            To obtain an optimized performance please use either "cuCIM" or "OpenSlide" backend.
        reader_kwargs: the arguments to pass to `WSIReader` or the provided whole slide reader class.
            For instance, level=2, dtype=torch.float32, etc.
            Note that if `level` is not provided, `level=0` is assumed.

    Note:
        When a scaler value is provided for `patch_size`, `offset`, or `overlap`,
        it is broadcasted to all the spatial dimensions.
    r5   r   Nr6   	OpenSlider   r   r7   r8   r9   r:   r;   r<   r=   r   r   reader0str | BaseWSIReader | type[BaseWSIReader] | Nonereader_kwargsdictr   r   c           	        sr   |r|dkrt d| dt j||||||d | || | jj dvr7td| jj  d d S d S )Nr6   zFThe underlying wsi readers only support for constant padding. pad_mod=rU   )r   r7   r9   r:   r   r<   )	openslidecucimzWSIReader with z backend is not supported for efficiently loading patches. This may cause an significant slow down and a large memory foot print. Please use other backends such as 'OpenSlide' or 'cuCIM' instead.)	rN   rJ   r!   _set_readerr   backendlowerwarningswarn)	r   r   r7   r9   r:   r<   r   r   r   rP   r   r    r!   H  s   
z!WSISlidingWindowSplitter.__init__c                 C  sz   |  || _ t|trtdd|i| j | _dS t|r+t|tr+|di | j | _dS t|tr5|| _dS td| d)a  
        Set the WSI reader object based on the input reader

        Args:
            reader: the module to be used for loading whole slide imaging. If `reader` is

                - a string, it defines the backend of `monai.data.WSIReader`. Defaults to cuCIM.
                - a class (inherited from `BaseWSIReader`), it is initialized and set as wsi_reader.
                - an instance of a class inherited from `BaseWSIReader`, it is set as the wsi_reader.
        r   zUnsupported reader type: rI   Nr   )	r   rK   strr   r   r	   
issubclassr   rN   )r   r   r   r   r   r    r   d  s   


z$WSISlidingWindowSplitter._set_readerr"   r   rp   rq   c                 C  s   | j j|||d\}}|d  S )N)wsirp   size)r   get_data)r   r"   rp   r   r   r|   r   r   r    rv   z  s   z#WSISlidingWindowSplitter._get_patchr#   c                 C  s(   | j |}| jdd}| j ||S )a  
        Return the input spatial shape.

        Args:
            inputs: either a tensor of shape BCHW[D], representing a batch of images,
                or a filename (str) or list of filenames to the image(s).

        Returns:
            spatial_shape

        levelr   )r   readr   getget_size)r   r"   r   r   r   r   r    r,   ~  s   z(WSISlidingWindowSplitter.get_input_shapePathLike | Sequence[PathLike]r.   c              	   #  s   t |tst |trt|dkrtd|d }t |ttjfs+tdt| d| j	|}| j
dd}| j|| | j||}t|}|dkrVtd| d	| |\}}}| |||||\}	}
t|	rtd
d t||	ddd |	ddd D }|
rtdd t||	ddd D }t||||dD ]>}t fdd|D }| |||}t|| jd}|
rtdd t||	ddd D }| jdu s| ||r||fV  qdS )zSplit the input tensor into patches and return patches and locations.

        Args:
            inputs: the file path to a whole slide image.

        Yields:
            tuple[torch.Tensor, Sequence[int]]: yields tuple of patch and location
        r_   zSOnly batch size of one would work for wsi image. Please provide one path at a time.r   z7The input should be the path to the whole slide image. rU   r   rT   z"WSIReader only support 2D images. z spatial dimension is provided.c                 s  rx   r   r   ry   r   r   r    rC     rt   z4WSISlidingWindowSplitter.__call__.<locals>.<genexpr>Nc                 s  r   r   r   r   r   r   r    rC     r^   Fc                 3  s    | ]	}t |  V  qd S r   )ra   )rA   rs   downsample_ratior   r    rC     r^   r   c                 s  r   r   r   r   r   r   r    rC     r^   )rK   r   r   rW   rN   osr   rZ   r   r   r   r   get_downsample_ratior   ro   ri   rM   r#   r`   r   rv   r   r   r:   )r   r"   r   r   rc   rd   r   r7   r9   re   r   rp   	location_r   r   r   r    r/     s<   
.""
z!WSISlidingWindowSplitter.__call__)r5   r   Nr6   Nr   )r   r   r7   r8   r9   r   r:   r;   r<   r=   r   r   r   r   r   r   r   r   )r   r   r   r   r   r   r   r0   )r"   r   r   r.   )
r*   r2   r3   r4   r!   r   rv   r,   r/   r   r   r   rP   r    r   '  s    #


r   )!
__future__r   r   r   abcr   r   collections.abcr   r   r   inspectr   r	   r
   typingr   r   monai.data.utilsr   monai.data.wsi_readerr   r   monai.transforms.utility.arrayr   monai.utils.miscr   r   r   __all__r   r   r   r   r   r   r    <module>   s"   ? J