o
    i!                     @  sp   d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	m
Z
 ddlmZ G dd deZd	ejd
fdddZdS )    )annotationsN)do_metric_reductionignore_background)MetricReductionWeightdeprecated_arglook_up_option   )CumulativeIterationMetricc                      sT   e Zd ZdZdejejfd fddZdddZ	e
dddddddddZ  ZS )GeneralizedDiceScorea  
    Compute the Generalized Dice Score metric between tensors.

    This metric is the complement of the Generalized Dice Loss defined in:
    Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
    loss function for highly unbalanced segmentations. DLMIA 2017.

    The inputs `y_pred` and `y` are expected to be one-hot, binarized batch-first tensors, i.e., NCHW[D].

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        include_background: Whether to include the background class (assumed to be in channel 0) in the
            score computation. Defaults to True.
        reduction: Define mode of reduction to the metrics. Available reduction modes:
            {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
            Default value is changed from `MetricReduction.MEAN_BATCH` to `MetricReduction.MEAN` in v1.5.0.
            Old versions computed `mean` when `mean_batch` was provided due to bug in reduction.
        weight_type: {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
            ground truth volume into a weight factor. Defaults to ``"square"``.

    Raises:
        ValueError: When the `reduction` is not one of MetricReduction enum.
    Tinclude_backgroundbool	reductionMetricReduction | strweight_typeWeight | strreturnNonec                   sH   t    || _t|t| _t|t| _| jtjtj	tj
tjhv | _d S N)super__init__r   r   r   r   r   r   SUMMEANMEAN_CHANNELSUM_CHANNELsum_over_classes)selfr   r   r   	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/generalized_dice.pyr   1   s   
zGeneralizedDiceScore.__init__y_predtorch.Tensoryc                 C  s   t ||| j| j| jdS )a  
        Computes the Generalized Dice Score and returns a tensor with its per image values.

        Args:
            y_pred (torch.Tensor): Binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
                where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
            y (torch.Tensor): Binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.

        Returns:
            torch.Tensor: Generalized Dice Score averaged across batch and class

        Raises:
            ValueError: If `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
        )r!   r#   r   r   r   )compute_generalized_dicer   r   r   )r   r!   r#   r   r   r    _compute_tensorB   s   z$GeneralizedDiceScore._compute_tensorz1.3.3z1.7.0zYReduction will be ignored. Set reduction during init. as gen.dice needs it during compute)sinceremoved
msg_suffixNMetricReduction | str | Nonec                 C  s0   |   }t|tjstdt|| j\}}|S )z
        Execute reduction logic for the output of `compute_generalized_dice`.

        Returns:
            torch.Tensor: Aggregated metric value.

        Raises:
            ValueError: If the data to aggregate is not a PyTorch Tensor.
        z/The data to aggregate must be a PyTorch Tensor.)
get_buffer
isinstancetorchTensor
ValueErrorr   r   )r   r   dataf_r   r   r    	aggregateY   s
   zGeneralizedDiceScore.aggregate)r   r   r   r   r   r   r   r   )r!   r"   r#   r"   r   r"   r   )r   r)   r   r"   )__name__
__module____qualname____doc__r   r   r   SQUAREr   r%   r   r2   __classcell__r   r   r   r    r      s    
r   TFr!   r"   r#   r   r   r   r   r   r   c                 C  s  |   }|dk rtd| d|j| jkr#td| j d|j d|s-t| |d\} }ttd|   }tj||  |d	}tj||d	}tj| |d	}	||	 }
t|t	}|t	j
krct| }n|t	jkrtt| |  }nt| }|D ]}t|}d
||< t|||< q}|rd|| jddd }|
| jddd}|	jddd}	nd||  }|
| }|	}	|| }|d
k}t|	d
k| tjd|jdtjd|jd||< |S )a  
    Computes the Generalized Dice Score and returns a tensor with its per image values.

    Args:
        y_pred (torch.Tensor): Binarized segmentation model output. It should be binarized, in one-hot format
            and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
            remaining are the spatial dimensions.
        y (torch.Tensor): Binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
        include_background: Whether to include score computation on the first channel of the
            predicted output. Defaults to True.
        weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
            transform ground truth volume into a weight factor. Defaults to ``"square"``.
        sum_over_labels (bool): Whether to sum the numerator and denominator across all labels before the final computation.

    Returns:
        torch.Tensor: Per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].

    Raises:
        ValueError: If `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
            or `y_pred` and `y` don't have the same shape.
       zHy_pred should have at least 3 dimensions (batch, channel, spatial), got .z	y_pred - z - and y - z - should have the same shapes.)r!   r#      )dimr   g       @r	   T)r<   keepdimg      ?)deviceg        )r<   r.   shaper   listranger,   sumr   r   SIMPLE
reciprocalfloatr7   	ones_likeisinfmaxwheretensorr?   )r!   r#   r   r   r   dimsreduce_axisintersectionZy_oZy_pred_odenominatorwbinfsnumerdenomZgeneralized_dice_scoreZdenom_zerosr   r   r    r$   s   sJ   




r$   )r!   r"   r#   r"   r   r   r   r   r   r   r   r"   )
__future__r   r,   monai.metrics.utilsr   r   monai.utilsr   r   r   r   metricr
   r   r7   r$   r   r   r   r    <module>   s   `