U
    PhI                      @  sv   d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	 ddl
mZ G dd deZd	ejfd
d
ddd
dddZdS )    )annotationsN)do_metric_reductionignore_background)MetricReductionWeightlook_up_option   )CumulativeIterationMetricc                      s\   e Zd ZdZdejejfddddd fdd	Zd
d
d
dddZ	ddd
dddZ
  ZS )GeneralizedDiceScorea  Compute the Generalized Dice Score metric between tensors, as the complement of the Generalized Dice Loss defined in:

    Sudre, C. et. al. (2017) Generalised Dice overlap as a deep learning
        loss function for highly unbalanced segmentations. DLMIA 2017.

    The inputs `y_pred` and `y` are expected to be one-hot, binarized channel-first
    or batch-first tensors, i.e., CHW[D] or BCHW[D].

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        include_background (bool, optional): whether to include the background class (assumed to be in channel 0), in the
            score computation. Defaults to True.
        reduction (str, optional): define mode of reduction to the metrics. Available reduction modes:
            {``"none"``, ``"mean_batch"``, ``"sum_batch"``}. Default to ``"mean_batch"``. If "none", will not do reduction.
        weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
            ground truth volume into a weight factor. Defaults to ``"square"``.

    Raises:
        ValueError: when the `weight_type` is not one of {``"none"``, ``"mean"``, ``"sum"``}.
    TboolzMetricReduction | strWeight | strNone)include_background	reductionweight_typereturnc                   sT   t    || _dddtjtjtjg}|| _| j|krDtd| t	|t
| _d S )Nnone
mean_batch	sum_batchreduction must be one of )super__init__r   r   NONE
MEAN_BATCH	SUM_BATCHr   
ValueErrorr   r   r   )selfr   r   r   reduction_options	__class__ S/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/generalized_dice.pyr   -   s    

zGeneralizedDiceScore.__init__torch.Tensor)y_predyr   c                 C  s   t ||| j| jdS )af  Computes the Generalized Dice Score and returns a tensor with its per image values.

        Args:
            y_pred (torch.Tensor): binarized segmentation model output. It must be in one-hot format and in the NCHW[D] format,
                where N is the batch dimension, C is the channel dimension, and the remaining are the spatial dimensions.
            y (torch.Tensor): binarized ground-truth. It must be in one-hot format and have the same shape as `y_pred`.

        Raises:
            ValueError: if `y_pred` and `y` have less than 3 dimensions, or `y_pred` and `y` don't have the same shape.
        )r#   r$   r   r   )compute_generalized_dicer   r   )r   r#   r$   r    r    r!   _compute_tensorB   s       z$GeneralizedDiceScore._compute_tensorNzMetricReduction | str | None)r   r   c                 C  s`   |   }t|tjstd|dk	rHdddddg}||krHtd| t||pT| j\}}|S )	a  
        Execute reduction logic for the output of `compute_generalized_dice`.

        Args:
            reduction (Union[MetricReduction, str, None], optional): define mode of reduction to the metrics.
                Available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``}.
                Defaults to ``"mean"``. If "none", will not do reduction.
        z/The data to aggregate must be a PyTorch Tensor.Nr   meansumr   r   r   )
get_buffer
isinstancetorchTensorr   r   r   )r   r   datar   f_r    r    r!   	aggregateQ   s    	zGeneralizedDiceScore.aggregate)N)__name__
__module____qualname____doc__r   r   r   SQUAREr   r&   r0   __classcell__r    r    r   r!   r
      s   r
   Tr"   r   r   )r#   r$   r   r   r   c                 C  s  |   }|dk r td| d|j| jkrFtd| j d|j d|sZt| |d\} }ttd|   }tj||  |d	}tj||d	}tj| |d	}|| }	t|t	}|t	j
krt| }
n0|t	jkrt| |  }
nt| }
|
D ]$}t|}d
||< t|||< qd||
 jdd	 }|	|
 jdd	}|| }|jdd	}|d
k}t|d
k| tjd|jdtjd|jd||< |S )a  Computes the Generalized Dice Score and returns a tensor with its per image values.

    Args:
        y_pred (torch.Tensor): binarized segmentation model output. It should be binarized, in one-hot format
            and in the NCHW[D] format, where N is the batch dimension, C is the channel dimension, and the
            remaining are the spatial dimensions.
        y (torch.Tensor): binarized ground-truth. It should be binarized, in one-hot format and have the same shape as `y_pred`.
        include_background (bool, optional): whether to include score computation on the first channel of the
            predicted output. Defaults to True.
        weight_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to
            transform ground truth volume into a weight factor. Defaults to ``"square"``.

    Returns:
        torch.Tensor: per batch and per class Generalized Dice Score, i.e., with the shape [batch_size, num_classes].

    Raises:
        ValueError: if `y_pred` or `y` are not PyTorch tensors, if `y_pred` and `y` have less than three dimensions,
            or `y_pred` and `y` don't have the same shape.
       zHy_pred should have at least 3 dimensions (batch, channel, spatial), got .z	y_pred - z - and y - z - should have the same shapes.)r#   r$      )dimr   g       @r   g      ?)deviceg        )r:   r   shaper   listranger+   r(   r   r   SIMPLE
reciprocalfloatr5   	ones_likeisinfmaxwheretensorr<   )r#   r$   r   r   dimsreduce_axisintersectionZy_oZy_pred_odenominatorwbinfsnumerdenomZgeneralized_dice_scoreZdenom_zerosr    r    r!   r%   j   sB    




r%   )
__future__r   r+   monai.metrics.utilsr   r   monai.utilsr   r   r   metricr	   r
   r5   r%   r    r    r    r!   <module>   s   U 