o
    i                     @  s   d dl mZ d dlZd dlmZmZ d dlZerd dlmZ	 d dl
Z
d dlmZmZ ddlmZ G dd deZdddZejfdddZdS )    )annotationsN)TYPE_CHECKINGcast)Averagelook_up_option   )CumulativeIterationMetricc                      s>   e Zd ZdZejfd fddZdddZddddZ  Z	S )ROCAUCMetrica  
    Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.
    The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    averageAverage | strreturnNonec                   s   t    || _d S N)super__init__r
   )selfr
   	__class__ V/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/rocauc.pyr   5   s   

zROCAUCMetric.__init__y_predtorch.Tensory!tuple[torch.Tensor, torch.Tensor]c                 C  s   ||fS r   r   )r   r   r   r   r   r   _compute_tensor9   s   zROCAUCMetric._compute_tensorNAverage | str | None"np.ndarray | float | npt.ArrayLikec                 C  s@   |   \}}t|tjrt|tjstdt|||p| jdS )as  
        Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
        This function reads the buffers and computes the area under the ROC.

        Args:
            average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
                Type of averaging performed if not binary classification. Defaults to `self.average`.

        z$y_pred and y must be PyTorch Tensor.)r   r   r
   )
get_buffer
isinstancetorchTensor
ValueErrorcompute_roc_aucr
   )r   r
   r   r   r   r   r   	aggregate<   s   
zROCAUCMetric.aggregate)r
   r   r   r   )r   r   r   r   r   r   r   )r
   r   r   r   )
__name__
__module____qualname____doc__r   MACROr   r   r#   __classcell__r   r   r   r   r	      s
    
r	   r   r   r   r   floatc                 C  s  |  |     krdkrn tdt|t| kstd| }t|dkr8td|  d tdS |t	j
ddg|j|jdsVtd|  d tdS t|}|  }||   }| |   } d	 } } }}t|D ]T}	tt||	 }
|	d |k r| |	 | |	d  kr||
7 }|d|
 7 }qz|| dkr||
7 }|d|
 7 }||7 }||||d
   7 }d }}qz|
dkr||7 }qz|d7 }qz||||   S )Nr   z7y and y_pred must be 1 dimension data with same length.zy values can not be all z(, skip AUC computation and return `Nan`.nanr   )dtypedevicez y values must be 0 or 1, but in g           )
ndimensionlenAssertionErroruniquewarningswarnitemr*   equalr   tensorr,   r-   tolistargsortcpunumpyranger   )r   r   y_uniquenindicesnnegZauctmp_posZtmp_negiy_ir   r   r   
_calculateN   sB    

rD   r
   r   r   c                 C  sp  |   }|  }|dvrtd| j d|dvr"td|j d|dkr5| jd dkr5| jdd} d}|dkrF|jd dkrF|jdd}|dkrOt| |S |j| jkrbtd	| j d
|j dt|t}|tjkrut|  | S |	dd| 	dd}} dd t
| |D }|tjkr|S |tjkrt|S |tjkrdd |D }tj||dS td| d)aF  Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.

    Args:
        y_pred: input data to compute, typical classification model output.
            the first dim must be batch, if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        y: ground truth to compute ROC AUC metric, the first dim must be batch.
            if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    Raises:
        ValueError: When ``y_pred`` dimension is not one of [1, 2].
        ValueError: When ``y`` dimension is not one of [1, 2].
        ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

    Note:
        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    )r   r.   zPPredictions should be of shape (batch_size, num_classes) or (batch_size, ), got .zLTargets should be of shape (batch_size, num_classes) or (batch_size, ), got r.   r   )dimz.data shapes of y_pred and y do not match, got z and r   c                 S  s   g | ]	\}}t ||qS r   )rD   ).0y_pred_y_r   r   r   
<listcomp>   s    z#compute_roc_auc.<locals>.<listcomp>c                 S  s   g | ]}t |qS r   )sum)rH   rJ   r   r   r   rK      s    )weightszUnsupported average: z?, available options are ["macro", "weighted", "micro", "none"].)r/   r!   shapesqueezerD   r   r   MICROflatten	transposezipNONEr(   npmeanWEIGHTEDr
   )r   r   r
   y_pred_ndimy_ndimZ
auc_valuesrM   r   r   r   r"   s   s<   #






r"   )r   r   r   r   r   r*   )r   r   r   r   r
   r   r   r   )
__future__r   r3   typingr   r   r;   rU   numpy.typingnptr   monai.utilsr   r   metricr   r	   rD   r(   r"   r   r   r   r   <module>   s   
1&