o
    i-                     @  s   d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	  m
Z d dlmZ d dlmZ d dlmZ G dd	 d	eZ	ddddZ	ddddZdS )    )annotationsN)Sequence)Optional)_Loss)one_hot)LossReductionc                      s>   e Zd ZdZdddddejdfd fddZdddZ  ZS )	FocalLossa  
    FocalLoss is an extension of BCEWithLogitsLoss that down-weights loss from
    high confidence correct predictions.

    Reimplementation of the Focal Loss described in:

        - ["Focal Loss for Dense Object Detection"](https://arxiv.org/abs/1708.02002), T. Lin et al., ICCV 2017
        - "AnatomyNet: Deep learning for fast and fully automated whole-volume segmentation of head and neck anatomy",
          Zhu et al., Medical Physics 2018

    Example:
        >>> import torch
        >>> from monai.losses import FocalLoss
        >>> from torch.nn import BCEWithLogitsLoss
        >>> shape = B, N, *DIMS = 2, 3, 5, 7, 11
        >>> input = torch.rand(*shape)
        >>> target = torch.rand(*shape)
        >>> # Demonstrate equivalence to BCE when gamma=0
        >>> fl_g0_criterion = FocalLoss(reduction='none', gamma=0)
        >>> fl_g0_loss = fl_g0_criterion(input, target)
        >>> bce_criterion = BCEWithLogitsLoss(reduction='none')
        >>> bce_loss = bce_criterion(input, target)
        >>> assert torch.allclose(fl_g0_loss, bce_loss)
        >>> # Demonstrate "focus" by setting gamma > 0.
        >>> fl_g2_criterion = FocalLoss(reduction='none', gamma=2)
        >>> fl_g2_loss = fl_g2_criterion(input, target)
        >>> # Mark easy and hard cases
        >>> is_easy = (target > 0.7) & (input > 0.7)
        >>> is_hard = (target > 0.7) & (input < 0.3)
        >>> easy_loss_g0 = fl_g0_loss[is_easy].mean()
        >>> hard_loss_g0 = fl_g0_loss[is_hard].mean()
        >>> easy_loss_g2 = fl_g2_loss[is_easy].mean()
        >>> hard_loss_g2 = fl_g2_loss[is_hard].mean()
        >>> # Gamma > 0 causes the loss function to "focus" on the hard
        >>> # cases.  IE, easy cases are downweighted, so hard cases
        >>> # receive a higher proportion of the loss.
        >>> hard_to_easy_ratio_g2 = hard_loss_g2 / easy_loss_g2
        >>> hard_to_easy_ratio_g0 = hard_loss_g0 / easy_loss_g0
        >>> assert hard_to_easy_ratio_g2 > hard_to_easy_ratio_g0
    TF       @Ninclude_backgroundboolto_onehot_ygammafloatalphafloat | Noneweight3Sequence[float] | float | int | torch.Tensor | None	reductionLossReduction | struse_softmaxreturnNonec                   sb   t  jt|jd || _|| _|| _|| _|| _|| _	|dur%t
|nd}| d| |  dS )a  
        Args:
            include_background: if False, channel index 0 (background category) is excluded from the loss calculation.
                If False, `alpha` is invalid when using softmax.
            to_onehot_y: whether to convert the label `y` into the one-hot format. Defaults to False.
            gamma: value of the exponent gamma in the definition of the Focal loss. Defaults to 2.
            alpha: value of the alpha in the definition of the alpha-balanced Focal loss.
                The value should be in [0, 1]. Defaults to None.
            weight: weights to apply to the voxels of each class. If None no weights are applied.
                The input can be a single value (same weight for all classes), a sequence of values (the length
                of the sequence should be the same as the number of classes. If not ``include_background``,
                the number of classes should not include the background category class 0).
                The value/values should be no less than 0. Defaults to None.
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.

                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.

            use_softmax: whether to use softmax to transform the original logits into probabilities.
                If True, softmax is used. If False, sigmoid is used. Defaults to False.

        Example:
            >>> import torch
            >>> from monai.losses import FocalLoss
            >>> pred = torch.tensor([[1, 0], [0, 1], [1, 0]], dtype=torch.float32)
            >>> grnd = torch.tensor([[0], [1], [0]], dtype=torch.int64)
            >>> fl = FocalLoss(to_onehot_y=True)
            >>> fl(pred, grnd)
        )r   Nclass_weight)super__init__r   valuer
   r   r   r   r   r   torch	as_tensorregister_buffer)selfr
   r   r   r   r   r   r   	__class__ Y/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/focal_loss.pyr   D   s   )zFocalLoss.__init__inputtorch.Tensortargetc                 C  s*  |j d }| jr|dkrtd nt||d}| js9|dkr%td n|ddddf }|ddddf }|j |j krLtd|j  d|j  dd}| }| }| jrs| jsi| j	durid| _	td	 t
||| j| j	}n	t||| j| j	}|j d }| jdur|dkr| jjd
krt| jg| | _n| jj d
 |krtd| j d
k rtd| j|| _dgdgt|j dd   }| j|| _| j| }| jtjjkrd}|r|jttdt|j d}| }|S | jtjjkr| }|S | jtjjkr	 |S td| j d)a  
        Args:
            input: the shape should be BNH[WD], where N is the number of classes.
                The input should be the original logits since it will be transformed by
                a sigmoid/softmax in the forward function.
            target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

        Raises:
            ValueError: When input and target (after one hot transform if set)
                have different shapes.
            ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
            ValueError: When ``self.weight`` is a sequence and the length is not equal to the
                number of classes.
            ValueError: When ``self.weight`` is/contains a value that is less than 0.

           z6single channel prediction, `to_onehot_y=True` ignored.)num_classesz>single channel prediction, `include_background=False` ignored.Nz"ground truth has different shape (z) from input ()z?`include_background=False`, `alpha` ignored when using softmax.r   zthe length of the `weight` sequence should be the same as the number of classes.
                        If `include_background=False`, the weight should not include
                        the background category class 0.z:the value/values of the `weight` should be no less than 0.   T)dimzUnsupported reduction: z0, available options are ["mean", "sum", "none"].)shaper   warningswarnr   r
   
ValueErrorr   r   r   softmax_focal_lossr   sigmoid_focal_lossr   ndimr   r   mintolenviewr   r   SUMr   meanlistrangesumMEANNONE)r   r$   r&   	n_pred_chlossnum_of_classesbroadcast_dimsZaverage_spatial_dimsr"   r"   r#   forwardx   s`   



zFocalLoss.forward)r
   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r$   r%   r&   r%   r   r%   )	__name__
__module____qualname____doc__r   r=   r   rC   __classcell__r"   r"   r    r#   r      s    +4r   r	   r$   r%   r&   r   r   r   Optional[float]r   c                 C  s   |  d}d|  | | | }|durDtd| g|g|jd d   |}dgdgt|jdd   }||}|| }|S )z
    FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

    where p_i = exp(s_i) / sum_j exp(s_j), t is the target (ground truth) class, and
    s_j is the unnormalized score for class j.
    r'   Nr*   r+   )	log_softmaxexppowr   tensorr-   r5   r6   r7   )r$   r&   r   r   Zinput_lsr@   Z	alpha_facrB   r"   r"   r#   r1      s   
	*
r1   c                 C  sj   | | |  t |  }t |  |d d  }||  | }|dur3|| d| d|   }|| }|S )z|
    FL(pt) = -alpha * (1 - pt)**gamma * log(pt)

    where p = sigmoid(x), pt = p if label is 1 or 1 - p if label is 0
    r+   r'   N)F
logsigmoidrK   )r$   r&   r   r   r@   ZinvprobsZalpha_factorr"   r"   r#   r2      s   r2   )r	   N)
r$   r%   r&   r%   r   r   r   rI   r   r%   )
__future__r   r.   collections.abcr   typingr   r   torch.nn.functionalnn
functionalrN   torch.nn.modules.lossr   monai.networksr   monai.utilsr   r   r1   r2   r"   r"   r"   r#   <module>   s    5