o
    i"                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	Z	 d dl
mZ eddd\ZZd	gZ				 	
	d0d1d"d	Zd2d'd(Zd3d4d*d+Z	 d5d6d.d/ZdS )7    )annotationsN)Sequence)Any)
MetaTensor)optional_importtqdm)namepoint_based_window_infererT   inputstorch.Tensor | MetaTensorroi_sizeSequence[int]	predictortorch.nn.Modulepoint_coordstorch.Tensorpoint_labelsclass_vectortorch.Tensor | Noneprompt_class	prev_mask torch.Tensor | MetaTensor | Nonepoint_startintcenter_onlyboolmarginkwargsr   returnc           #      K  s  |j d dkstdt| j dkstdtt| |\}}|t|d |d |d g|j	 }|d	urBtt||d nd	}d	}|d |d	 D ]}t
|d |d |j d
 |	|
d\}}t
|d |d |j d |	|
d\}}t
|d |d |j d |	|
d\}}tt|D ]}tt|D ]}tt|D ]}|| || || || || || f\}}}}}}td	td	tt|t|tt|t|tt|t|g}|| } || f|||||g|d|}!|d	u r!tjd|!j d |j d
 |j d |j d gdd}tjd|!j d |j d
 |j d |j d gdd}"||  |!d7  < d|"|< qqqqN||" }|d	d	d	d	|d |j d
 |d  |d |j d |d  |d |j d |d  f }|"d	d	d	d	|d |j d
 |d  |d |j d |d  |d |j d |d  f }"|d	ur|d	d	d	d	|d |j d
 |d  |d |j d |d  |d |j d |d  f }|d}||"dk  ||"dk < t| tjrt| } t|dst|| jd | jd}|S )a	  
    Point-based window inferer that takes an input image, a set of points, and a model, and returns a segmented image.
    The inferer algorithm crops the input image into patches that centered at the point sets, which is followed by
    patch inference and average output stitching, and finally returns the segmented mask.

    Args:
        inputs: [1CHWD], input image to be processed.
        roi_size: the spatial window size for inferences.
            When its components have None or non-positives, the corresponding inputs dimension will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        predictor: the model. For vista3D, the output is [B, 1, H, W, D] which needs to be transposed to [1, B, H, W, D].
            Add transpose=True in kwargs for vista3d.
        point_coords: [B, N, 3]. Point coordinates for B foreground objects, each has N points.
        point_labels: [B, N]. Point labels. 0/1 means negative/positive points for regular supported or zero-shot classes.
            2/3 means negative/positive points for special supported classes (e.g. tumor, vessel).
        class_vector: [B]. Used for class-head automatic segmentation. Can be None value.
        prompt_class: [B]. The same as class_vector representing the point class and inform point head about
            supported class or zeroshot, not used for automatic segmentation. If None, point head is default
            to supported class segmentation.
        prev_mask: [1, B, H, W, D]. The value is before sigmoid. An optional tensor of previously segmented masks.
        point_start: only use points starting from this number. All points before this number is used to generate
            prev_mask. This is used to avoid re-calculating the points in previous iterations if given prev_mask.
        center_only: for each point, only crop the patch centered at this point. If false, crop 3 patches for each point.
        margin: if center_only is false, this value is the distance between point to the patch boundary.
    Returns:
        stitched_output: [1, B, H, W, D]. The value is before sigmoid.
    Notice: The function only supports SINGLE OBJECT INFERENCE with B=1.
    r      z(Only supports single object point click.r
   zInput image should be 5D.iN)r   r      )r   r   r   r   patch_coordsr   cpu)device      metaaffine)r,   r+   )shape
ValueErrorlen_pad_previous_maskcopydeepcopytorchtensortor(   _get_window_idxrangeslicer   zeros
isinstanceTensorr   hasattrr+   )#r   r   r   r   r   r   r   r   r   r   r   r   imagepadZstitched_outputpZlx_Zrx_Zly_Zry_Zlz_Zrz_ijklxrxlyrylzrzunravel_sliceZbatch_imageoutputZstitched_mask rK   \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/vista3d/inferer.pyr	      s   -( $$$4

&&
 VV

r?   roistuple[int, int]c                 C  sr   | |d  dk rd|}}||fS | |d  |kr$|| |}}||fS t | |d  t | |d  }}||fS )z(Helper function to get the window index.r$   r   )r   )r?   rM   rN   leftrightrK   rK   rL   _get_window_idx_c   s   
"rR   tuple[list[int], list[int]]c                 C  sh   t | ||\}}|r|g|gfS td| | | }t|| | | }||| |g}	|| ||g}
|	|
fS )zGet the window index.r   )rR   maxmin)r?   rM   rN   r   r   rP   rQ   Z	left_mostZ
right_most	left_list
right_listrK   rK   rL   r6      s   r6   padvalue+tuple[torch.Tensor | MetaTensor, list[int]]c                 C  s|   g }t t| jd ddD ]}t||d  | j|  d}|d }|||| g qt|r:tjjj	| |d|d} | |fS )zHelper function to pad inputs.r    r%   r$   r   constant)r>   modevalue)
r7   r/   r-   rT   extendanyr3   nn
functionalr>   )r   r   rX   pad_sizerB   diffhalfrK   rK   rL   r0      s   r0   )NNNr   Tr
   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r?   r   rM   r   rN   r   r   rO   )Tr
   )r?   r   rM   r   rN   r   r   r   r   r   r   rS   )r   )r   r   r   r   rX   r   r   rY   )
__future__r   r1   collections.abcr   typingr   r3   monai.data.meta_tensorr   monai.utilsr   r   ___all__r	   rR   r6   r0   rK   rK   rK   rL   <module>   s(   	
s