U
    Ph.                     @  s   d dl mZ d dlmZ d dlZd dlmZmZ d dlm	Z	m
Z
mZmZ ddlmZ edd	d
\ZZeddd
\ZZddgZG dd deZG dd deZG dd deZdS )    )annotations)castN)do_metric_reductionignore_background)MetricReductionconvert_to_numpyconvert_to_tensoroptional_import   )CumulativeIterationMetricz)MetricsReloaded.metrics.pairwise_measuresBinaryPairwiseMeasures)nameMultiClassPairwiseMeasuresMetricsReloadedBinaryMetricsReloadedCategoricalc                      sR   e Zd ZdZdejdfdddddd fd	d
ZddddddZdd Z  Z	S )MetricsReloadedWrapperao  Base class for defining MetricsReloaded metrics as a CumulativeIterationMetric.

    Args:
        metric_name: Name of a metric from the MetricsReloaded package.
        include_background: whether to include computation on the first channel of
            the predicted output. Defaults to ``True``.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric,
            thus its shape equals to the shape of the metric.

    TFstrboolMetricReduction | strNonemetric_nameinclude_background	reductionget_not_nansreturnc                   s&   t    || _|| _|| _|| _d S )N)super__init__r   r   r   r   selfr   r   r   r   	__class__ J/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/wrapper.pyr   /   s
    
zMetricsReloadedWrapper.__init__NzMetricReduction | str | Nonez0torch.Tensor | tuple[torch.Tensor, torch.Tensor])r   r   c                 C  sB   |   }t|tjstdt||p(| j\}}| jr>||fS |S )Nz-the data to aggregate must be PyTorch Tensor.)
get_buffer
isinstancetorchTensor
ValueErrorr   r   r   )r   r   datafnot_nansr"   r"   r#   	aggregate<   s
    z MetricsReloadedWrapper.aggregatec                 C  s2   |  }|  }| js&t||d\}}|||jfS )z.Prepares onehot encoded input for metric call.)y_predy)floatr   r   device)r   r-   r.   r"   r"   r#   prepare_onehotF   s
    z%MetricsReloadedWrapper.prepare_onehot)N)
__name__
__module____qualname____doc__r   MEANr   r,   r1   __classcell__r"   r"   r    r#   r      s    
r   c                      sJ   e Zd ZdZdejdfdddddd fd	d
ZddddddZ  ZS )r   a  
    Wraps the binary pairwise metrics of MetricsReloaded.

    Args:
        metric_name: Name of a binary metric from the MetricsReloaded package.
        include_background: whether to include computation on the first channel of
            the predicted output. Defaults to ``True``.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric,
            thus its shape equals to the shape of the metric.

    Example:

    .. code-block:: python

        import torch
        from monai.metrics import MetricsReloadedBinary

        metric_name = "Cohens Kappa"
        metric = MetricsReloadedBinary(metric_name=metric_name)

        # first iteration
        # shape [batch=1, channel=1, 2, 2]
        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])
        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])
        print(metric(y_pred, y))

        # second iteration
        # shape [batch=1, channel=1, 2, 2]
        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]]]])
        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])
        print(metric(y_pred, y))

        # aggregate
        # shape ([batch=2, channel=1])
        print(metric.aggregate(reduction="none"))  # tensor([[0.5], [0.2]])

        # reset
        metric.reset()

    TFr   r   r   r   r   c                   s   t  j||||d d S N)r   r   r   r   )r   r   r   r    r"   r#   r   }   s    zMetricsReloadedBinary.__init__torch.Tensorr-   r.   r   c                 C  s   |  ||\}}}| }|dk r2td| d|jd dksN|jd dkrptd|jd  d|jd  dt|}t|}t||ttd|d	d
}| j|j	krtd| j |j	| j  }t
||dS )a#  Computes a binary (single-class) MetricsReloaded metric from a batch of
        predictions and references.

        Args:
            y_pred: Prediction with dimensions (batch, channel, *spatial), where channel=1.
                The values should be binarized.
            y: Ground-truth with dimensions (batch, channel, *spatial), where channel=1.
                The values should be binarized.

        Raises:
            ValueError: when `y_pred` has less than three dimensions.
            ValueError: when second dimension ~= 1

           Hy_pred should have at least 3 dimensions (batch, channel, spatial), got .r
   zy_pred.shape[1]=z and y.shape[1]=z should be one.   h㈵>)axis	smooth_drUnsupported metric: r0   )r1   
ndimensionr(   shaper   r   tupleranger   metricsr   )r   r-   r.   r0   dimsbpmmetricr"   r"   r#   _compute_tensor   s    "z%MetricsReloadedBinary._compute_tensor	r2   r3   r4   r5   r   r6   r   rL   r7   r"   r"   r    r#   r   O   s   0c                      sN   e Zd ZdZdejddfdddddd	d
 fddZddddddZ  ZS )r   a  
    Wraps the categorical pairwise metrics of MetricsReloaded.


    Args:
        metric_name: Name of a categorical metric from the MetricsReloaded package.
        include_background: whether to include computation on the first channel of
            the predicted output. Defaults to ``True``.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric,
            thus its shape equals to the shape of the metric.
        smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero.

    Example:

    .. code-block:: python

        import torch
        from monai.metrics import MetricsReloadedCategorical

        metric_name = "Weighted Cohens Kappa"
        metric = MetricsReloadedCategorical(metric_name=metric_name)

        # first iteration
        # shape [bach=1, channel=3, 2, 2]
        y_pred = torch.tensor([[[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]])
        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])
        print(metric(y_pred, y))

        # second iteration
        # shape [batch=1, channel=3, 2, 2]
        y_pred = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, 0], [0, 0]]]])
        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])
        print(metric(y_pred, y))

        # aggregate
        # shape ([batch=2, channel=1])
        print(metric.aggregate(reduction="none"))  # tensor([[0.2727], [0.6000]])

        # reset
        metric.reset()

    TFr?   r   r   r   r/   r   )r   r   r   r   rA   r   c                   s   t  j||||d || _d S r8   )r   r   rA   )r   r   r   r   r   rA   r    r"   r#   r      s    z#MetricsReloadedCategorical.__init__r9   r:   c                 C  s
  |  ||\}}}| }|dk r2td| d|jd }||jd |jd d}|d}||jd |jd d}|d}| }t|}t|}t||tt	d|| j
tt	|dd	}| j|jkrtd
| j |j| j  }|d }ttjt||dS )a  Computes a categorical (multi-class) MetricsReloaded metric from a batch of
        predictions and references.

        Args:
            y_pred: Prediction with dimensions (batch, channel, *spatial). The values should be
                one-hot encoded and binarized.
            y: Ground-truth with dimensions (batch, channel, *spatial). The values should be 1
                one-hot encoded and binarized.

        Raises:
            ValueError: when `y_pred` has less than three dimensions.

        r;   r<   r=   r
   r   )r   r>   r
   T)r@   rA   list_values	is_onehotrB   ).NrC   )r1   rD   r(   rE   reshapepermuter   r   rF   rG   rA   listr   rH   r   r&   r'   r   )r   r-   r.   r0   rI   num_classesrJ   rK   r"   r"   r#   rL      s2    




z*MetricsReloadedCategorical._compute_tensorrM   r"   r"   r    r#   r      s   2)
__future__r   typingr   r&   monai.metrics.utilsr   r   monai.utilsr   r   r   r	   rK   r   r   _r   __all__r   r   r   r"   r"   r"   r#   <module>   s    
0g