o
    i	C                     @  s`   d Z ddlmZ ddlmZmZ ddlZdgZ		d d!ddZ	d"ddZ
d#ddZd$ddZdS )%a  
This script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/matching.py
The changes include 1) code reformatting, 2) docstrings,
3) allow input args gt_ignore to be optional. (If so, no GT boxes will be ignored.)
    )annotations)CallableSequenceNmatching_batchd   iou_fn.Callable[[np.ndarray, np.ndarray], np.ndarray]iou_thresholdsSequence[float]
pred_boxesSequence[np.ndarray]pred_classespred_scoresgt_boxes
gt_classes	gt_ignore6Sequence[Sequence[bool]] | Sequence[np.ndarray] | Nonemax_detectionsintreturn&list[dict[int, dict[str, np.ndarray]]]c	                 C  s   g }	|du rdd |D }t ||||||D ]X\}
}}}}}t||}i }|D ]@}||k}||k}t|sCt||| |d||< q(t|sSt||| d||< q(t| |
| || || || ||d||< q(|	| q|	S )a!  
    Match boxes of a batch to corresponding ground truth for each category
    independently.

    Args:
        iou_fn: compute overlap for each pair
        iou_thresholds: defined which IoU thresholds should be evaluated
        pred_boxes: predicted boxes from single batch; List[[D, dim * 2]],
            D number of predictions
        pred_classes: predicted classes from a single batch; List[[D]],
            D number of predictions
        pred_scores: predicted score for each bounding box; List[[D]],
            D number of predictions
        gt_boxes: ground truth boxes; List[[G, dim * 2]], G number of ground
            truth
        gt_classes: ground truth classes; List[[G]], G number of ground truth
        gt_ignore: specified if which ground truth boxes are not counted as
            true positives. If not given, when use all the gt_boxes.
            (detections which match theses boxes are not counted as false
            positives either); List[[G]], G number of ground truth
        max_detections: maximum number of detections which should be evaluated

    Returns:
        List[Dict[int, Dict[str, np.ndarray]]], each Dict[str, np.ndarray] corresponds to an image.
        Dict has the following keys.

        - `dtMatches`: matched detections [T, D], where T = number of
          thresholds, D = number of detections
        - `gtMatches`: matched ground truth boxes [T, G], where T = number
          of thresholds, G = number of ground truth
        - `dtScores`: prediction scores [D] detection scores
        - `gtIgnore`: ground truth boxes which should be ignored
          [G] indicate whether ground truth should be ignored
        - `dtIgnore`: detections which should be ignored [T, D],
          indicate which detections should be ignored

    Example:

        .. code-block:: python

            from monai.data.box_utils import box_iou
            from monai.apps.detection.metrics.coco import COCOMetric
            from monai.apps.detection.metrics.matching import matching_batch
            # 3D example outputs of one image from detector
            val_outputs_all = [
                    {"boxes": torch.tensor([[1,1,1,3,4,5]],dtype=torch.float16),
                    "labels": torch.randint(3,(1,)),
                    "scores": torch.randn((1,)).absolute()},
            ]
            val_targets_all = [
                    {"boxes": torch.tensor([[1,1,1,2,6,4]],dtype=torch.float16),
                    "labels": torch.randint(3,(1,))},
            ]

            coco_metric = COCOMetric(
                classes=['c0','c1','c2'], iou_list=[0.1], max_detection=[10]
            )
            results_metric = matching_batch(
                iou_fn=box_iou,
                iou_thresholds=coco_metric.iou_thresholds,
                pred_boxes=[val_data_i["boxes"].numpy() for val_data_i in val_outputs_all],
                pred_classes=[val_data_i["labels"].numpy() for val_data_i in val_outputs_all],
                pred_scores=[val_data_i["scores"].numpy() for val_data_i in val_outputs_all],
                gt_boxes=[val_data_i["boxes"].numpy() for val_data_i in val_targets_all],
                gt_classes=[val_data_i["labels"].numpy() for val_data_i in val_targets_all],
            )
            val_metric_dict = coco_metric(results_metric)
            print(val_metric_dict)
    Nc                 S  s   g | ]}t |d qS )F)np	full_like).0Zgt_c r   g/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/metrics/matching.py
<listcomp>   s    z"matching_batch.<locals>.<listcomp>)r	   r   r   )r	   r   )r   r   r   r   r   r   r	   )zipr   union1dany_matching_no_gt_matching_no_pred#_matching_single_image_single_classappend)r   r	   r   r   r   r   r   r   r   resultsZpboxesZpclassesZpscoresZgboxesZgclassesZgignoreZimg_classesresultcZ	pred_maskZgt_maskr   r   r   r   J   s8   P


	
np.ndarraydict[str, np.ndarray]c           	      C  s   t j| dd}|d| }|| }t|}t g gt|  }t t| |f}t t| |f}|||t g d|dS )a  
    Matching result with not ground truth in image

    Args:
        iou_thresholds: defined which IoU thresholds should be evaluated
        dt_scores: predicted scores
        max_detections: maximum number of allowed detections per image.
            This functions uses this parameter to stay consistent with
            the actual matching function which needs this limit.

    Returns:
        computed matching, a Dict[str, np.ndarray]

        - `dtMatches`: matched detections [T, D], where T = number of
          thresholds, D = number of detections
        - `gtMatches`: matched ground truth boxes [T, G], where T = number
          of thresholds, G = number of ground truth
        - `dtScores`: prediction scores [D] detection scores
        - `gtIgnore`: ground truth boxes which should be ignored
          [G] indicate whether ground truth should be ignored
        - `dtIgnore`: detections which should be ignored [T, D],
          indicate which detections should be ignored
    	mergesortkindN	dtMatchesZ	gtMatchesdtScoresgtIgnoredtIgnore)r   argsortlenarrayzerosreshape)	r	   r   r   dt_ind	dt_scores	num_predsgt_matchdt_match	dt_ignorer   r   r   r       s   r    c                 C  sr   t g }t g gt|  }t g gt|  }|jdkr dn|jd }t t| |f}||||d|dS )a  
    Matching result with no predictions

    Args:
        iou_thresholds: defined which IoU thresholds should be evaluated
        gt_ignore: specified if which ground truth boxes are not counted as
            true positives (detections which match theses boxes are not
            counted as false positives either); [G], G number of ground truth

    Returns:
        dict: computed matching

        - `dtMatches`: matched detections [T, D], where T = number of
          thresholds, D = number of detections
        - `gtMatches`: matched ground truth boxes [T, G], where T = number
          of thresholds, G = number of ground truth
        - `dtScores`: prediction scores [D] detection scores
        - `gtIgnore`: ground truth boxes which should be ignored
          [G] indicate whether ground truth should be ignored
        - `dtIgnore`: detections which should be ignored [T, D],
          indicate which detections should be ignored
    r   r,   r-   )r   r4   r3   sizeshaper5   r6   )r	   r   r8   r;   r<   Zn_gtr:   r   r   r   r!      s   
r!   c                 C  s  t j| dd}|d| }|| }|| }t j|dd}|| }|| }| ||}	|	jd |	jd }
}t t||f}t t||
f}t t||
f}t|D ]e\}}t|D ]\\}}t|dg}d}t|D ]0\}}|||f dkryql|dkr|| dkr|| dkr n|	||f |k rql|	||f }|}ql|dkrq\t|| |||f< d|||f< d|||f< q\qT||||d|dS )	a  
    Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py

    Args:
        iou_fn: compute overlap for each pair
        iou_thresholds: defined which IoU thresholds should be evaluated
        pred_boxes: predicted boxes from single batch; [D, dim * 2], D number
            of predictions
        pred_scores: predicted score for each bounding box; [D], D number of
            predictions
        gt_boxes: ground truth boxes; [G, dim * 2], G number of ground truth
        gt_ignore: specified if which ground truth boxes are not counted as
            true positives (detections which match theses boxes are not
            counted as false positives either); [G], G number of ground truth
        max_detections: maximum number of detections which should be evaluated

    Returns:
        dict: computed matching

        - `dtMatches`: matched detections [T, D], where T = number of
          thresholds, D = number of detections
        - `gtMatches`: matched ground truth boxes [T, G], where T = number
          of thresholds, G = number of ground truth
        - `dtScores`: prediction scores [D] detection scores
        - `gtIgnore`: ground truth boxes which should be ignored
          [G] indicate whether ground truth should be ignored
        - `dtIgnore`: detections which should be ignored [T, D],
          indicate which detections should be ignored
    r)   r*   Nr      gA?r,   r-   )	r   r2   r>   r5   r3   	enumerateminr   r6   )r   r   r   r   r   r   r	   r7   Zgt_indZiousr9   Znum_gtsr:   r;   r<   ZtindtZdind_dioumZgind_gr   r   r   r"     sJ   '
  r"   )Nr   )r   r   r	   r
   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r	   r
   r   r'   r   r   r   r(   )r	   r
   r   r'   r   r(   )r   r   r   r'   r   r'   r   r'   r   r'   r   r   r	   r
   r   r(   )__doc__
__future__r   collections.abcr   r   numpyr   __all__r   r    r!   r"   r   r   r   r   <module>   s   :
r
-'