o
    i-                     @  sr   d dl mZ d dlmZmZ d dlZd dlZd dlm	Z	 	d&d'ddZ
		 d(d)ddZd*ddZ	d+d,d$d%ZdS )-    )annotations)AnycastN)NdarrayOrTensorprobsr   coordsevaluation_masklabels_to_excludelist | Nonereturn,tuple[NdarrayOrTensor, NdarrayOrTensor, int]c           
      C  s`  t | t |kstd| j d|j dt |jdkr&|jd t |jks3td|j d|j dt| tjrA|    } t|tjrO|   }t|tjr]|   }|du rcg }t	
|}t	j|ft	jd}|t|j }| t	|d	k }td|d D ]}||vr||v r| t	||k 
 ||d < q|t | }	||tt|	fS )
ar  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish
    true positive and false positive predictions. A true positive prediction is defined when
    the detection point is within the annotated ground truth region.

    Args:
        probs: an array with shape (n,) that represents the probabilities of the detections.
            Where, n is the number of predicted detections.
        coords: an array with shape (n, n_dim) that represents the coordinates of the detections.
            The dimensions must be in the same order as in `evaluation_mask`.
        evaluation_mask: the ground truth mask for evaluation.
        labels_to_exclude: labels in this list will not be counted for metric calculation.

    Returns:
        fp_probs: an array that contains the probabilities of the false positive detections.
        tp_probs: an array that contains the probabilities of the True positive detections.
        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.

    zthe length of probs z", should be the same as of coords .   zcoords z9 need to represent the same number of dimensions as mask N)dtyper   )len
ValueErrorshape
isinstancetorchTensordetachcpunumpynpmaxzerosfloat32tupleTwhereranger   int)
r   r   r   r	   	max_labeltp_probsZhittedlabelfp_probsinum_targets r'   T/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/froc.pycompute_fp_tp_probs_nd   s0   "
r)   y_coordx_coordresolution_levelr!   c                 C  s   t |tjr|   }t |tjr|   }|td| t}|td| t}t	j
||gdd}t| |||dS )a  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to distinguish
    true positive and false positive predictions. A true positive prediction is defined when
    the detection point is within the annotated ground truth region.

    Args:
        probs: an array with shape (n,) that represents the probabilities of the detections.
            Where, n is the number of predicted detections.
        y_coord: an array with shape (n,) that represents the Y-coordinates of the detections.
        x_coord: an array with shape (n,) that represents the X-coordinates of the detections.
        evaluation_mask: the ground truth mask for evaluation.
        labels_to_exclude: labels in this list will not be counted for metric calculation.
        resolution_level: the level at which the evaluation mask is made.

    Returns:
        fp_probs: an array that contains the probabilities of the false positive detections.
        tp_probs: an array that contains the probabilities of the True positive detections.
        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.

       r   )axis)r   r   r   r	   )r   r   r   r   r   r   powastyper!   r   stackr)   )r   r*   r+   r   r	   r,   stackedr'   r'   r(   compute_fp_tp_probsN   s   r3   r$   np.ndarray | torch.Tensorr#   r&   
num_imagestuple[np.ndarray, np.ndarray]c           
      C  s   t | t|stdt | tjr|    } t |tjr'|   }g g }}tt	t
| t
| }|dd D ]}|| |k  |||k  q>|d |d t|t| }t|t| }	||	fS )a  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to compute
    the required data for plotting the Free Response Operating Characteristic (FROC) curve.

    Args:
        fp_probs: an array that contains the probabilities of the false positive detections for all
            images under evaluation.
        tp_probs: an array that contains the probabilities of the True positive detections for all
            images under evaluation.
        num_targets: the total number of targets (excluding `labels_to_exclude`) for all images under evaluation.
        num_images: the number of images under evaluation.

    z&fp and tp probs should have same type.r   Nr   )r   typeAssertionErrorr   r   r   r   r   sortedsetlistappendsumr   asarrayfloat)
r$   r#   r&   r5   Z	total_fpsZ	total_tpsZ	all_probsthreshfps_per_imagetotal_sensitivityr'   r'   r(   compute_froc_curve_dataz   s    


rC   g      ?g      ?r   r-         rA   
np.ndarrayrB   eval_thresholdsr   r   c                 C  s,   t || ddd |ddd }t |S )a  
    This function is modified from the official evaluation code of
    `CAMELYON 16 Challenge <https://camelyon16.grand-challenge.org/>`_, and used to compute
    the challenge's second evaluation metric, which is defined as the average sensitivity at
    the predefined false positive rates per whole slide image.

    Args:
        fps_per_image: the average number of false positives per image for different thresholds.
        total_sensitivity: sensitivities (true positive rates) for different thresholds.
        eval_thresholds: the false positive rates for calculating the average sensitivity. Defaults
            to (0.25, 0.5, 1, 2, 4, 8) which is the same as the CAMELYON 16 Challenge.

    N)r   interpmean)rA   rB   rH   Zinterp_sensr'   r'   r(   compute_froc_score   s   "
rL   )N)
r   r   r   r   r   r   r	   r
   r   r   )Nr   )r   r   r*   r   r+   r   r   r   r	   r
   r,   r!   r   r   )
r$   r4   r#   r4   r&   r!   r5   r!   r   r6   )rD   )rA   rG   rB   rG   rH   r   r   r   )
__future__r   typingr   r   r   r   r   monai.configr   r)   r3   rC   rL   r'   r'   r'   r(   <module>   s   =
,%