U
    uPh                     @  s`   d dl mZ d dlZd dlmZ d dlZd dlmZ d dlm	Z	 d dl
mZ G dd deZdS )	    )annotationsN)Callable)_Loss)one_hot)LossReductionc                      sh   e Zd ZdZdddddddejdddfdddddd	d	d
d	d	ddd fddZddddddZ  ZS )TverskyLossaC  
    Compute the Tversky loss defined in:

        Sadegh et al. (2017) Tversky loss function for image segmentation
        using 3D fully convolutional deep networks. (https://arxiv.org/abs/1706.05721)

    Adapted from:
        https://github.com/NifTK/NiftyNet/blob/v0.6.0/niftynet/layer/loss_segmentation.py#L631

    TFNg      ?gh㈵>boolzCallable | NonefloatzLossReduction | strNone)include_backgroundto_onehot_ysigmoidsoftmax	other_actalphabeta	reduction	smooth_nr	smooth_drbatchreturnc                   s   t  jt|jd |dk	r:t|s:tdt|j dt|t| t|dk	 dkrbt	d|| _
|| _|| _|| _|| _|| _|| _t|	| _t|
| _|| _dS )a  
        Args:
            include_background: If False channel index 0 (background category) is excluded from the calculation.
            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
            sigmoid: If True, apply a sigmoid function to the prediction.
            softmax: If True, apply a softmax function to the prediction.
            other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
                other activation layers, Defaults to ``None``. for example:
                `other_act = torch.tanh`.
            alpha: weight of false positives
            beta: weight of false negatives
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.

            smooth_nr: a small constant added to the numerator to avoid zero.
            smooth_dr: a small constant added to the denominator to avoid nan.
            batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
                Defaults to False, a Dice loss value is computed independently from each item in the batch
                before any `reduction`.

        Raises:
            TypeError: When ``other_act`` is not an ``Optional[Callable]``.
            ValueError: When more than 1 of [``sigmoid=True``, ``softmax=True``, ``other_act is not None``].
                Incompatible values.

        )r   Nz*other_act must be None or callable but is .   zXIncompatible values: more than 1 of [sigmoid=True, softmax=True, other_act is not None].)super__init__r   valuecallable	TypeErrortype__name__int
ValueErrorr   r   r   r   r   r   r   r	   r   r   r   )selfr   r   r   r   r   r   r   r   r   r   r   	__class__ I/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/tversky.pyr   $   s    - 

zTverskyLoss.__init__ztorch.Tensor)inputtargetr   c                 C  s  | j rt |}|jd }| jr@|dkr4td nt|d}| jdk	rT| |}| jrz|dkrntd nt||d}| j	s|dkrtd n(|ddddf }|ddddf }|j|jkrt
d|j d|j d	|}d| }|}d| }td
t|j }| jr"dg| }t|| |}	| jt|| | }
| jt|| | }|	| j }|	|
 | | j }d||  }| jtjjkrt|S | jtjjkr|S | jtjjkrt|S td| j ddS )z
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].

        Raises:
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].

        r   z2single channel prediction, `softmax=True` ignored.Nz6single channel prediction, `to_onehot_y=True` ignored.)num_classesz>single channel prediction, `include_background=False` ignored.z"ground truth has differing shape (z) from input ()   r   g      ?zUnsupported reduction: z0, available options are ["mean", "sum", "none"].)r   torchshaper   warningswarnr   r   r   r   AssertionErrorarangelentolistr   sumr   r   r   r   r   r   SUMr   NONEMEANmeanr!   )r"   r'   r(   	n_pred_chp0p1Zg0g1reduce_axistpfpfn	numeratordenominatorscorer%   r%   r&   forwarda   sP    








zTverskyLoss.forward)	r   
__module____qualname____doc__r   r7   r   rD   __classcell__r%   r%   r#   r&   r      s   *=r   )
__future__r   r.   collections.abcr   r,   torch.nn.modules.lossr   monai.networksr   monai.utilsr   r   r%   r%   r%   r&   <module>   s   