U
    Phf                     @  s   d dl mZ d dlZd dlmZmZ d dlZer<d dlmZ	 d dl
Z
d dlmZmZ ddlmZ G dd deZd	d	d
dddZejfd	d	dddddZdS )    )annotationsN)TYPE_CHECKINGcast)Averagelook_up_option   )CumulativeIterationMetricc                      sR   e Zd ZdZejfddd fddZdddd	d
dZddddddZ  Z	S )ROCAUCMetrica  
    Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.
    The input `y_pred` and `y` can be a list of `channel-first` Tensor or a `batch-first` Tensor.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    Average | strNone)averagereturnc                   s   t    || _d S N)super__init__r   )selfr   	__class__ I/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/rocauc.pyr   5   s    
zROCAUCMetric.__init__torch.Tensorz!tuple[torch.Tensor, torch.Tensor]y_predyr   c                 C  s   ||fS r   r   )r   r   r   r   r   r   _compute_tensor9   s    zROCAUCMetric._compute_tensorNzAverage | str | None"np.ndarray | float | npt.ArrayLikec                 C  s@   |   \}}t|tjr$t|tjs,tdt|||p:| jdS )as  
        Typically `y_pred` and `y` are stored in the cumulative buffers at each iteration,
        This function reads the buffers and computes the area under the ROC.

        Args:
            average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
                Type of averaging performed if not binary classification. Defaults to `self.average`.

        z$y_pred and y must be PyTorch Tensor.)r   r   r   )
get_buffer
isinstancetorchTensor
ValueErrorcompute_roc_aucr   )r   r   r   r   r   r   r   	aggregate<   s    
zROCAUCMetric.aggregate)N)
__name__
__module____qualname____doc__r   MACROr   r   r"   __classcell__r   r   r   r   r	      s   r	   r   floatr   c                 C  s  |  |     krdkr0n nt|t| ks8td| }t|dkrjtd|  d tdS |t	j
ddg|j|jdstd|  d tdS t|}|  }||   }| |   } d	 } } }}t|D ]}	tt||	 }
|	d |k r:| |	 | |	d  kr:||
7 }|d|
 7 }q|| dkr||
7 }|d|
 7 }||7 }||||d
   7 }d }}q|
dkr||7 }q|d7 }q||||   S )Nr   z7y and y_pred must be 1 dimension data with same length.zy values can not be all z(, skip AUC computation and return `Nan`.nanr   )dtypedevicez y values must be 0 or 1, but in g           )
ndimensionlenAssertionErroruniquewarningswarnitemr)   equalr   tensorr+   r,   tolistargsortcpunumpyranger   )r   r   Zy_uniquenindicesnnegZaucZtmp_posZtmp_negiZy_ir   r   r   
_calculateN   s>    0$


r@   r
   r   )r   r   r   r   c                 C  sz  |   }|  }|dkr*td| j d|dkrDtd|j d|dkrj| jd dkrj| jdd} d}|dkr|jd dkr|jdd}|dkrt| |S |j| jkrtd	| j d
|j dt|t}|tjkrt|  | S |	dd| 	dd }} dd t
| |D }|tjkr(|S |tjkr>t|S |tjkrfdd |D }tj||dS td| ddS )aF  Computes Area Under the Receiver Operating Characteristic Curve (ROC AUC). Referring to:
    `sklearn.metrics.roc_auc_score <https://scikit-learn.org/stable/modules/generated/
    sklearn.metrics.roc_auc_score.html#sklearn.metrics.roc_auc_score>`_.

    Args:
        y_pred: input data to compute, typical classification model output.
            the first dim must be batch, if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        y: ground truth to compute ROC AUC metric, the first dim must be batch.
            if multi-classes, it must be in One-Hot format.
            for example: shape `[16]` or `[16, 1]` for a binary data, shape `[16, 2]` for 2 classes data.
        average: {``"macro"``, ``"weighted"``, ``"micro"``, ``"none"``}
            Type of averaging performed if not binary classification.
            Defaults to ``"macro"``.

            - ``"macro"``: calculate metrics for each label, and find their unweighted mean.
                This does not take label imbalance into account.
            - ``"weighted"``: calculate metrics for each label, and find their average,
                weighted by support (the number of true instances for each label).
            - ``"micro"``: calculate metrics globally by considering each element of the label
                indicator matrix as a label.
            - ``"none"``: the scores for each class are returned.

    Raises:
        ValueError: When ``y_pred`` dimension is not one of [1, 2].
        ValueError: When ``y`` dimension is not one of [1, 2].
        ValueError: When ``average`` is not one of ["macro", "weighted", "micro", "none"].

    Note:
        ROCAUC expects y to be comprised of 0's and 1's. `y_pred` must be either prob. estimates or confidence values.

    )r   r-   zPPredictions should be of shape (batch_size, num_classes) or (batch_size, ), got .zLTargets should be of shape (batch_size, num_classes) or (batch_size, ), got r-   r   )dimz.data shapes of y_pred and y do not match, got z and r   c                 S  s   g | ]\}}t ||qS r   )r@   ).0Zy_pred_y_r   r   r   
<listcomp>   s     z#compute_roc_auc.<locals>.<listcomp>c                 S  s   g | ]}t |qS r   )sum)rD   rE   r   r   r   rF      s     )weightszUnsupported average: z?, available options are ["macro", "weighted", "micro", "none"].N)r.   r    shapesqueezer@   r   r   MICROflatten	transposezipNONEr'   npmeanWEIGHTEDr   )r   r   r   Zy_pred_ndimZy_ndimZ
auc_valuesrH   r   r   r   r!   s   s<    #



r!   )
__future__r   r2   typingr   r   r:   rP   numpy.typingnptr   monai.utilsr   r   metricr   r	   r@   r'   r!   r   r   r   r   <module>   s   1&