U
    Phg                     @  sj   d Z ddlmZ ddlZddlZddlmZ ddlm	Z	 ddl
ZG dd dZdddd	d
ddddZdS )z
This script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/coco.py
The changes include 1) code reformatting, 2) docstrings.
    )annotationsN)Sequence)Anyc                   @  s   e Zd Zd3ddddddd	d
dZddddddZdddddZddddZdddddZddd d!d"Zddd d#d$Z	e
d4d'd(d)d*d+d,d-d.Ze
d5d'd)d)d*d+d,d/d0Zdddd1d2Zd%S )6
COCOMetric皙?      ?g      ?r   r   g?      d   TzSequence[str]zSequence[float]zSequence[int]bool)classesiou_list	iou_rangemax_detection	per_classverbosec           	   	   C  s  || _ || _|| _t|}tj|d |d tt|d |d  |d  d dd}t||| _	|| _
t|ddtjf | j	tj kd | _t|ddtjf | j	tj kd | _| j	| j |k r| j	| j |k stdtjdd	ttd
d dd| _|| _dS )a	  
        Class to compute COCO metrics
        Metrics computed includes,

        - mAP over the IoU range specified by `iou_range` at last value of `max_detection`
        - AP values at IoU thresholds specified by `iou_list` at last value of `max_detection`
        - AR over max detections thresholds defined by `max_detection` (over iou range)

        Args:
            classes (Sequence[str]): name of each class (index needs to correspond to predicted class indices!)
            iou_list (Sequence[float]): specific thresholds where ap is evaluated and saved
            iou_range (Sequence[float]): (start, stop, step) for mAP iou thresholds
            max_detection (Sequence[int]): maximum number of detections per image
            verbose (bool): log time needed for evaluation

        Example:

            .. code-block:: python

                from monai.data.box_utils import box_iou
                from monai.apps.detection.metrics.coco import COCOMetric
                from monai.apps.detection.metrics.matching import matching_batch
                # 3D example outputs of one image from detector
                val_outputs_all = [
                        {"boxes": torch.tensor([[1,1,1,3,4,5]],dtype=torch.float16),
                        "labels": torch.randint(3,(1,)),
                        "scores": torch.randn((1,)).absolute()},
                ]
                val_targets_all = [
                        {"boxes": torch.tensor([[1,1,1,2,6,4]],dtype=torch.float16),
                        "labels": torch.randint(3,(1,))},
                ]

                coco_metric = COCOMetric(
                    classes=['c0','c1','c2'], iou_list=[0.1], max_detection=[10]
                )
                results_metric = matching_batch(
                    iou_fn=box_iou,
                    iou_thresholds=coco_metric.iou_thresholds,
                    pred_boxes=[val_data_i["boxes"].numpy() for val_data_i in val_outputs_all],
                    pred_classes=[val_data_i["labels"].numpy() for val_data_i in val_outputs_all],
                    pred_scores=[val_data_i["scores"].numpy() for val_data_i in val_outputs_all],
                    gt_boxes=[val_data_i["boxes"].numpy() for val_data_i in val_targets_all],
                    gt_classes=[val_data_i["labels"].numpy() for val_data_i in val_targets_all],
                )
                val_metric_dict = coco_metric(results_metric)
                print(val_metric_dict)
        r   r      T)endpointNzxRequire self.iou_thresholds[self.iou_list_idx] == iou_list_np and self.iou_thresholds[self.iou_range_idx] == _iou_range.g        g      ?g      Y@)r   r   r   nparraylinspaceintroundunion1diou_thresholdsr   nonzeronewaxisiou_list_idxiou_range_idxall
ValueErrorrecall_thresholdsmax_detections)	selfr   r   r   r   r   r   Ziou_list_npZ
_iou_range r'   V/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/detection/metrics/coco.py__init__L   s.    9
  $ **"zCOCOMetric.__init__r   z5tuple[dict[str, float], dict[str, np.ndarray] | None])argskwargsreturnc                 O  s   | j ||S )a  
        Compute metric. See :func:`compute` for more information.

        Args:
            *args: positional arguments passed to :func:`compute`
            **kwargs: keyword arguments passed to :func:`compute`

        Returns:
            dict[str, float]: dictionary with scalar values for evaluation
            dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
        )compute)r&   r*   r+   r'   r'   r(   __call__   s    zCOCOMetric.__call__
np.ndarrayNone)r*   r,   c                 G  sH   t |  }|D ]2}|jd |krtd|jd  d|   dqdS )z
        Check if shape of input in first dimension is consistent with expected IoU values
        (assumes IoU dimension is the first dimension)

        Args:
            args: array like inputs with shape function
        r   zIRequire arg.shape[0] == len(self.get_iou_thresholds()). Got arg.shape[0]=z, self.get_iou_thresholds()=.N)lenget_iou_thresholdsshaper#   )r&   r*   Znum_iousargr'   r'   r(   check_number_of_iou   s    zCOCOMetric.check_number_of_iou)r,   c                 C  s
   t | jS )z
        Return IoU thresholds needed for this metric in an numpy array

        Returns:
            Sequence[float]: IoU thresholds [M], M is the number of thresholds
        )listr   )r&   r'   r'   r(   r3      s    zCOCOMetric.get_iou_thresholdsz&list[dict[int, dict[str, np.ndarray]]]ztuple[dict[str, float], None])results_listr,   c                 C  s   | j rtd t }| j|d}| j rJt }td|| dd i }|| | || | | j rt }td|| dd |dfS )	a  
        Compute COCO metrics

        Args:
            results_list (list[dict[int, dict[str, np.ndarray]]]): list with results per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored

        Returns:
            dict[str, float], dictionary with coco metrics
        z Start COCO metric computation...)r8   z(Statistics for COCO metrics finished (t=z0.2fzs).zCOCO metrics computed in t=zs.N)r   loggerinfotime_compute_statisticsupdate_compute_ap_compute_ar)r&   r8   ticdataset_statisticstocresultsr'   r'   r(   r-      s    
zCOCOMetric.computezdict[str, np.ndarray | list]zdict[str, float])rA   r,   c                 C  sd  i }| j rd| j d dd| j d dd| j d dd| jd  }| j|| jdd	||< | jrt| jD ]^\}}| d
| j d dd| j d dd| j d dd| jd  	}| j|| j|dd||< qj| jD ]}d| j| dd| jd  }| j||gdd	||< | jrt| jD ]D\}}| d| j| dd| jd  }| j||g|dd||< qq|S )a  
        Compute AP metrics

        Args:
            dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored
        ZmAP_IoU_r   .2f_r   r   _MaxDet_iou_idxmax_det_idxZ	_mAP_IoU_rI   cls_idxrJ   ZAP_IoU_Z_AP_IoU_)	r   r%   
_select_apr!   r   	enumerater   r    r   )r&   rA   rC   keyrL   cls_stridxr'   r'   r(   r>      s.    8<   
"zCOCOMetric._compute_apc           	      C  sZ  i }t | jD ]\}}d| jd dd| jd dd| jd dd| }| j||d||< | jrt | jD ]T\}}| d	| jd dd| jd dd| jd dd| 	}| j|||d
||< qlq| jD ]}d| j| dd| jd  }| j||dd||< | jrt | jD ]B\}}| d| j| dd| jd  }| j|||dd||< qq|S )a  
        Compute AR metrics

        Args:
            dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored
        ZmAR_IoU_r   rD   rE   r   r   rF   )rJ   Z	_mAR_IoU_)rL   rJ   ZAR_IoU_rG   rH   Z_AR_IoU_rK   )rN   r%   r   
_select_arr   r   r    r   )	r&   rA   rC   rJ   max_detrO   rL   rP   rQ   r'   r'   r(   r?     s"    46
"zCOCOMetric._compute_arNrG   dictz#int | list[int] | np.ndarray | Nonezint | Sequence[int] | Noner   float)rA   rI   rL   rJ   r,   c                 C  sL   | d }|dk	r|| }|dk	r2|d|ddf }|d|f }t t|S )a  
        Compute average precision

        Args:
            dataset_statistics (dict): computed statistics over dataset

                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
                  detection thresholds
                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
                - `precision`: Precision values at specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
                - `scores`: Scores corresponding to specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
            iou_idx: index of IoU values to select for evaluation(if None, all values are used)
            cls_idx: class indices to select, if None all classes will be selected
            max_det_idx (int): index to select max detection threshold from data

        Returns:
            np.ndarray: AP value
        	precisionN.)rU   r   mean)rA   rI   rL   rJ   precr'   r'   r(   rM   H  s    zCOCOMetric._select_apc                 C  sl   | d }|dk	r|| }|dk	r2|d|ddf }|d|f }t ||dk dkrVdS tt||dk S )a  
        Compute average recall

        Args:
            dataset_statistics (dict): computed statistics over dataset

                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
                  detection thresholds
                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
                - `precision`: Precision values at specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
                - `scores`: Scores corresponding to specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
            iou_idx: index of IoU values to select for evaluation(if None, all values are used)
            cls_idx: class indices to select, if None all classes will be selected
            max_det_idx (int): index to select max detection threshold from data

        Returns:
            np.ndarray: recall value
        recallN.rG   r   g      )r2   rU   r   rW   )rA   rI   rL   rJ   recr'   r'   r(   rR   k  s    zCOCOMetric._select_arc              	     sz  t | j}t | j}t | j}t | j}t||||f }t|||f }t||||f }t| jD ]\ }	t| jD ]\}
 fdd|D }t |dkrt	d|	  qt
fdd|D }tj| dd}|| }tj
fdd|D d	d
dd|f }tj
fdd|D d	d
dd|f }| || t
dd |D }t|dk}|dkrt	d|	  qt|t|}tt|t|}tj|d	d
jtjd}tj|d	d
jtjd}tt||D ]r\}\}}t|t| }}t|||| j|\}}}||| |
f< |||dd |
f< |||dd |
f< qqql||||g|||dS )a  
        Compute statistics needed for COCO metrics (mAP, AP of individual classes, mAP@IoU_Thresholds, AR)
        Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py

        Args:
            results_list (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored

        Returns:
            dict: computed statistics over dataset
                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
                  detection thresholds
                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
                - `precision`: Precision values at specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
                - `scores`: Scores corresponding to specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
        c                   s   g | ]} |kr|  qS r'   r'   .0r)rL   r'   r(   
<listcomp>  s      z2COCOMetric._compute_statistics.<locals>.<listcomp>r   z4WARNING, no results found for coco metric for class c                   s   g | ]}|d  d  qS )ZdtScoresr   r'   r[   rS   r'   r(   r^     s     	mergesort)kindc                   s$   g | ]}|d  ddd f qS )Z	dtMatchesNr   r'   r[   r_   r'   r(   r^     s     r   )axisNc                   s$   g | ]}|d  ddd f qS )ZdtIgnoreNr   r'   r[   r_   r'   r(   r^     s     c                 S  s   g | ]}|d  qS )ZgtIgnorer'   r[   r'   r'   r(   r^     s     z/WARNING, no gt found for coco metric for class )dtype)countsrY   rV   scores)r2   r   r$   r   r%   r   onesrN   r9   warningconcatenateargsortr6   count_nonzerological_andlogical_notcumsumastypefloat32zipr   _compute_stats_single_threshold)r&   r8   Z
num_iou_thnum_recall_thnum_classesZnum_max_detectionsrV   rY   re   Zcls_irJ   rC   Z	dt_scoresindsdt_scores_sortedZ
dt_matchesZ
dt_ignoresZ	gt_ignorenum_gtZtpsfpstp_sumZfp_sumZth_indtpfpr]   psr'   )rL   rS   r(   r<     sN    



((

zCOCOMetric._compute_statistics)r   r	   r
   TT)NNrG   )NNrG   )__name__
__module____qualname__r)   r.   r6   r3   r-   r>   r?   staticmethodrM   rR   r<   r'   r'   r'   r(   r   J   s.        T	(/*   "   &r   r/   znp.ndarray | Sequence[float]r   z$tuple[float, np.ndarray, np.ndarray])ry   rz   ru   r$   rv   r,   c                 C  s  t |}| | }| ||  td  }t | r8|d }nd}t|f}	t|f}
| }|	 }	tt | d ddD ](}|| ||d  krx|| ||d < qxtj||dd}z.t|D ] \}}|| |	|< || |
|< qW n tk
r   Y nX |t	|	t	|
fS )a  
    Compute recall value, precision curve and scores thresholds
    Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py

    Args:
        tp (np.ndarray): cumsum over true positives [R], R is the number of detections
        fp (np.ndarray): cumsum over false positives [R], R is the number of detections
        dt_scores_sorted (np.ndarray): sorted (descending) scores [R], R is the number of detections
        recall_thresholds (Sequence[float]): recall thresholds which should be evaluated
        num_gt (int): number of ground truth bounding boxes (excluding boxes which are ignored)

    Returns:
        - float, overall recall for given IoU value
        - np.ndarray, precision values at defined recall values
          [RTH], where RTH is the number of recall thresholds
        - np.ndarray, prediction scores corresponding to recall values
          [RTH], where RTH is the number of recall thresholds
    r   rG   r   left)side)
r2   r   spacingzerostolistrangesearchsortedrN   BaseExceptionr   )ry   rz   ru   r$   rv   rr   rcprrY   rV   Z	th_scoresirt   Zsave_idxZarray_indexr'   r'   r(   rq     s*    
rq   )__doc__
__future__r   loggingr9   r;   collections.abcr   typingr   numpyr   r   rq   r'   r'   r'   r(   <module>;   s      !