o
    ig                     @  s\   d Z ddlmZ ddlZddlZddlmZ ddlm	Z	 ddl
ZG dd dZdddZdS )z
This script is almost same with https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/evaluator/detection/coco.py
The changes include 1) code reformatting, 2) docstrings.
    )annotationsN)Sequence)Anyc                   @  s   e Zd Z					d:d;ddZd<ddZd=ddZd>ddZd?d!d"Zd@d&d'Zd@d(d)Z	e
	*	*	+dAdBd4d5Ze
	*	*	+dAdCd6d7ZdDd8d9Zd*S )E
COCOMetric皙?      ?g      ?r   r   g?      d   TclassesSequence[str]iou_listSequence[float]	iou_rangemax_detectionSequence[int]	per_classboolverbosec           	   	   C  s  || _ || _|| _t|}tj|d |d tt|d |d  |d  d dd}t||| _	|| _
t|ddtjf | j	tj kd | _t|ddtjf | j	tj kd | _| j	| j |k rt| j	| j |k sxtdtjdd	ttd
d dd| _|| _dS )a	  
        Class to compute COCO metrics
        Metrics computed includes,

        - mAP over the IoU range specified by `iou_range` at last value of `max_detection`
        - AP values at IoU thresholds specified by `iou_list` at last value of `max_detection`
        - AR over max detections thresholds defined by `max_detection` (over iou range)

        Args:
            classes (Sequence[str]): name of each class (index needs to correspond to predicted class indices!)
            iou_list (Sequence[float]): specific thresholds where ap is evaluated and saved
            iou_range (Sequence[float]): (start, stop, step) for mAP iou thresholds
            max_detection (Sequence[int]): maximum number of detections per image
            verbose (bool): log time needed for evaluation

        Example:

            .. code-block:: python

                from monai.data.box_utils import box_iou
                from monai.apps.detection.metrics.coco import COCOMetric
                from monai.apps.detection.metrics.matching import matching_batch
                # 3D example outputs of one image from detector
                val_outputs_all = [
                        {"boxes": torch.tensor([[1,1,1,3,4,5]],dtype=torch.float16),
                        "labels": torch.randint(3,(1,)),
                        "scores": torch.randn((1,)).absolute()},
                ]
                val_targets_all = [
                        {"boxes": torch.tensor([[1,1,1,2,6,4]],dtype=torch.float16),
                        "labels": torch.randint(3,(1,))},
                ]

                coco_metric = COCOMetric(
                    classes=['c0','c1','c2'], iou_list=[0.1], max_detection=[10]
                )
                results_metric = matching_batch(
                    iou_fn=box_iou,
                    iou_thresholds=coco_metric.iou_thresholds,
                    pred_boxes=[val_data_i["boxes"].numpy() for val_data_i in val_outputs_all],
                    pred_classes=[val_data_i["labels"].numpy() for val_data_i in val_outputs_all],
                    pred_scores=[val_data_i["scores"].numpy() for val_data_i in val_outputs_all],
                    gt_boxes=[val_data_i["boxes"].numpy() for val_data_i in val_targets_all],
                    gt_classes=[val_data_i["labels"].numpy() for val_data_i in val_targets_all],
                )
                val_metric_dict = coco_metric(results_metric)
                print(val_metric_dict)
        r   r      T)endpointNzxRequire self.iou_thresholds[self.iou_list_idx] == iou_list_np and self.iou_thresholds[self.iou_range_idx] == _iou_range.        g      ?g      Y@)r   r   r   nparraylinspaceintroundunion1diou_thresholdsr   nonzeronewaxisiou_list_idxiou_range_idxall
ValueErrorrecall_thresholdsmax_detections)	selfr   r   r   r   r   r   Ziou_list_npZ
_iou_range r+   c/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/metrics/coco.py__init__L   s(   9
2**"
zCOCOMetric.__init__argsr   kwargsreturn5tuple[dict[str, float], dict[str, np.ndarray] | None]c                 O  s   | j |i |S )a  
        Compute metric. See :func:`compute` for more information.

        Args:
            *args: positional arguments passed to :func:`compute`
            **kwargs: keyword arguments passed to :func:`compute`

        Returns:
            dict[str, float]: dictionary with scalar values for evaluation
            dict[str, np.ndarray]: dictionary with arrays, e.g. for visualization of graphs
        )compute)r*   r.   r/   r+   r+   r,   __call__   s   zCOCOMetric.__call__
np.ndarrayNonec                 G  sH   t |  }|D ]}|jd |kr!td|jd  d|   dqdS )z
        Check if shape of input in first dimension is consistent with expected IoU values
        (assumes IoU dimension is the first dimension)

        Args:
            args: array like inputs with shape function
        r   zIRequire arg.shape[0] == len(self.get_iou_thresholds()). Got arg.shape[0]=z, self.get_iou_thresholds()=.N)lenget_iou_thresholdsshaper'   )r*   r.   Znum_iousargr+   r+   r,   check_number_of_iou   s   zCOCOMetric.check_number_of_iouc                 C  s
   t | jS )z
        Return IoU thresholds needed for this metric in an numpy array

        Returns:
            Sequence[float]: IoU thresholds [M], M is the number of thresholds
        )listr!   )r*   r+   r+   r,   r8      s   
zCOCOMetric.get_iou_thresholdsresults_list&list[dict[int, dict[str, np.ndarray]]]tuple[dict[str, float], None]c                 C  s   | j rtd t }| j|d}| j r%t }td|| dd i }|| | || | | j rJt }td|| dd |dfS )	a  
        Compute COCO metrics

        Args:
            results_list (list[dict[int, dict[str, np.ndarray]]]): list with results per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored

        Returns:
            dict[str, float], dictionary with coco metrics
        z Start COCO metric computation...)r=   z(Statistics for COCO metrics finished (t=z0.2fzs).zCOCO metrics computed in t=zs.N)r   loggerinfotime_compute_statisticsupdate_compute_ap_compute_ar)r*   r=   ticdataset_statisticstocresultsr+   r+   r,   r2      s   
zCOCOMetric.computerH   dict[str, np.ndarray | list]dict[str, float]c                 C  sb  i }| j red| j d dd| j d dd| j d dd| jd  }| j|| jdd	||< | jret| jD ]/\}}| d
| j d dd| j d dd| j d dd| jd  	}| j|| j|dd||< q5| jD ]F}d| j| dd| jd  }| j||gdd	||< | jrt| jD ]!\}}| d| j| dd| jd  }| j||g|dd||< qqh|S )a  
        Compute AP metrics

        Args:
            dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored
        ZmAP_IoU_r   .2f_r   r   _MaxDet_iou_idxmax_det_idxZ	_mAP_IoU_rR   cls_idxrS   ZAP_IoU_Z_AP_IoU_)	r   r)   
_select_apr%   r   	enumerater   r$   r!   )r*   rH   rJ   keyrU   cls_stridxr+   r+   r,   rE      s>   ,




"zCOCOMetric._compute_apc           	      C  sX  i }t | jD ]Z\}}d| jd dd| jd dd| jd dd| }| j||d||< | jrat | jD ]*\}}| d	| jd dd| jd dd| jd dd| 	}| j|||d
||< q6q| jD ]D}d| j| dd| jd  }| j||dd||< | jrt | jD ] \}}| d| j| dd| jd  }| j|||dd||< qqe|S )a  
        Compute AR metrics

        Args:
            dataset_statistics (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored
        ZmAR_IoU_r   rM   rN   r   r   rO   )rS   Z	_mAR_IoU_)rU   rS   ZAR_IoU_rP   rQ   Z_AR_IoU_rT   )rW   r)   r   
_select_arr   r   r$   r!   )	r*   rH   rJ   rS   max_detrX   rU   rY   rZ   r+   r+   r,   rF     s6   4



"zCOCOMetric._compute_arNrP   dictrR   #int | list[int] | np.ndarray | NonerU   int | Sequence[int] | NonerS   r   floatc                 C  sL   | d }|dur|| }|dur|d|ddf }|d|f }t t|S )a  
        Compute average precision

        Args:
            dataset_statistics (dict): computed statistics over dataset

                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
                  detection thresholds
                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
                - `precision`: Precision values at specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
                - `scores`: Scores corresponding to specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
            iou_idx: index of IoU values to select for evaluation(if None, all values are used)
            cls_idx: class indices to select, if None all classes will be selected
            max_det_idx (int): index to select max detection threshold from data

        Returns:
            np.ndarray: AP value
        	precisionN.)r`   r   mean)rH   rR   rU   rS   precr+   r+   r,   rV   H  s   zCOCOMetric._select_apc                 C  sl   | d }|dur|| }|dur|d|ddf }|d|f }t ||dk dkr+dS tt||dk S )a  
        Compute average recall

        Args:
            dataset_statistics (dict): computed statistics over dataset

                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
                  detection thresholds
                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
                - `precision`: Precision values at specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
                - `scores`: Scores corresponding to specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
            iou_idx: index of IoU values to select for evaluation(if None, all values are used)
            cls_idx: class indices to select, if None all classes will be selected
            max_det_idx (int): index to select max detection threshold from data

        Returns:
            np.ndarray: recall value
        recallN.rP   r   g      )r7   r`   r   rb   )rH   rR   rU   rS   recr+   r+   r,   r[   k  s   zCOCOMetric._select_arc              	     sv  t | j}t | j}t | j}t | j}t||||f }t|||f }t||||f }t| jD ]\ }	t| jD ]\}
 fdd|D }t |dkr[t	d|	  q?t
fdd|D }tj| dd}|| }tj
fdd|D d	d
dd|f }tj
fdd|D d	d
dd|f }| || t
dd |D }tt|dk}|dkrt	d|	  q?t|t|}tt|t|}tj|d	d
jtjd}tj|d	d
jtjd}tt||D ]8\}\}}t|t|}}t|||| j|\}}}||| |
f< |||dd |
f< |||dd |
f< qq?q6||||g|||dS )a  
        Compute statistics needed for COCO metrics (mAP, AP of individual classes, mAP@IoU_Thresholds, AR)
        Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py

        Args:
            results_list (list[dict[int, dict[str, np.ndarray]]]): list with result s per image (in list)
                per category (dict). Inner dict contains multiple results obtained by :func:`box_matching_batch`.

                - `dtMatches`: matched detections [T, D], where T = number of
                  thresholds, D = number of detections
                - `gtMatches`: matched ground truth boxes [T, G], where T = number
                  of thresholds, G = number of ground truth
                - `dtScores`: prediction scores [D] detection scores
                - `gtIgnore`: ground truth boxes which should be ignored
                  [G] indicate whether ground truth should be ignored
                - `dtIgnore`: detections which should be ignored [T, D],
                  indicate which detections should be ignored

        Returns:
            dict: computed statistics over dataset
                - `counts`: Number of thresholds, Number recall thresholds, Number of classes, Number of max
                  detection thresholds
                - `recall`: Computed recall values [num_iou_th, num_classes, num_max_detections]
                - `precision`: Precision values at specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
                - `scores`: Scores corresponding to specified recall thresholds
                  [num_iou_th, num_recall_th, num_classes, num_max_detections]
        c                   s   g | ]
} |v r|  qS r+   r+   .0r)rU   r+   r,   
<listcomp>      z2COCOMetric._compute_statistics.<locals>.<listcomp>r   z4WARNING, no results found for coco metric for class c                   s   g | ]
}|d  d  qS )ZdtScoresr   r+   rf   r\   r+   r,   ri     rj   	mergesort)kindc                   $   g | ]}|d  ddd f qS )Z	dtMatchesNr   r+   rf   rk   r+   r,   ri        $ r   )axisNc                   rn   )ZdtIgnoreNr   r+   rf   rk   r+   r,   ri     ro   c                 S  s   g | ]}|d  qS )ZgtIgnorer+   rf   r+   r+   r,   ri     s    z/WARNING, no gt found for coco metric for class )dtype)countsrd   ra   scores)r7   r!   r(   r   r)   r   onesrW   r@   warningconcatenateargsortr;   r   count_nonzerological_andlogical_notcumsumastypefloat32zipr   _compute_stats_single_threshold)r*   r=   Z
num_iou_thnum_recall_thnum_classesZnum_max_detectionsra   rd   rs   Zcls_irS   rJ   Z	dt_scoresindsdt_scores_sortedZ
dt_matchesZ
dt_ignoresZ	gt_ignorenum_gtZtpsfpstp_sumZfp_sumZth_indtpfprh   psr+   )rU   r\   r,   rC     sR   



((
'zCOCOMetric._compute_statistics)r   r	   r
   TT)r   r   r   r   r   r   r   r   r   r   r   r   )r.   r   r/   r   r0   r1   )r.   r4   r0   r5   )r0   r   )r=   r>   r0   r?   )rH   rK   r0   rL   )NNrP   )
rH   r]   rR   r^   rU   r_   rS   r   r0   r`   )
rH   r]   rR   r_   rU   r_   rS   r   r0   r`   )r=   r>   r0   rK   )__name__
__module____qualname__r-   r3   r;   r8   r2   rE   rF   staticmethodrV   r[   rC   r+   r+   r+   r,   r   J   s0    
T


	
(
/*"&r   r   r4   r   r   r(   np.ndarray | Sequence[float]r   r   r0   $tuple[float, np.ndarray, np.ndarray]c                 C  s   t |}| | }| ||  td  }t | r|d }nd}dg| }	t|f}
| }tt | d ddD ]}|| ||d  krK|| ||d < q7tj||dd}zt|D ]\}}|| |	|< || |
|< qYW n	 tyt   Y nw |t	|	t	|
fS )a  
    Compute recall value, precision curve and scores thresholds
    Adapted from https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py

    Args:
        tp (np.ndarray): cumsum over true positives [R], R is the number of detections
        fp (np.ndarray): cumsum over false positives [R], R is the number of detections
        dt_scores_sorted (np.ndarray): sorted (descending) scores [R], R is the number of detections
        recall_thresholds (Sequence[float]): recall thresholds which should be evaluated
        num_gt (int): number of ground truth bounding boxes (excluding boxes which are ignored)

    Returns:
        - float, overall recall for given IoU value
        - np.ndarray, precision values at defined recall values
          [RTH], where RTH is the number of recall thresholds
        - np.ndarray, prediction scores corresponding to recall values
          [RTH], where RTH is the number of recall thresholds
    r   rP   r   r   left)side)
r7   r   spacingzerostolistrangesearchsortedrW   BaseExceptionr   )r   r   r   r(   r   r   rcprrd   ra   Z	th_scoresir   Zsave_idxZarray_indexr+   r+   r,   r     s.   

r   )r   r4   r   r4   r   r4   r(   r   r   r   r0   r   )__doc__
__future__r   loggingr@   rB   collections.abcr   typingr   numpyr   r   r   r+   r+   r+   r,   <module>   s   :   !