U
    Ph-                  	   @  s   d dl mZ d dlmZmZ d dlZd dlZd dlm	Z	 dddddddd	d
Z
dddddddddddZddddddddZddddddddZdS )    )annotations)AnycastN)NdarrayOrTensorr   zlist | Nonez,tuple[NdarrayOrTensor, NdarrayOrTensor, int])probscoordsevaluation_masklabels_to_excludereturnc           
      C  sf  t | t |ks*td| j d|j dt |jdkrL|jd t |jksftd|j d|j dt| tjr|    } t|tjr|   }t|tjr|   }|dkrg }t	
|}t	j|ft	jd}|t|j }| t	|d	k }td|d D ]8}||kr||kr| t	||k 
 ||d < q|t | }	||tt|	fS )
ar  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish
    true positive and false positive predictions. A true positive prediction is defined when
    the detection point is within the annotated ground truth region.

    Args:
        probs: an array with shape (n,) that represents the probabilities of the detections.
            Where, n is the number of predicted detections.
        coords: an array with shape (n, n_dim) that represents the coordinates of the detections.
            The dimensions must be in the same order as in `evaluation_mask`.
        evaluation_mask: the ground truth mask for evaluation.
        labels_to_exclude: labels in this list will not be counted for metric calculation.

    Returns:
        fp_probs: an array that contains the probabilities of the false positive detections.
        tp_probs: an array that contains the probabilities of the True positive detections.
        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.

    zthe length of probs z", should be the same as of coords .   zcoords z9 need to represent the same number of dimensions as mask N)dtyper   )len
ValueErrorshape
isinstancetorchTensordetachcpunumpynpmaxzerosfloat32tupleTwhereranger   int)
r   r   r   r	   	max_labeltp_probsZhittedlabelfp_probsinum_targets r%   G/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/froc.pycompute_fp_tp_probs_nd   s.    "
"r'   r   )r   y_coordx_coordr   r	   resolution_levelr
   c                 C  s   t |tjr|   }t |tjr8|   }|td| t}|td| t}t	j
||gdd}t| |||dS )a  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish
    true positive and false positive predictions. A true positive prediction is defined when
    the detection point is within the annotated ground truth region.

    Args:
        probs: an array with shape (n,) that represents the probabilities of the detections.
            Where, n is the number of predicted detections.
        y_coord: an array with shape (n,) that represents the Y-coordinates of the detections.
        x_coord: an array with shape (n,) that represents the X-coordinates of the detections.
        evaluation_mask: the ground truth mask for evaluation.
        labels_to_exclude: labels in this list will not be counted for metric calculation.
        resolution_level: the level at which the evaluation mask is made.

    Returns:
        fp_probs: an array that contains the probabilities of the false positive detections.
        tp_probs: an array that contains the probabilities of the True positive detections.
        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.

       r   )axis)r   r   r   r	   )r   r   r   r   r   r   powastyper   r   stackr'   )r   r(   r)   r   r	   r*   stackedr%   r%   r&   compute_fp_tp_probsN   s       r1   znp.ndarray | torch.Tensorztuple[np.ndarray, np.ndarray])r"   r!   r$   
num_imagesr
   c           
      C  s   t | t|stdt | tjr2|    } t |tjrN|   }g g  }}tt	t
| t
| }|dd D ](}|| |k  |||k  q||d |d t|t| }t|t| }	||	fS )a  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to compute
    the required data for plotting the Free Response Operating Characteristic (FROC) curve.

    Args:
        fp_probs: an array that contains the probabilities of the false positive detections for all
            images under evaluation.
        tp_probs: an array that contains the probabilities of the True positive detections for all
            images under evaluation.
        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.
        num_images: the number of images under evaluation.

    z&fp and tp probs should have same type.r   Nr   )r   typeAssertionErrorr   r   r   r   r   sortedsetlistappendsumr   asarrayfloat)
r"   r!   r$   r2   Z	total_fpsZ	total_tpsZ	all_probsthreshfps_per_imagetotal_sensitivityr%   r%   r&   compute_froc_curve_dataz   s     


r?   g      ?g      ?r   r+         z
np.ndarrayr   r   )r=   r>   eval_thresholdsr
   c                 C  s,   t || ddd |ddd }t |S )a  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to compute
    the challenge's second evaluation metric, which is defined as the average sensitivity at
    the predefined false positive rates per whole slide image.

    Args:
        fps_per_image: the average number of false positives per image for different thresholds.
        total_sensitivity: sensitivities (true positive rates) for different thresholds.
        eval_thresholds: the false positive rates for calculating the average sensitivity. Defaults
            to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.

    N)r   interpmean)r=   r>   rC   Zinterp_sensr%   r%   r&   compute_froc_score   s    "rG   )N)Nr   )r@   )
__future__r   typingr   r   r   r   r   monai.configr   r'   r1   r?   rG   r%   r%   r%   r&   <module>   s    =  ,% 