U
    tPh-                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	  m
Z d dlmZ d dlmZ d dlmZ G dd	 d	eZdddddddddZdddddddddZdS )    )annotationsN)Sequence)Optional)_Loss)one_hot)LossReductionc                
      sX   e Zd ZdZdddddejdfddddd	d
ddd fddZddddddZ  ZS )	FocalLossa  
    FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
    high confidence correct predictions.

    Reimplementation of the Focal Loss described in:

        - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017
        - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy",
          Zhu et al., Medical Physics 2018

    Example:
        >>> import torch
        >>> from monai.losses import FocalLoss
        >>> from torch.nn import BCEWithLogitsLoss
        >>> shape = B, N, *DIMS = 2, 3, 5, 7, 11
        >>> input = torch.rand(*shape)
        >>> target = torch.rand(*shape)
        >>> # Demonstrate equivalence to BCE when gamma=0
        >>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0)
        >>> fl_g0_loss = fl_g0_criterion(input, target)
        >>> bce_criterion = BCEWithLogitsLoss(reduction='none')
        >>> bce_loss = bce_criterion(input, target)
        >>> assert torch.allclose(fl_g0_loss, bce_loss)
        >>> # Demonstrate "focus" by setting gamma > 0.
        >>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2)
        >>> fl_g2_loss = fl_g2_criterion(input, target)
        >>> # Mark easy and hard cases
        >>> is_easy = (target > 0.7) & (input > 0.7)
        >>> is_hard = (target > 0.7) & (input < 0.3)
        >>> easy_loss_g0 = fl_g0_loss[is_easy].mean()
        >>> hard_loss_g0 = fl_g0_loss[is_hard].mean()
        >>> easy_loss_g2 = fl_g2_loss[is_easy].mean()
        >>> hard_loss_g2 = fl_g2_loss[is_hard].mean()
        >>> # Gamma > 0 causes the loss function to "focus" on the hard
        >>> # cases.  IE, easy cases are downweighted, so hard cases
        >>> # receive a higher proportion of the loss.
        >>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2
        >>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0
        >>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0
    TF       @Nboolfloatzfloat | Nonez3Sequence[float] | float | int | torch.Tensor | NonezLossReduction | strNone)include_backgroundto_onehot_ygammaalphaweight	reductionuse_softmaxreturnc                   sb   t  jt|jd || _|| _|| _|| _|| _|| _	|dk	rJt
|nd}| d| |  dS )a  
        Args:
            include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
                If False, `alpha` is invalid when using softmax.
            to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
            gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
            alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
                The value should be in [0, 1]. Defaults to None.
            weight: weights to apply to the voxels of each class. If None no weights are applied.
                The input can be a single value (same weight for all classes), a sequence of values (the length
                of the sequence should be the same as the number of classes. If not ``include_background``,
                the number of classes should not include the background category class 0).
                The value/values should be no less than 0. Defaults to None.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.

            use_softmax: whether to use softmax to transform the original logits into probabilities.
                If True, softmax is used. If False, sigmoid is used. Defaults to False.

        Example:
            >>> import torch
            >>> from monai.losses import FocalLoss
            >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
            >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
            >>> fl = FocalLoss(to_onehot_y=True)
            >>> fl(pred, grnd)
        )r   Nclass_weight)super__init__r   valuer   r   r   r   r   r   torch	as_tensorregister_buffer)selfr   r   r   r   r   r   r   	__class__ L/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/focal_loss.pyr   D   s    )zFocalLoss.__init__torch.Tensor)inputtargetr   c                 C  s4  |j d }| jr0|dkr$td nt||d}| jsr|dkrJtd n(|ddddf }|ddddf }|j |j krtd|j  d|j  dd}| }| }| jr| js| j	dk	rd| _	td	 t
||| j| j	}nt||| j| j	}|j d }| jdk	r|dkr| jjd
kr<t| jg| | _n| jj d
 |krVtd| j d
k rntd| j|| _dgdgt|j dd   }| j|| _| j| }| jtjjkrd}|r|jttdt|j d}| }n>| jtjjkr| }n$| jtjjkrntd| j d|S )a  
        Args:
            input: the shape should be BNH[WD], where N is the number of classes.
                The input should be the original logits since it will be transformed by
                a sigmoid/softmax in the forward function.
            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

        Raises:
            ValueError: When input and target (after one hot transform if set)
                have different shapes.
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``self.weight`` is a sequence and the length is not equal to the
                number of classes.
            ValueError: When ``self.weight`` is/contains a value that is less than 0.

           z6single channel prediction, `to_onehot_y=True` ignored.)num_classesz>single channel prediction, `include_background=False` ignored.Nz"ground truth has different shape (z) from input ()z?`include_background=False`, `alpha` ignored when using softmax.r   zthe length of the `weight` sequence should be the same as the number of classes.
                        If `include_background=False`, the weight should not include
                        the background category class 0.z:the value/values of the `weight` should be no less than 0.   T)dimzUnsupported reduction: z0, available options are ["mean", "sum", "none"].)shaper   warningswarnr   r   
ValueErrorr   r   r   softmax_focal_lossr   sigmoid_focal_lossr   ndimr   r   mintolenviewr   r   SUMr   meanlistrangesumMEANNONE)r   r"   r#   	n_pred_chlossnum_of_classesbroadcast_dimsZaverage_spatial_dimsr   r   r    forwardx   s\    





zFocalLoss.forward)	__name__
__module____qualname____doc__r   r:   r   r@   __classcell__r   r   r   r    r      s   +"4r   r	   r!   r   zOptional[float])r"   r#   r   r   r   c                 C  s   |  d}d|  | | | }|dk	rtd| g|g|jd d   |}dgdgt|jdd   }||}|| }|S )z
    FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

    where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
    s_j is the unnormalized score for class j.
    r$   Nr'   r(   )	log_softmaxexppowr   tensorr*   r2   r3   r4   )r"   r#   r   r   Zinput_lsr=   Z	alpha_facr?   r   r   r    r.      s    	
*
r.   c                 C  sj   | | |  t |  }t |  |d d  }||  | }|dk	rf|| d| d|   }|| }|S )z|
    FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

    where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
    r(   r$   N)F
logsigmoidrG   )r"   r#   r   r   r=   ZinvprobsZalpha_factorr   r   r    r/      s    r/   )r	   N)r	   N)
__future__r   r+   collections.abcr   typingr   r   torch.nn.functionalnn
functionalrJ   torch.nn.modules.lossr   monai.networksr   monai.utilsr   r   r.   r/   r   r   r   r    <module>   s     5      