o
    i.                     @  s   d dl mZ d dlmZ d dlZd dlmZmZ d dlm	Z	m
Z
mZmZ ddlmZ edd	d
\ZZeddd
\ZZddgZG dd deZG dd deZG dd deZdS )    )annotations)castN)do_metric_reductionignore_background)MetricReductionconvert_to_numpyconvert_to_tensoroptional_import   )CumulativeIterationMetricz)MetricsReloaded.metrics.pairwise_measuresBinaryPairwiseMeasures)nameMultiClassPairwiseMeasuresMetricsReloadedBinaryMetricsReloadedCategoricalc                      sB   e Zd ZdZdejdfd fddZ	ddddZdd Z  Z	S )MetricsReloadedWrapperao  Base class for defining MetricsReloaded metrics as a CumulativeIterationMetric.

    Args:
        metric_name: Name of a metric from the MetricsReloaded package.
        include_background: whether to include computation on the first channel of
            the predicted output. Defaults to ``True``.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric,
            thus its shape equals to the shape of the metric.

    TFmetric_namestrinclude_backgroundbool	reductionMetricReduction | strget_not_nansreturnNonec                   s&   t    || _|| _|| _|| _d S N)super__init__r   r   r   r   selfr   r   r   r   	__class__ W/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/wrapper.pyr   /   s
   

zMetricsReloadedWrapper.__init__NMetricReduction | str | None0torch.Tensor | tuple[torch.Tensor, torch.Tensor]c                 C  sB   |   }t|tjstdt||p| j\}}| jr||fS |S )Nz-the data to aggregate must be PyTorch Tensor.)
get_buffer
isinstancetorchTensor
ValueErrorr   r   r   )r   r   datafnot_nansr"   r"   r#   	aggregate<   s
   z MetricsReloadedWrapper.aggregatec                 C  s2   |  }|  }| jst||d\}}|||jfS )z.Prepares onehot encoded input for metric call.)y_predy)floatr   r   device)r   r/   r0   r"   r"   r#   prepare_onehotF   s
   z%MetricsReloadedWrapper.prepare_onehot
r   r   r   r   r   r   r   r   r   r   r   )r   r$   r   r%   )
__name__
__module____qualname____doc__r   MEANr   r.   r3   __classcell__r"   r"   r    r#   r      s    
r   c                      s6   e Zd ZdZdejdfd fddZdddZ  ZS )r   a  
    Wraps the binary pairwise metrics of MetricsReloaded.

    Args:
        metric_name: Name of a binary metric from the MetricsReloaded package.
        include_background: whether to include computation on the first channel of
            the predicted output. Defaults to ``True``.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric,
            thus its shape equals to the shape of the metric.

    Example:

    .. code-block:: python

        import torch
        from monai.metrics import MetricsReloadedBinary

        metric_name = "Cohens Kappa"
        metric = MetricsReloadedBinary(metric_name=metric_name)

        # first iteration
        # shape [batch=1, channel=1, 2, 2]
        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 1.0]]]])
        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])
        print(metric(y_pred, y))

        # second iteration
        # shape [batch=1, channel=1, 2, 2]
        y_pred = torch.tensor([[[[1.0, 0.0], [0.0, 0.0]]]])
        y = torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]])
        print(metric(y_pred, y))

        # aggregate
        # shape ([batch=2, channel=1])
        print(metric.aggregate(reduction="none"))  # tensor([[0.5], [0.2]])

        # reset
        metric.reset()

    TFr   r   r   r   r   r   r   r   r   c                   s   t  j||||d d S N)r   r   r   r   )r   r   r   r    r"   r#   r   }   s   
zMetricsReloadedBinary.__init__r/   torch.Tensorr0   c                 C  s   |  ||\}}}| }|dk rtd| d|jd dks'|jd dkr8td|jd  d|jd  dt|}t|}t||ttd|d	d
}| j|j	vr[td| j |j	| j  }t
||dS )a#  Computes a binary (single-class) MetricsReloaded metric from a batch of
        predictions and references.

        Args:
            y_pred: Prediction with dimensions (batch, channel, *spatial), where channel=1.
                The values should be binarized.
            y: Ground-truth with dimensions (batch, channel, *spatial), where channel=1.
                The values should be binarized.

        Raises:
            ValueError: when `y_pred` has less than three dimensions.
            ValueError: when second dimension ~= 1

           Hy_pred should have at least 3 dimensions (batch, channel, spatial), got .r
   zy_pred.shape[1]=z and y.shape[1]=z should be one.   h㈵>)axis	smooth_drUnsupported metric: r2   )r3   
ndimensionr*   shaper   r   tupleranger   metricsr   )r   r/   r0   r2   dimsbpmmetricr"   r"   r#   _compute_tensor   s   "z%MetricsReloadedBinary._compute_tensorr4   r/   r<   r0   r<   r   r<   	r5   r6   r7   r8   r   r9   r   rN   r:   r"   r"   r    r#   r   O   s    0c                      s8   e Zd ZdZdejddfd fddZdddZ  ZS )r   a  
    Wraps the categorical pairwise metrics of MetricsReloaded.


    Args:
        metric_name: Name of a categorical metric from the MetricsReloaded package.
        include_background: whether to include computation on the first channel of
            the predicted output. Defaults to ``True``.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric,
            thus its shape equals to the shape of the metric.
        smooth_dr: a small constant added to the denominator to avoid nan. OBS: should be greater than zero.

    Example:

    .. code-block:: python

        import torch
        from monai.metrics import MetricsReloadedCategorical

        metric_name = "Weighted Cohens Kappa"
        metric = MetricsReloadedCategorical(metric_name=metric_name)

        # first iteration
        # shape [bach=1, channel=3, 2, 2]
        y_pred = torch.tensor([[[[0, 0], [0, 1]], [[0, 0], [0, 0]], [[1, 1], [1, 0]]]])
        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])
        print(metric(y_pred, y))

        # second iteration
        # shape [batch=1, channel=3, 2, 2]
        y_pred = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [1, 0]], [[0, 0], [0, 0]]]])
        y = torch.tensor([[[[1, 0], [0, 1]], [[0, 1], [0, 0]], [[0, 0], [1, 0]]]])
        print(metric(y_pred, y))

        # aggregate
        # shape ([batch=2, channel=1])
        print(metric.aggregate(reduction="none"))  # tensor([[0.2727], [0.6000]])

        # reset
        metric.reset()

    TFrA   r   r   r   r   r   r   r   rC   r1   r   r   c                   s   t  j||||d || _d S r;   )r   r   rC   )r   r   r   r   r   rC   r    r"   r#   r      s   
z#MetricsReloadedCategorical.__init__r/   r<   r0   c                 C  s
  |  ||\}}}| }|dk rtd| d|jd }||jd |jd d}|d}||jd |jd d}|d}| }t|}t|}t||tt	d|| j
tt	|dd	}| j|jvrptd
| j |j| j  }|d }ttjt||dS )a  Computes a categorical (multi-class) MetricsReloaded metric from a batch of
        predictions and references.

        Args:
            y_pred: Prediction with dimensions (batch, channel, *spatial). The values should be
                one-hot encoded and binarized.
            y: Ground-truth with dimensions (batch, channel, *spatial). The values should be 1
                one-hot encoded and binarized.

        Raises:
            ValueError: when `y_pred` has less than three dimensions.

        r=   r>   r?   r
   r   )r   r@   r
   T)rB   rC   list_values	is_onehotrD   ).NrE   )r3   rF   r*   rG   reshapepermuter   r   rH   rI   rC   listr   rJ   r   r(   r)   r   )r   r/   r0   r2   rK   num_classesrL   rM   r"   r"   r#   rN      s2   




z*MetricsReloadedCategorical._compute_tensor)r   r   r   r   r   r   r   r   rC   r1   r   r   rO   rP   r"   r"   r    r#   r      s    2)
__future__r   typingr   r(   monai.metrics.utilsr   r   monai.utilsr   r   r   r	   rM   r   r   _r   __all__r   r   r   r"   r"   r"   r#   <module>   s   
0g