o
     im?                     @  sx   d dl mZ d dlZd dlmZ d dlmZmZ ddlm	Z	 g dZ
G dd	 d	e	Z	
	
	ddddZG dd dZdS )    )annotationsN)do_metric_reduction)MetricReductiondeprecated_arg   )CumulativeIterationMetric)
DiceMetriccompute_dice
DiceHelperc                      sJ   e Zd ZdZdejddddfd fddZdddZ	ddddZ  Z	S ) r   aQ  
    Computes Dice score for a set of pairs of prediction-groundtruth labels. It supports single-channel label maps
    or multi-channel images with class segmentations per channel. This allows the computation for both multi-class
    and multi-label tasks.

    If either prediction ``y_pred`` or ground truth ``y`` have shape BCHW[D], it is expected that these represent one-
    hot segmentations for C number of classes. If either shape is B1HW[D], it is expected that these are label maps
    and the number of classes must be specified by the ``num_classes`` parameter. In either case for either inputs,
    this metric applies no activations and so non-binary values will produce unexpected results if this metric is used
    for binary overlap measurement (ie. either was expected to be one-hot formatted). Soft labels are thus permitted by
    this metric. Typically this implies that raw predictions from a network must first be activated and possibly made
    into label maps, eg. for a multi-class prediction tensor softmax and then argmax should be applied over the channel
    dimensions to produce a label map.

    The ``include_background`` parameter can be set to `False` to exclude the first category (channel index 0) which
    is by convention assumed to be background. If the non-background segmentations are small compared to the total
    image size they can get overwhelmed by the signal from the background. This assumes the shape of both prediction
    and ground truth is BCHW[D].

    The typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Further information can be found in the official
    `MONAI Dice Overview <https://github.com/Project-MONAI/tutorials/blob/main/modules/dice_loss_metric_notes.ipynb>`.

    Example:

    .. code-block:: python

        import torch
        from monai.metrics import DiceMetric
        from monai.losses import DiceLoss
        from monai.networks import one_hot

        batch_size, n_classes, h, w = 7, 5, 128, 128

        y_pred = torch.rand(batch_size, n_classes, h, w)  # network predictions
        y_pred = torch.argmax(y_pred, 1, True)  # convert to label map

        # ground truth as label map
        y = torch.randint(0, n_classes, size=(batch_size, 1, h, w))

        dm = DiceMetric(
            reduction="mean_batch", return_with_label=True, num_classes=n_classes
        )

        raw_scores = dm(y_pred, y)
        print(dm.aggregate())

        # now compute the Dice loss which should be the same as 1 - raw_scores
        dl = DiceLoss(to_onehot_y=True, reduction="none")
        loss = dl(one_hot(y_pred, n_classes), y).squeeze()

        print(1.0 - loss)  # same as raw_scores


    Args:
        include_background: whether to include Dice computation on the first channel/category of the prediction and
            ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
        reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
            available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
            selected, the metric will not do reduction.
        get_not_nans: whether to return the `not_nans` count. If True, aggregate() returns `(metric, not_nans)` where
            `not_nans` counts the number of valid values in the result, and will have the same shape.
        ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
            set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
            are also empty.
        num_classes: number of input channels (always including the background). When this is ``None``,
            ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
            single-channel class indices and the number of classes is not automatically inferred from data.
        return_with_label: whether to return the metrics with label, only works when reduction is "mean_batch".
            If `True`, use "label_{index}" as the key corresponding to C channels; if ``include_background`` is True,
            the index begins at "0", otherwise at "1". It can also take a list of label names.
            The outcome will then be returned as a dictionary.

    TFNinclude_backgroundbool	reductionMetricReduction | strget_not_nansignore_emptynum_classes
int | Nonereturn_with_labelbool | list[str]returnNonec                   sP   t    || _|| _|| _|| _|| _|| _t| jt	j
dd| j| jd| _d S )NFr   r   r   apply_argmaxr   r   )super__init__r   r   r   r   r   r   r
   r   NONEdice_helper)selfr   r   r   r   r   r   	__class__ X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/meandice.pyr   e   s   
	zDiceMetric.__init__y_predtorch.Tensoryc                 C  s.   |  }|dk rtd| d| j||dS )a  
        Compute the dice value using ``DiceHelper``.

        Args:
            y_pred: prediction value, see class docstring for format definition.
            y: ground truth label.

        Raises:
            ValueError: when `y_pred` has fewer than three dimensions.
           zHy_pred should have at least 3 dimensions (batch, channel, spatial), got .r"   r$   )
ndimension
ValueErrorr   )r   r"   r$   dimsr    r    r!   _compute_tensor~   s   zDiceMetric._compute_tensorMetricReduction | str | None0torch.Tensor | tuple[torch.Tensor, torch.Tensor]c           
      C  s   |   }t|tjstdt| dt||p| j\}}| jtj	krg| j
rgi }t| j
trQt|D ]\}}| jsAd|d  nd| }t| d||< q3nt| j
|D ]\}	}t| d||	< qW|}| jrn||fS |S )a  
        Execute reduction and aggregation logic for the output of `compute_dice`.

        Args:
            reduction: defines mode of reduction as enumerated in :py:class:`monai.utils.enums.MetricReduction`.
                By default this will do no reduction.
        z2the data to aggregate must be PyTorch Tensor, got r&   label_r      )
get_buffer
isinstancetorchTensorr)   typer   r   r   
MEAN_BATCHr   r   	enumerater   rounditemzipr   )
r   r   datafnot_nans_fivZ
_label_keykeyr    r    r!   	aggregate   s   
zDiceMetric.aggregate)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r"   r#   r$   r#   r   r#   N)r   r,   r   r-   )
__name__
__module____qualname____doc__r   MEANr   r+   rA   __classcell__r    r    r   r!   r      s    N
r   Tr"   r#   r$   r   r   r   r   r   r   c                 C  s   t |tjdd||d| |dS )a  
    Computes Dice score metric for a batch of predictions. This performs the same computation as
    :py:class:`monai.metrics.DiceMetric`, which is preferrable to use over this function. For input formats, see the
    documentation for that class .

    Args:
        y_pred: input data to compute, typical segmentation model output.
        y: ground truth to compute mean dice metric.
        include_background: whether to include Dice computation on the first channel/category of the prediction and
            ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
        ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
            set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
            are also empty.
        num_classes: number of input channels (always including the background). When this is ``None``,
            ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
            single-channel class indices and the number of classes is not automatically inferred from data.

    Returns:
        Dice scores per batch and per class, (shape: [batch_size, num_classes]).

    Fr   r'   )r
   r   r   )r"   r$   r   r   r   r    r    r!   r	      s   r	   c                   @  sj   e Zd ZdZeddddddedddd	d
ddddddejddddf
d$ddZd%dd Zd&d"d#Z	dS )'r
   a	  
    Compute Dice score between two tensors ``y_pred`` and ``y``. This is used by :py:class:`monai.metrics.DiceMetric`,
    see the documentation for that class for input formats.

    Example:

    .. code-block:: python

        import torch
        from monai.metrics import DiceHelper

        n_classes, batch_size = 5, 16
        spatial_shape = (128, 128, 128)

        y_pred = torch.rand(batch_size, n_classes, *spatial_shape).float()  # predictions
        y = torch.randint(0, n_classes, size=(batch_size, 1, *spatial_shape)).long()  # ground truth

        score, not_nans = DiceHelper(include_background=False, sigmoid=True, softmax=True)(y_pred, y)
        print(score, not_nans)

    Args:
        include_background: whether to include Dice computation on the first channel/category of the prediction and
            ground truth. Defaults to ``True``, use ``False`` to exclude the background class.
        threshold: if ``True`, ``y_pred`` will be thresholded at a value of 0.5. Defaults to False.
        apply_argmax: whether ``y_pred`` are softmax activated outputs. If True, `argmax` will be performed to
            get the discrete prediction. Defaults to the value of ``not threshold``.
        activate: if this and ``threshold` are ``True``, sigmoid activation is applied to ``y_pred`` before
            thresholding. Defaults to False.
        get_not_nans: whether to return the number of not-nan values.
        reduction: defines mode of reduction to the metrics, this will only apply reduction on `not-nan` values. The
            available reduction modes are enumerated by :py:class:`monai.utils.enums.MetricReduction`. If "none", is
            selected, the metric will not do reduction.
        ignore_empty: whether to ignore empty ground truth cases during calculation. If `True`, the `NaN` value will be
            set for an empty ground truth cases, otherwise 1 will be set if the predictions of empty ground truth cases
            are also empty.
        num_classes: number of input channels (always including the background). When this is ``None``,
            ``y_pred.shape[1]`` will be used. This option is useful when both ``y_pred`` and ``y`` are
            single-channel class indices and the number of classes is not automatically inferred from data.
    softmaxz1.5z1.7zUse `apply_argmax` instead.r   )new_namesigmoidzUse `threshold` instead.	thresholdNFTr   bool | Noner   activater   r   r   r   r   r   r   r   c                 C  sf   |	d ur|	}|
d ur|
}|| _ || _|| _|d u r|n|| _|d u r%| n|| _|| _|| _|| _d S rC   )rM   r   r   r   r   rO   r   r   )r   r   rM   r   rO   r   r   r   r   rL   rJ   r    r    r!   r      s   
zDiceHelper.__init__r"   r#   r$   c                 C  s   t |}|dkrdt t || |t |  S | jr(t jtd|jdS |t | }|dkr;t jd|jdS t jd|jdS )a6  
        Compute the dice metric for binary inputs which have only spatial dimensions. This method is called separately
        for each batch item and for each channel of those items.

        Args:
            y_pred: input predictions with shape HW[D].
            y: ground truth with shape HW[D].
        r   g       @nan)deviceg      ?g        )r2   summasked_selectr   tensorfloatrQ   )r   r"   r$   y_oZdenormr    r    r!   compute_channel  s   
	$zDiceHelper.compute_channelr-   c                 C  sz  | j | j}}| jdu r|jd }n| j}|jd dkr%| jdkr%d }}|r4|dkr4tj|ddd}n|rB| jr>t|}|dk}| jrGdnd}g }t	|jd D ]P}g }	|dkr_t	||ndgD ]7}
|jd dkrs||df |
kn|||
f 
 }|jd dkr||df |
kn|||
f }|	| || qb|t|	 qRtj|dd }t|| j\}}| jr||fS |S )	a  
        Compute the metric for the given prediction and ground truth.

        Args:
            y_pred: input predictions with shape (batch_size, num_classes or 1, spatial_dims...).
                the number of channels is inferred from ``y_pred.shape[1]`` when ``num_classes is None``.
            y: ground truth with shape (batch_size, num_classes or 1, spatial_dims...).
        Nr   FT)dimkeepdimg      ?r   )rX   )r   rM   r   shaper2   argmaxrO   rL   r   ranger   appendrW   stack
contiguousr   r   r   )r   r"   r$   Z_apply_argmax
_threshold	n_pred_chZfirst_chr:   bZc_listcZx_predxr;   r<   r    r    r!   __call__,  s0   	

.*zDiceHelper.__call__)r   rN   rM   r   r   rN   rO   r   r   r   r   r   r   r   r   r   rL   rN   rJ   rN   r   r   rB   )r"   r#   r$   r#   r   r-   )
rD   rE   rF   rG   r   r   r5   r   rW   re   r    r    r    r!   r
      s"    (
r
   )TTN)r"   r#   r$   r#   r   r   r   r   r   r   r   r#   )
__future__r   r2   monai.metrics.utilsr   monai.utilsr   r   metricr   __all__r   r	   r
   r    r    r    r!   <module>   s    &