U
    Ph                      @  s   d dl mZ d dlZd dlmZ d dlZd dlmZ d dlm	Z	 ddl
mZ G dd	 d	eZG d
d deZddddddddddZdddddddddZdS )    )annotationsN)Any)ignore_background)MetricReduction   )Metricc                      s@   e Zd ZdZdddddd	d
 fddZdddddZ  ZS )VarianceMetrica  
    Compute the Variance of a given T-repeats N-dimensional array/tensor. The primary usage is as an uncertainty based
    metric for Active Learning.

    It can return the spatial variance/uncertainty map based on user choice or a single scalar value via mean/sum of the
    variance for scoring purposes

    Args:
        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
        spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image dimensions
        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
        threshold: To avoid NaN's a threshold is used to replace zero's

    TFsumMb@?boolstrfloatNone)include_backgroundspatial_mapscalar_reduction	thresholdreturnc                   s&   t    || _|| _|| _|| _d S N)super__init__r   r   r   r   )selfr   r   r   r   	__class__ Z/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/active_learning_metrics.pyr   )   s
    
zVarianceMetric.__init__r   )y_predr   c                 C  s   t || j| j| j| jdS )(  
        Args:
            y_pred: Predicted segmentation, typically segmentation model output.
                It must be N-repeats, repeat-first tensor [N,C,H,W,D].

        Returns:
            Pytorch tensor of scalar value of variance as uncertainty or a spatial map of uncertainty

        )r   r   r   r   r   )compute_variancer   r   r   r   )r   r   r   r   r   __call__6   s    
zVarianceMetric.__call__)TFr	   r
   __name__
__module____qualname____doc__r   r   __classcell__r   r   r   r   r      s       r   c                      s>   e Zd ZdZddddd fdd	Zd
d
ddddZ  ZS )LabelQualityScorea  
    The assumption is that the DL model makes better predictions than the provided label quality, hence the difference
    can be treated as a label quality score

    It can be combined with variance/uncertainty for active learning frameworks to factor in the quality of label along
    with uncertainty
    Args:
        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
        spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image
        dimensions
        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used

    Tr	   r   r   r   )r   r   r   c                   s   t    || _|| _d S r   )r   r   r   r   )r   r   r   r   r   r   r   X   s    
zLabelQualityScore.__init__r   torch.Tensor | None)r   yr   c                 C  s   t ||| j| jdS )r   )r   r(   r   r   )label_quality_scorer   r   )r   r   r(   r   r   r   r   ]   s    
   zLabelQualityScore.__call__)Tr	   r    r   r   r   r   r&   I   s   r&   TFmeanr
   ztorch.Tensorr   r   r   r'   )r   r   r   r   r   r   c                 C  s   |   } |s | }t| |d\} }|| | dk< t| j}|dk rP|rPtd dS | j}|d |d  g}td|D ]}	|||	  qrt	| |}
tj
|
ddd	}|r|S |tjkrt|S |tjkrt|S td
| ddS )a  
    Args:
        y_pred: [N, C, H, W, D] or [N, C, H, W] or [N, C, H] where N is repeats, C is channels and H, W, D stand for
            Height, Width & Depth
        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
        spatial_map: Boolean, if set to True, spatial map of variance will be returned corresponding to i/p image
            dimensions
        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used
        threshold: To avoid NaN's a threshold is used to replace zero's
    Returns:
        A single scalar uncertainty/variance value or the spatial map of uncertainty/variance
    r   r(   r      z@Spatial map requires a 2D/3D image with N-repeats and C-channelsNr      F)dimunbiasedscalar_reduction= not supported.)r   r   lenshapewarningswarnrangeappendtorchreshapevarr   MEANr*   SUMr	   
ValueError)r   r   r   r   r   r(   n_lenZn_shape	new_shapeZeach_dim_idxZ
y_reshapedvariancer   r   r   r   l   s,    





r   )r   r(   r   r   r   c                 C  s   |   } |  }|s$t| |d\} }t| j}|dk rL|dkrLtd dS t| | }|tj	krh|S |tj
krtj|ttd|dS |tjkrtj|ttd|dS td| d	dS )
a  
    The assumption is that the DL model makes better predictions than the provided label quality, hence the difference
    can be treated as a label quality score

    Args:
        y_pred: Input data of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is
            channels and H, W, D stand for Height, Width & Depth
        y: Ground Truth of dimension [B, C, H, W, D] or [B, C, H, W] or [B, C, H] where B is Batch-size, C is channels
            and H, W, D stand for Height, Width & Depth
        include_background: Whether to include the background of the spatial image or channel 0 of the 1-D vector
        scalar_reduction: reduction type of the metric, either 'sum' or 'mean' can be used to retrieve a single scalar
            value, if set to 'none' a spatial map will be returned

    Returns:
        A single scalar absolute difference value as score with a reduction based on sum/mean or the spatial map of
        absolute difference
    r+   r,   nonez^Reduction set to None, Spatial map return requires a 2D/3D image of B-Batchsize and C-channelsNr   )r.   r0   r1   )r   r   r2   r3   r4   r5   r8   absr   NONEr;   r*   listr6   r<   r	   r=   )r   r(   r   r   r>   Zabs_diff_mapr   r   r   r)      s     




r)   )TFr*   r
   )Tr*   )
__future__r   r4   typingr   r8   Zmonai.metrics.utilsr   monai.utilsr   metricr   r   r&   r   r)   r   r   r   r   <module>   s    0%    :   