o
    i!                     @  s   d dl mZ d dlZd dlmZmZ d dlZerd dlmZ	 d dl
Z
d dlmZmZ ddlmZ G dd deZdddZejfdddZdS )    )annotationsN)TYPE_CHECKINGcast)Averagelook_up_option   )CumulativeIterationMetricc                      s>   e Zd ZdZejfd fddZdddZddddZ  Z	S )AveragePrecisionMetrica  
    Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
    imbalanced. It can take values between 0.0 and 1.0, 1.0 being the best possible score.
    It summarizes a Precision-Recall curve as the weighted mean of precisions achieved at each
    threshold, with the increase in recall from the previous threshold used as the weight:

    .. math::
        \text{AP} = \sum_n (R_n - R_{n-1}) P_n
        :label: ap

    where :math:`P_n` and :math:`R_n` are the precision and recall at the :math:`n^{th}` threshold.

    Referring to: `sklearn.metrics.average_precision_score
    <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.

    The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    averageAverage | strreturnNonec                   s   t    || _d S N)super__init__r
   )selfr
   	__class__ a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/average_precision.pyr   @   s   

zAveragePrecisionMetric.__init__y_predtorch.Tensory!tuple[torch.Tensor, torch.Tensor]c                 C  s   ||fS r   r   )r   r   r   r   r   r   _compute_tensorD   s   z&AveragePrecisionMetric._compute_tensorNAverage | str | None"np.ndarray | float | npt.ArrayLikec                 C  s@   |   \}}t|tjrt|tjstdt|||p| jdS )ar  
        Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
        This function reads the buffers and computes the Average Precision.

        Args:
            average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
                Type of averaging performed if not binary classification. Defaults to `self.average`.

        z$y_pred and y must be PyTorch Tensor.)r   r   r
   )
get_buffer
isinstancetorchTensor
ValueErrorcompute_average_precisionr
   )r   r
   r   r   r   r   r   	aggregateG   s   
z AveragePrecisionMetric.aggregate)r
   r   r   r   )r   r   r   r   r   r   r   )r
   r   r   r   )
__name__
__module____qualname____doc__r   MACROr   r   r#   __classcell__r   r   r   r   r	      s
    "
r	   r   r   r   r   floatc           
      C  sb  |  |     krdkrn tdt|t| kstd| }t|dkr8td|  d tdS |t	j
ddg|j|jdsVtd|  d tdS t|}| jd	d
}||   }| |   } d } }}t|D ]2}tt|| }	|d |k r| | | |d  kr||	7 }qz||	7 }||7 }||| |d  7 }d}qz|| S )Nr   z7y and y_pred must be 1 dimension data with same length.zy values can not be all z', skip AP computation and return `Nan`.nanr   )dtypedevicez y values must be 0 or 1, but in T)
descendingg        )
ndimensionlenAssertionErroruniquewarningswarnitemr*   equalr   tensorr,   r-   tolistargsortcpunumpyranger   )
r   r   Zy_uniquenindicesnposapZtmp_posiZy_ir   r   r   
_calculateY   s2    
rB   r
   r   r   c                 C  sp  |   }|  }|dvrtd| j d|dvr"td|j d|dkr5| jd dkr5| jdd} d}|dkrF|jd dkrF|jdd}|dkrOt| |S |j| jkrbtd	| j d
|j dt|t}|tjkrut|  | S |	dd| 	dd}} dd t
| |D }|tjkr|S |tjkrt|S |tjkrdd |D }tj||dS td| d)a  Computes Average Precision (AP). AP is a useful metric to evaluate a classifier when the classes are
    imbalanced. It summarizes a Precision-Recall according to equation :eq:`ap`.
    Referring to: `sklearn.metrics.average_precision_score
    <https://scikit-learn.org/stable/modules/generated/sklearn.metrics.average_precision_score>`_.

    Args:
        y_pred: input data to compute, typical classification model output.
            the first dim must be batch, if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        y: ground truth to compute AP metric, the first dim must be batch.
            if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    Raises:
        ValueError: When ``y_pred`` dimension is not one of [1, 2].
        ValueError: When ``y`` dimension is not one of [1, 2].
        ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

    Note:
        Average Precision expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    )r      zPPredictions should be of shape (batch_size, num_classes) or (batch_size, ), got .zLTargets should be of shape (batch_size, num_classes) or (batch_size, ), got rC   r   )dimz.data shapes of y_pred and y do not match, got z and r   c                 S  s   g | ]	\}}t ||qS r   )rB   ).0Zy_pred_y_r   r   r   
<listcomp>   s    z-compute_average_precision.<locals>.<listcomp>c                 S  s   g | ]}t |qS r   )sum)rG   rH   r   r   r   rI      s    )weightszUnsupported average: z?, available options are ["macro", "weighted", "micro", "none"].)r/   r!   shapesqueezerB   r   r   MICROflatten	transposezipNONEr(   npmeanWEIGHTEDr
   )r   r   r
   Zy_pred_ndimZy_ndimZ	ap_valuesrK   r   r   r   r"   w   s<   $






r"   )r   r   r   r   r   r*   )r   r   r   r   r
   r   r   r   )
__future__r   r3   typingr   r   r;   rS   numpy.typingnptr   monai.utilsr   r   metricr   r	   rB   r(   r"   r   r   r   r   <module>   s   
<