U
    Ph                  	   @  s  d Z ddlmZ ddlZddlZddlmZmZ ddlm	Z	m
Z
 ddlmZ ddlZddlZddlmZmZ ddlmZ dd	lmZ dd
lmZmZ ddgZdZejZG dd deZG dd deZG dd deZ G dd deZ!G dd deZ"G dd deZ#ee e!e"e#gZ$eZ%d`dddddddd Z&dad!dd"d#d$Z'd%dd%d&d'd(Z(dbd%d!d!d%d)d*d+Z)dcd%d!d%d,d-d.Z*d%d%d/d0d1Z+ddd%d%d3d%d4d5d6Z,ded%d%d8d9d:d;d<Z-d%d8d/d=d>Z.d%d%d/d?d@Z/ejfdAdAdBdCdDdEdFZ0d%d%d%dGdHdIZ1d%d%d%dGdJdKZ2d%d%d%dGdLdMZ3dfdNdOdOd8dPdQdRdSZ4dgd%dOd8dTdUdVdWZ5dXe1fd%d%d3ddYd%dZd[d\Z6dXe1fd%d%d%d3ddYd%d]d^d_Z7dS )ha  
This utility module mainly supports rectangular bounding boxes with a few
different parameterizations and methods for converting between them. It
provides reliable access to the spatial coordinates of the box vertices in the
"canonical ordering":
[xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, zmin, xmax, ymax, zmax] for 3D.
We currently define this ordering as `monai.data.box_utils.StandardMode` and
the rest of the detection pipelines mainly assumes boxes in `StandardMode`.
    )annotationsN)ABCabstractmethod)CallableSequence)deepcopy)NdarrayOrTensorNdarrayTensor)look_up_option)BoxModeName)convert_data_typeconvert_to_dst_type      g        c                   @  sZ   e Zd ZU dZi Zded< edddddZed	d
dddZ	edd	dddZ
dS )BoxModea+  
    An abstract class of a ``BoxMode``.

    A ``BoxMode`` is callable that converts box mode of ``boxes``, which are Nx4 (2D) or Nx6 (3D) torch tensor or ndarray.
    ``BoxMode`` has several subclasses that represents different box modes, including

    - :class:`~monai.data.box_utils.CornerCornerModeTypeA`:
      represents [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, zmin, xmax, ymax, zmax] for 3D
    - :class:`~monai.data.box_utils.CornerCornerModeTypeB`:
      represents [xmin, xmax, ymin, ymax] for 2D and [xmin, xmax, ymin, ymax, zmin, zmax] for 3D
    - :class:`~monai.data.box_utils.CornerCornerModeTypeC`:
      represents [xmin, ymin, xmax, ymax] for 2D and [xmin, ymin, xmax, ymax, zmin, zmax] for 3D
    - :class:`~monai.data.box_utils.CornerSizeMode`:
      represents [xmin, ymin, xsize, ysize] for 2D and [xmin, ymin, zmin, xsize, ysize, zsize] for 3D
    - :class:`~monai.data.box_utils.CenterSizeMode`:
      represents [xcenter, ycenter, xsize, ysize] for 2D and [xcenter, ycenter, zcenter, xsize, ysize, zsize] for 3D

    We currently define ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,
    and monai detection pipelines mainly assume ``boxes`` are in ``StandardMode``.

    The implementation should be aware of:

    - remember to define class variable ``name``,
      a dictionary that maps ``spatial_dims`` to :class:`~monai.utils.enums.BoxModeName`.
    - :func:`~monai.data.box_utils.BoxMode.boxes_to_corners` and :func:`~monai.data.box_utils.BoxMode.corners_to_boxes`
      should not modify inputs in place.
    zdict[int, BoxModeName]nameintstr)spatial_dimsreturnc                 C  s   | j | jS )z
        Get the mode name for the given spatial dimension using class variable ``name``.

        Args:
            spatial_dims: number of spatial dimensions of the bounding boxes.

        Returns:
            ``str``: mode string name
        )r   value)clsr    r   I/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/data/box_utils.pyget_nameT   s    zBoxMode.get_nametorch.Tensortupleboxesr   c                 C  s   t d| jj ddS )a`  
        Convert the bounding boxes of the current mode to corners.

        Args:
            boxes: bounding boxes, Nx4 or Nx6 torch tensor

        Returns:
            ``tuple``: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor.
            It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

        Example:
            .. code-block:: python

                boxes = torch.ones(10,6)
                boxmode = BoxMode()
                boxmode.boxes_to_corners(boxes) # will return a 6-element tuple, each element is a 10x1 tensor
        	Subclass  must implement this method.NNotImplementedError	__class____name__)selfr   r   r   r   boxes_to_cornersa   s    zBoxMode.boxes_to_cornersr   cornersr   c                 C  s   t d| jj ddS )a  
        Convert the given box corners to the bounding boxes of the current mode.

        Args:
            corners: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor.
                It represents (xmin, ymin, xmax, ymax) or (xmin, ymin, zmin, xmax, ymax, zmax)

        Returns:
            ``Tensor``: bounding boxes, Nx4 or Nx6 torch tensor

        Example:
            .. code-block:: python

                corners = (torch.ones(10,1), torch.ones(10,1), torch.ones(10,1), torch.ones(10,1))
                boxmode = BoxMode()
                boxmode.corners_to_boxes(corners) # will return a 10x4 tensor
        r   r    Nr!   )r%   r(   r   r   r   corners_to_boxesv   s    zBoxMode.corners_to_boxesN)r$   
__module____qualname____doc__r   __annotations__classmethodr   r   r&   r)   r   r   r   r   r   4   s   
r   c                   @  s>   e Zd ZdZejejdZdddddZddd	d
dZ	dS )CornerCornerModeTypeAav  
    A subclass of ``BoxMode``.

    Also represented as "xyxy" or "xyzxyz", with format of
    [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].

    Example:
        .. code-block:: python

            CornerCornerModeTypeA.get_name(spatial_dims=2) # will return "xyxy"
            CornerCornerModeTypeA.get_name(spatial_dims=3) # will return "xyzxyz"
    r   r   r   r   r   c                 C  s   |j ddd}|S )N   dim)split)r%   r   r(   r   r   r   r&      s    z&CornerCornerModeTypeA.boxes_to_cornersr   r'   c                 C  s   t jt|dd}|S )Nr2   r3   )torchcatr   )r%   r(   r   r   r   r   r)      s    z&CornerCornerModeTypeA.corners_to_boxesN)
r$   r*   r+   r,   r   XYXYXYZXYZr   r&   r)   r   r   r   r   r/      s   r/   c                   @  s>   e Zd ZdZejejdZdddddZddd	d
dZ	dS )CornerCornerModeTypeBav  
    A subclass of ``BoxMode``.

    Also represented as "xxyy" or "xxyyzz", with format of
    [xmin, xmax, ymin, ymax] or [xmin, xmax, ymin, ymax, zmin, zmax].

    Example:
        .. code-block:: python

            CornerCornerModeTypeB.get_name(spatial_dims=2) # will return "xxyy"
            CornerCornerModeTypeB.get_name(spatial_dims=3) # will return "xxyyzz"
    r0   r   r   r   c           
      C  sl   t |d}|dkr>|jddd\}}}}}}||||||f}	n*|dkrh|jddd\}}}}||||f}	|	S Nr   r   r1   r2   r3   r   get_spatial_dimsr5   )
r%   r   r   xminxmaxyminymaxzminzmaxr(   r   r   r   r&      s    
z&CornerCornerModeTypeB.boxes_to_cornersr   r'   c                 C  sx   t |d}|dkrFtj|d |d |d |d |d |d fdd	}n.|dkrttj|d |d |d |d fdd	}|S 
Nr(   r   r   r1      r      r2   r3   )r>   r6   r7   r%   r(   r   r   r   r   r   r)      s    
4&z&CornerCornerModeTypeB.corners_to_boxesN)
r$   r*   r+   r,   r   XXYYXXYYZZr   r&   r)   r   r   r   r   r:      s   r:   c                   @  s>   e Zd ZdZejejdZdddddZddd	d
dZ	dS )CornerCornerModeTypeCav  
    A subclass of ``BoxMode``.

    Also represented as "xyxy" or "xyxyzz", with format of
    [xmin, ymin, xmax, ymax] or [xmin, ymin, xmax, ymax, zmin, zmax].

    Example:
        .. code-block:: python

            CornerCornerModeTypeC.get_name(spatial_dims=2) # will return "xyxy"
            CornerCornerModeTypeC.get_name(spatial_dims=3) # will return "xyxyzz"
    r0   r   r   r   c           
      C  sX   t |d}|dkr>|jddd\}}}}}}||||||f}	n|dkrT|jddd}	|	S r;   r=   )
r%   r   r   r?   rA   r@   rB   rC   rD   r(   r   r   r   r&      s    
z&CornerCornerModeTypeC.boxes_to_cornersr   r'   c                 C  sd   t |d}|dkrFtj|d |d |d |d |d |d fdd	}n|dkr`tjt|dd	}|S rE   )r>   r6   r7   r   rI   r   r   r   r)      s    
4z&CornerCornerModeTypeC.corners_to_boxesN)
r$   r*   r+   r,   r   r8   XYXYZZr   r&   r)   r   r   r   r   rL      s   
rL   c                   @  s>   e Zd ZdZejejdZdddddZddd	d
dZ	dS )CornerSizeModeam  
    A subclass of ``BoxMode``.

    Also represented as "xywh" or "xyzwhd", with format of
    [xmin, ymin, xsize, ysize] or [xmin, ymin, zmin, xsize, ysize, zsize].

    Example:
        .. code-block:: python

            CornerSizeMode.get_name(spatial_dims=2) # will return "xywh"
            CornerSizeMode.get_name(spatial_dims=3) # will return "xyzwhd"
    r0   r   r   r   c                 C  s(  |j }t|d}|dkr|jddd\}}}}}}	||t jtdjddj|d }
||t jtdjddj|d }||	t jtdjddj|d }||||
||f}nt|d	kr$|jddd\}}}}||t jtdjddj|d }
||t jtdjddj|d }|||
|f}|S )
Nr<   r   r1   r2   r3   dtyper   minr   rP   r>   r5   	TO_REMOVEtoCOMPUTE_DTYPEclamp)r%   r   	box_dtyper   r?   rA   rC   whdr@   rB   rD   r(   r   r   r   r&      s    
$$$
$$zCornerSizeMode.boxes_to_cornersr   r'   c           
      C  s   t |d}|dkrz|d |d |d |d |d |d f\}}}}}}tj||||| t || t || t fdd	}	nR|dkr|d |d |d |d f\}}}}tj|||| t || t fdd	}	|	S )
NrF   r   r   r1   r   rG   rH   r2   r3   r>   r6   r7   rT   
r%   r(   r   r?   rA   rC   r@   rB   rD   r   r   r   r   r)     s    
4& $&zCornerSizeMode.corners_to_boxesN)
r$   r*   r+   r,   r   XYWHXYZWHDr   r&   r)   r   r   r   r   rN      s   rN   c                   @  s>   e Zd ZdZejejdZdddddZddd	d
dZ	dS )CenterSizeModeam  
    A subclass of ``BoxMode``.

    Also represented as "ccwh" or "cccwhd", with format of
    [xmin, ymin, xsize, ysize] or [xmin, ymin, zmin, xsize, ysize, zsize].

    Example:
        .. code-block:: python

            CenterSizeMode.get_name(spatial_dims=2) # will return "ccwh"
            CenterSizeMode.get_name(spatial_dims=3) # will return "cccwhd"
    r0   r   r   r   c                 C  s  |j }t|d}|dkr6|jddd\}}}}}}	||t d jtdjdd	j|d }
||t d jtdjdd	j|d }||t d jtdjdd	j|d }||t d jtdjdd	j|d }||	t d jtdjdd	j|d }||	t d jtdjdd	j|d }|
|||||f}n|d
kr|jddd\}}}}||t d jtdjdd	j|d }
||t d jtdjdd	j|d }||t d jtdjdd	j|d }||t d jtdjdd	j|d }|
|||f}|S )Nr<   r   r1   r2   r3          @rO   r   rQ   r   rS   )r%   r   rX   r   xcyczcrY   rZ   r[   r?   r@   rA   rB   rC   rD   r(   r   r   r   r&   1  s&    

((((((
((((zCenterSizeMode.boxes_to_cornersr   r'   c           
      C  s  t |d}|dkr|d |d |d |d |d |d f\}}}}}}tj|| t d || t d || t d || t || t || t fd	d
}	nl|dkr
|d |d |d |d f\}}}}tj|| t d || t d || t || t fd	d
}	|	S )NrF   r   r   r1   r   rG   rH   ra   r2   r3   r\   r]   r   r   r   r)   I  s0    
4



$

	zCenterSizeMode.corners_to_boxesN)
r$   r*   r+   r,   r   CCWHCCCWHDr   r&   r)   r   r   r   r   r`   !  s   r`   z torch.Tensor | np.ndarray | NonezSequence | Nonez0Sequence[int] | torch.Tensor | np.ndarray | Noner   )r   pointsr(   spatial_sizer   c                 C  s  t  }| dk	rt| jdkrP| jd dkr>td| j dntd| j dt| jd d tkrxtd| j d|t| jd d  |dk	rt|jdkr|jd dkrtd|j d	ntd|j dt|jd tkrtd|j d|t|jd  |dk	rXt|d tkrFtd
t| d|t|d  |dk	rt|tkrtd| d|t| t|}t|dkrtdt|dkrt|d }t|ddgd}t|S tddS )a1  
    Get spatial dimension for the giving setting and check the validity of them.
    Missing input is allowed. But at least one of the input value should be given.
    It raises ValueError if the dimensions of multiple inputs do not match with each other.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray
        points: point coordinates, [x, y] or [x, y, z], Nx2 or Nx3 torch tensor or ndarray
        corners: corners of boxes, 4-element or 6-element tuple, each element is a Nx1 torch tensor or ndarray
        spatial_size: The spatial size of the image where the boxes are attached.
                len(spatial_size) should be in [2, 3].

    Returns:
        ``int``: spatial_dims, number of spatial dimensions of the bounding boxes.

    Example:
        .. code-block:: python

            boxes = torch.ones(10,6)
            get_spatial_dims(boxes, spatial_size=[100,200,200]) # will return 3
            get_spatial_dims(boxes, spatial_size=[100,200]) # will raise ValueError
            get_spatial_dims(boxes) # will return 3
    Nr   r   zPCurrently we support only boxes with shape [N,4] or [N,6], got boxes with shape z^. Please reshape it with boxes = torch.reshape(boxes, [0, 4]) or torch.reshape(boxes, [0, 6])..r1   zRCurrently we support only points with shape [N,2] or [N,3], got points with shape za. Please reshape it with points = torch.reshape(points, [0, 2]) or torch.reshape(points, [0, 3]).z\Currently we support only boxes with shape [N,4] or [N,6], got box corner tuple with length zNCurrently we support only boxes on 2-D and 3-D images, got image spatial_size z1At least one of the inputs needs to be non-empty.r   )	supportedz?The dimensions of multiple inputs should match with each other.)	setlenshape
ValueErrorr   SUPPORTED_SPATIAL_DIMSaddlistr
   )r   rg   r(   rh   Zspatial_dims_setZspatial_dims_listr   r   r   r   r>   m  sb    



r>   z$str | BoxMode | type[BoxMode] | None)moder   c                 O  s   t | tr| S t| r,t| tr,| ||S t | tr~tD ]B}tD ]8}t|rBt|trB||| krB|||    S qBq:| dk	rt	d|  dt
||S )a`	  
    This function that return a :class:`~monai.data.box_utils.BoxMode` object giving a representation of box mode

    Args:
        mode: a representation of box mode. If it is not given, this func will assume it is ``StandardMode()``.

    Note:
        ``StandardMode`` = :class:`~monai.data.box_utils.CornerCornerModeTypeA`,
        also represented as "xyxy" for 2D and "xyzxyz" for 3D.

        mode can be:
            #. str: choose from :class:`~monai.utils.enums.BoxModeName`, for example,
                - "xyxy": boxes has format [xmin, ymin, xmax, ymax]
                - "xyzxyz": boxes has format [xmin, ymin, zmin, xmax, ymax, zmax]
                - "xxyy": boxes has format [xmin, xmax, ymin, ymax]
                - "xxyyzz": boxes has format [xmin, xmax, ymin, ymax, zmin, zmax]
                - "xyxyzz": boxes has format [xmin, ymin, xmax, ymax, zmin, zmax]
                - "xywh": boxes has format [xmin, ymin, xsize, ysize]
                - "xyzwhd": boxes has format [xmin, ymin, zmin, xsize, ysize, zsize]
                - "ccwh": boxes has format [xcenter, ycenter, xsize, ysize]
                - "cccwhd": boxes has format [xcenter, ycenter, zcenter, xsize, ysize, zsize]
            #. BoxMode class: choose from the subclasses of :class:`~monai.data.box_utils.BoxMode`, for example,
                - CornerCornerModeTypeA: equivalent to "xyxy" or "xyzxyz"
                - CornerCornerModeTypeB: equivalent to "xxyy" or "xxyyzz"
                - CornerCornerModeTypeC: equivalent to "xyxy" or "xyxyzz"
                - CornerSizeMode: equivalent to "xywh" or "xyzwhd"
                - CenterSizeMode: equivalent to "ccwh" or "cccwhd"
            #. BoxMode object: choose from the subclasses of :class:`~monai.data.box_utils.BoxMode`, for example,
                - CornerCornerModeTypeA(): equivalent to "xyxy" or "xyzxyz"
                - CornerCornerModeTypeB(): equivalent to "xxyy" or "xxyyzz"
                - CornerCornerModeTypeC(): equivalent to "xyxy" or "xyxyzz"
                - CornerSizeMode(): equivalent to "xywh" or "xyzwhd"
                - CenterSizeMode(): equivalent to "ccwh" or "cccwhd"
            #. None: will assume mode is ``StandardMode()``

    Returns:
        BoxMode object

    Example:
        .. code-block:: python

            mode = "xyzxyz"
            get_boxmode(mode) # will return CornerCornerModeTypeA()
    NzUnsupported box mode: ri   )
isinstancer   inspectisclass
issubclassr   SUPPORTED_MODESro   r   rn   StandardMode)rr   argskwargsmnr   r   r   get_boxmode  s    -


"r}   r   )r   r   r   c                 C  sF   t | tj^}}|jd dkr2t|d|d g}t|| d^}}|S )a  
    When boxes are empty, this function standardize it to shape of (0,4) or (0,6).

    Args:
        boxes: bounding boxes, Nx4 or Nx6 or empty torch tensor or ndarray
        spatial_dims: number of spatial dimensions of the bounding boxes.

    Returns:
        bounding boxes with shape (N,4) or (N,6), N can be 0.

    Example:
        .. code-block:: python

            boxes = torch.ones(0,)
            standardize_empty_box(boxes, 3)
    r   r   srcdst)r   r6   Tensorrm   reshaper   )r   r   boxes_t_	boxes_dstr   r   r   standardize_empty_box  s
    r   )r   src_modedst_moder   c                 C  s   | j d dkr| S t|}t|}t|t|r8t| S t| tj^}}||}t	|d}t
d|D ]*}	|||	  ||	 k  dkrftd qf||}
t|
| d^}}|S )a  
    This function converts the boxes in src_mode to the dst_mode.

    Args:
        boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.
        src_mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.
            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.
        dst_mode: target box mode. If it is not given, this func will assume it is ``StandardMode()``.
            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.

    Returns:
        bounding boxes with target mode, with same data type as ``boxes``, does not share memory with ``boxes``

    Example:
        .. code-block:: python

            boxes = torch.ones(10,4)
            # The following three lines are equivalent
            # They convert boxes with format [xmin, ymin, xmax, ymax] to [xcenter, ycenter, xsize, ysize].
            convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode="ccwh")
            convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode)
            convert_box_mode(boxes=boxes, src_mode="xyxy", dst_mode=monai.data.box_utils.CenterSizeMode())
    r   r<   BGiven boxes has invalid values. The box size must be non-negative.r~   )rm   r}   rs   typer   r   r6   r   r&   r>   rangesumwarningswarnr)   r   )r   r   r   Zsrc_boxmodeZdst_boxmoder   r   r(   r   axisZboxes_t_dstr   r   r   r   convert_box_mode"  s    


r   )r   rr   r   c                 C  s   t | |t dS )a  
    Convert given boxes to standard mode.
    Standard mode is "xyxy" or "xyzxyz",
    representing box format of [xmin, ymin, xmax, ymax] or [xmin, ymin, zmin, xmax, ymax, zmax].

    Args:
        boxes: source bounding boxes, Nx4 or Nx6 torch tensor or ndarray.
        mode: source box mode. If it is not given, this func will assume it is ``StandardMode()``.
            It follows the same format with ``mode`` in :func:`~monai.data.box_utils.get_boxmode`.

    Returns:
        bounding boxes with standard mode, with same data type as ``boxes``, does not share memory with ``boxes``

    Example:
        .. code-block:: python

            boxes = torch.ones(10,6)
            # The following two lines are equivalent
            # They convert boxes with format [xmin, xmax, ymin, ymax, zmin, zmax] to [xmin, ymin, zmin, xmax, ymax, zmax]
            convert_box_to_standard_mode(boxes=boxes, mode="xxyyzz")
            convert_box_mode(boxes=boxes, src_mode="xxyyzz", dst_mode="xyzxyz")
    r   r   r   )r   rx   )r   rr   r   r   r   convert_box_to_standard_mode^  s    r   r   c                 C  s(   t | d}t| ttdddd|f S )z
    Compute center points of boxes

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        center points with size of (N, spatial_dims)

    r<   r   N)r>   r   rx   r`   )r   r   r   r   r   box_centersz  s    
r   {Gz?float)centersr   epsr   c                   s   t  d fddtD  fddtD  }t tjrftj|ddjdd}||kS tj|ddt	jddd |kS )	a  
    Checks which center points are within boxes

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``.
        centers: center points, Nx2 or Nx3 torch tensor or ndarray.
        eps: minimum distance to border of boxes.

    Returns:
        boolean array indicating which center points are within the boxes, sized (N,).

    Reference:
        https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py

    r<   c                   s,   g | ]$}d d |f  d d |f  qS Nr   .0r   )r   r   r   r   
<listcomp>  s     z$centers_in_boxes.<locals>.<listcomp>c                   s0   g | ](} d d | f d d |f  qS r   r   r   r   r   r   r   r   r     s    r1   )r   r3   r   )
r>   r   rs   npndarraystackrR   r6   rU   rV   )r   r   r   Zcenter_to_borderZmin_center_to_borderr   r   r   centers_in_boxes  s    
$r   Tboolz8tuple[NdarrayOrTensor, NdarrayOrTensor, NdarrayOrTensor])boxes1boxes2	euclideanr   c           	   	   C  s   t | t|s8tdt|  dt| dt|  d t| tj^}}t|tj^}}t|t	}t|t	}|r|dddf |d  
dd }n|dddf |d  d}t|||f| d^\}}}}|||fS )	aB  
    Distance of center points between two sets of boxes

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        euclidean: computed the euclidean distance otherwise it uses the l1 distance

    Returns:
        - The pairwise distances for every element in boxes1 and boxes2,
          with size of (N,M) and same data type as ``boxes1``.
        - Center points of boxes1, with size of (N,spatial_dims) and same data type as ``boxes1``.
        - Center points of boxes2, with size of (M,spatial_dims) and same data type as ``boxes1``.

    Reference:
        https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/ops.py

    
boxes1 is , while boxes2 is . The result will be ri   Nr   r2   r~   )rs   r   r   r   r   r6   r   r   rU   rV   powr   sqrtr   )	r   r   r   boxes1_tr   boxes2_tZcenter1Zcenter2distsr   r   r   boxes_center_distance  s    **r   c                 C  sP   t | d}td|D ]6}| dd|| f | dd|f k  dkr dS qdS )z
    This function checks whether the box size is non-negative.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        whether ``boxes`` is valid
    r<   r   NFT)r>   r   r   )r   r   r   r   r   r   is_valid_box_values  s
    

,r   c                 C  s   t | stdt| d}| dd|f | dddf  t }td|D ]0}|| dd|| f | dd|f  t  }qHt|tj^}}| 	 s|
 	 r|jtjkrtdntd|S )a  
    This function computes the area (2D) or volume (3D) of each box.
    Half precision is not recommended for this function as it may cause overflow, especially for 3D images.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        area (2D) or volume (3D) of boxes, with size of (N,).

    Example:
        .. code-block:: python

            boxes = torch.ones(10,6)
            # we do computation with torch.float32 to avoid overflow
            compute_dtype = torch.float32
            area = box_area(boxes=boxes.to(dtype=compute_dtype))  # torch.float32, size of (10,)
    r   r<   Nr   r1   zUBox area is NaN or Inf. boxes is float16. Please change to float32 and test it again.zBox area is NaN or Inf.)r   rn   r>   rT   r   r   r6   r   isnananyisinfrP   float16)r   r   arear   Zarea_tr   r   r   r   box_area  s    
$.
r   r   ztorch.dtypez!tuple[torch.Tensor, torch.Tensor])r   r   compute_dtyper   c                 C  s   t | d}t| j|dd}t|j|dd}t| dddd|f |ddd|f j|d}t| ddd|df |dd|df j|d}|| t jdd}tj|ddd}	|dddf | |	 }
|	|
fS )	a(  
    This internal function computes the intersection and union area of two set of boxes.

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor. The box mode is assumed to be ``StandardMode``
        compute_dtype: default torch.float32, dtype with which the results will be computed

    Returns:
        inter, with size of (N,M) and dtype of ``compute_dtype``.
        union, with size of (N,M) and dtype of ``compute_dtype``.

    r<   rO   Nr   rQ   r2   Fr4   keepdim)	r>   r   rU   r6   maxrR   rT   rW   prod)r   r   r   r   area1area2ltrbwhinterunionr   r   r   _box_inter_union  s    
..r   )r   r   r   c           
   	   C  s   t | t|s8tdt|  dt| dt|  d t| tj^}}t|tj^}}|j}t||t	d\}}||t
t	j  }|j|d}t| st| rtdt|| d^}	}|	S )	a  
    Compute the intersection over union (IoU) of two set of boxes.

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        IoU, with size of (N,M) and same data type as ``boxes1``

    r   r   r   ri   r   rO   zBox IoU is NaN or Inf.r~   )rs   r   r   r   r   r6   r   rP   r   rV   finfor   rU   r   r   r   rn   r   )
r   r   r   r   r   rX   r   r   Ziou_tiour   r   r   box_iou4  s    *r   c              	   C  s  t | t|s8tdt|  dt| dt|  d t| tj^}}t|tj^}}t|d}|j}t	||t
d\}}||tt
j  }	t|dddd|f |ddd|f jt
d}
t|ddd|df |dd|df jt
d}||
 t jd	d
}tj|ddd}|	|| |tt
j   }|j|d}t| sft| rntdt|| d^}}|S )a  
    Compute the generalized intersection over union (GIoU) of two sets of boxes.
    The two inputs can have different shapes and the func return an NxM matrix,
    (in contrary to :func:`~monai.data.box_utils.box_pair_giou` , which requires the inputs to have the same
    shape and returns ``N`` values).

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, Mx4 or Mx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``

    Returns:
        GIoU, with size of (N,M) and same data type as ``boxes1``

    Reference:
        https://giou.stanford.edu/GIoU.pdf

    r   r   r   ri   r<   r   NrO   r   rQ   r2   Fr   Box GIoU is NaN or Inf.r~   )rs   r   r   r   r   r6   r   r>   rP   r   rV   r   r   rR   rU   r   rT   rW   r   r   r   r   rn   r   )r   r   r   r   r   r   rX   r   r   r   r   r   r   	enclosuregiou_tgiour   r   r   box_giouZ  s,    *
.. r   c              	   C  s<  t | t|s8tdt|  dt| dt|  d t| tj^}}t|tj^}}|j|jkrltdt	|d}|j
}t|jtdd}t|jtdd}t|ddd|f |ddd|f jtd}	t|dd|df |dd|df jtd}
|
|	 t jd	d
}tj|ddd}|| | }||ttj  }t|ddd|f |ddd|f jtd}	t|dd|df |dd|df jtd}
|
|	 t jd	d
}tj|ddd}||| |ttj   }|j|d}t| s t| r(tdt|| d^}}|S )a  
    Compute the generalized intersection over union (GIoU) of a pair of boxes.
    The two inputs should have the same shape and the func return an (N,) array,
    (in contrary to :func:`~monai.data.box_utils.box_giou` , which does not require the inputs to have the same
    shape and returns ``NxM`` matrix).

    Args:
        boxes1: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        boxes2: bounding boxes, same shape with boxes1. The box mode is assumed to be ``StandardMode``

    Returns:
        paired GIoU, with size of (N,) and same data type as ``boxes1``

    Reference:
        https://giou.stanford.edu/GIoU.pdf

    r   r   r   ri   z7boxes1 and boxes2 should be paired and have same shape.r<   rO   Nr   rQ   r2   Fr   r   r~   )rs   r   r   r   r   r6   r   rm   rn   r>   rP   r   rU   rV   r   rR   rT   rW   r   r   r   r   r   r   r   )r   r   r   r   r   r   rX   r   r   r   r   r   r   r   r   r   r   r   r   r   r   box_pair_giou  sD    *
,,,, r   r	   zSequence[int] | NdarrayOrTensorz%tuple[NdarrayTensor, NdarrayOrTensor])r   	roi_startroi_endremove_emptyr   c                 C  s  t | tjd  }|jtd}t||ddd tj}t||ddd tj}t||}t	| |d}t
d|D ]}|dd|f j|| || t d|dd|f< |dd|| f j|| || t d|dd|| f< |dd|f  || 8  < |dd|| f  || 8  < qv|r|dd|f |dddf d t k}	t
d|D ]6}|	|dd|| f |dd|f d t k@ }	q\||	 }ntj|dddf dtjd	}	t|| d
^}
}t|	| |	jd^}}|
|fS )a  
    This function generate the new boxes when the corresponding image is cropped to the given ROI.
    When ``remove_empty=True``, it makes sure the bounding boxes are within the new cropped image.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        roi_start: voxel coordinates for start of the crop ROI, negative values allowed.
        roi_end: voxel coordinates for end of the crop ROI, negative values allowed.
        remove_empty: whether to remove the boxes that are actually empty

    Returns:
        - cropped boxes, boxes[keep], does not share memory with original boxes
        - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.
    r   rO   T)r   r   wrap_sequencer   rh   N)rR   r   r1   )
fill_valuerP   r~   r   r   rP   )r   r6   r   clonerU   rV   r   int16maximumr>   r   rW   rT   	full_liker   rP   )r   r   r   r   r   Zroi_start_tZ	roi_end_tr   r   Zkeep_tZ
boxes_keepr   keepr   r   r   spatial_crop_boxes  s.    2 
"(4
r   z'tuple[NdarrayOrTensor, NdarrayOrTensor])r   rh   r   r   c                 C  s"   t | |d}t| dg| ||dS )ad  
    This function clips the ``boxes`` to makes sure the bounding boxes are within the image.

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        spatial_size: The spatial size of the image where the boxes are attached. len(spatial_size) should be in [2, 3].
        remove_empty: whether to remove the boxes that are actually empty

    Returns:
        - clipped boxes, boxes[keep], does not share memory with original boxes
        - ``keep``, it indicates whether each box in ``boxes`` are kept when ``remove_empty=True``.
    r   r   )r   r   r   )r>   r   )r   rh   r   r   r   r   r   clip_boxes_to_image  s    r   r2   r   )r   scores
nms_threshmax_proposalsbox_overlap_metricr   c                 C  sr  | j d dkr(ttg | tjdd S | j d |j d krTtd| j  d|j  t| tj^}}t||^}}tj	|ddd}t
||ddf }	g }
tttd|	j d j|jtjd}t|dkrVt|d  }|
| t|
|  krd	krn nqV||	|ddf |	||d	 ddf }||k }d
|d< || }q||
 }t|| |jdd S )a  
    Non-maximum suppression (NMS).

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.
        nms_thresh: threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.
        max_proposals: maximum number of boxes it keeps.
            If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept.
        box_overlap_metric: the metric to compute overlap between boxes.

    Returns:
        Indexes of ``boxes`` that are kept after NMS.

    Example:
        .. code-block:: python

            boxes = torch.ones(10,6)
            scores = torch.ones(10)
            keep = non_max_suppression(boxes, scores, num_thresh=0.1)
            boxes_after_nms = boxes[keep]
    r   r   z:boxes and scores should have same length, got boxes shape z, scores shape T)r4   
descendingN)devicerP   r1   F)rm   r   r   arrayr6   longrn   r   r   argsortr   rq   r   rU   r   rl   r   itemappendflattenrP   )r   r   r   r   r   r   r   scores_tZ	sort_idxsZ
boxes_sortpickidxsiZbox_overlapZto_keep_idxZpick_idxr   r   r   non_max_suppression.  s.    (
 *
r   )r   r   labelsr   r   r   r   c                 C  s   | j d dkr(ttg | tjdd S t| tjtjd^}}t||^}}t||tjd^}	}|	 }
|	
||
d  }| |dddf  }t|||||}t|| |jdd S )a  
    Performs non-maximum suppression in a batched fashion.
    Each labels value correspond to a category, and NMS will not be applied between elements of different categories.

    Adapted from https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/nms.py

    Args:
        boxes: bounding boxes, Nx4 or Nx6 torch tensor or ndarray. The box mode is assumed to be ``StandardMode``
        scores: prediction scores of the boxes, sized (N,). This function keeps boxes with higher scores.
        labels: indices of the categories for each one of the boxes. sized(N,), value range is (0, num_classes)
        nms_thresh: threshold of NMS. Discards all overlapping boxes with box_overlap > nms_thresh.
        max_proposals: maximum number of boxes it keeps.
            If ``max_proposals`` = -1, there is no limit on the number of boxes that are kept.
        box_overlap_metric: the metric to compute overlap between boxes.

    Returns:
        Indexes of ``boxes`` that are kept after NMS.
    r   r   rO   r1   N)rm   r   r   r   r6   r   r   r   float32r   rU   r   rP   )r   r   r   r   r   r   r   r   r   Zlabels_tZmax_coordinateoffsetsZboxes_for_nmsr   r   r   r   batched_nmsx  s    r   )NNNN)N)NN)N)r   )T)T)T)8r,   
__future__r   rt   r   abcr   r   collections.abcr   r   copyr   numpyr   r6   monai.config.type_definitionsr   r	   monai.utilsr
   monai.utils.enumsr   Zmonai.utils.type_conversionr   r   ro   rT   r   rV   r   r/   r:   rL   rN   r`   rw   rx   r>   r}   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   <module>   sp   
X%$1G    \>  =   ++&&:R : O