o
    i'                     @  st   d dl mZ d dlZd dlZd dlmZ d dlmZ d dlm	Z	 G dd deZ
G dd	 d	eZG d
d deZdS )    )annotationsN)_Loss)one_hot)LossReductionc                      s:   e Zd ZdZddddejfd fddZdddZ  ZS )AsymmetricFocalTverskyLossa  
    AsymmetricFocalTverskyLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.

    Actually, it's only supported for binary image segmentation now.

    Reimplementation of the Asymmetric Focal Tversky Loss described in:

    - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
    Michael Yeung, Computerized Medical Imaging and Graphics
    Fffffff?g      ?Hz>to_onehot_ybooldeltafloatgammaepsilon	reductionLossReduction | strreturnNonec                   0   t  jt|jd || _|| _|| _|| _dS )a  
        Args:
            to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
            delta : weight of the background. Defaults to 0.7.
            gamma : value of the exponent gamma in the definition of the Focal loss  . Defaults to 0.75.
            epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
        r   Nsuper__init__r   valuer	   r   r   r   selfr	   r   r   r   r   	__class__ a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/unified_focal_loss.pyr   #   
   
z#AsymmetricFocalTverskyLoss.__init__y_predtorch.Tensory_truec                 C  sN  |j d }| jr|dkrtd nt||d}|j |j kr+td|j  d|j  dt|| jd| j }t	t
dt|j }tj|| |d	}tj|d|  |d	}tjd| | |d	}|| j || j|  d| j |  | j  }d|d d d
f  }	d|d d df  td|d d df  | j  }
ttj|	|
gdd	}|S )N   6single channel prediction, `to_onehot_y=True` ignored.num_classes"ground truth has different shape () from input ()      ?   dimr   )shaper	   warningswarnr   
ValueErrortorchclampr   listrangelensumr   powr   meanstack)r   r    r"   	n_pred_chaxistpfnfpZ
dice_classZ	back_diceZ	fore_dicelossr   r   r   forward8   s"   
,4z"AsymmetricFocalTverskyLoss.forward)r	   r
   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r!   r   r!   	__name__
__module____qualname____doc__r   MEANr   rB   __classcell__r   r   r   r   r          r   c                      s:   e Zd ZdZddddejfd fddZdddZ  ZS )AsymmetricFocalLossa  
    AsymmetricFocalLoss is a variant of FocalTverskyLoss, which attentions to the foreground class.

    Actually, it's only supported for binary image segmentation now.

    Reimplementation of the Asymmetric Focal Loss described in:

    - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
    Michael Yeung, Computerized Medical Imaging and Graphics
    Fr   r+   r   r	   r
   r   r   r   r   r   r   c                   r   )a  
        Args:
            to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
            delta : weight of the background. Defaults to 0.7.
            gamma : value of the exponent gamma in the definition of the Focal loss  . Defaults to 0.75.
            epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
        r   Nr   r   r   r   r   r   c   r   zAsymmetricFocalLoss.__init__r    r!   r"   r   c                 C  s   |j d }| jr|dkrtd nt||d}|j |j kr+td|j  d|j  dt|| jd| j }| t	| }t
d|d d df  | j|d d df  }d| j | }|d d df }| j| }ttjtj||gdd	dd	}|S )
Nr#   r$   r%   r'   r(   r)   r*   r   r,   )r/   r	   r0   r1   r   r2   r3   r4   r   logr9   r   r   r:   r8   r;   )r   r    r"   r<   cross_entropyZback_ceZfore_cerA   r   r   r   rB   x   s   
.
"zAsymmetricFocalLoss.forward)
r	   r
   r   r   r   r   r   r   r   r   rC   rD   r   r   r   r   rL   W   rK   rL   c                      s<   e Zd ZdZdddddejfd fddZdddZ  ZS )AsymmetricUnifiedFocalLossa  
    AsymmetricUnifiedFocalLoss is a variant of Focal Loss.

    Actually, it's only supported for binary image segmentation now

    Reimplementation of the Asymmetric Unified Focal Tversky Loss described in:

    - "Unified Focal Loss: Generalising Dice and Cross Entropy-based Losses to Handle Class Imbalanced Medical Image Segmentation",
    Michael Yeung, Computerized Medical Imaging and Graphics
    Fr+   g      ?r   r	   r
   r&   intweightr   r   r   r   r   c                   sZ   t  jt|jd || _|| _|| _|| _|| _t	| j| jd| _
t| j| jd| _dS )a  
        Args:
            to_onehot_y : whether to convert `y` into the one-hot format. Defaults to False.
            num_classes : number of classes, it only supports 2 now. Defaults to 2.
            delta : weight of the background. Defaults to 0.7.
            gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.75.
            epsilon : it defines a very small number each time. simmily smooth value. Defaults to 1e-7.
            weight : weight for each loss function, if it's none it's 0.5. Defaults to None.

        Example:
            >>> import torch
            >>> from monai.losses import AsymmetricUnifiedFocalLoss
            >>> pred = torch.ones((1,1,32,32), dtype=torch.float32)
            >>> grnd = torch.ones((1,1,32,32), dtype=torch.int64)
            >>> fl = AsymmetricUnifiedFocalLoss(to_onehot_y=True)
            >>> fl(pred, grnd)
        r   )r   r   N)r   r   r   r   r	   r&   r   r   rQ   rL   asy_focal_lossr   asy_focal_tversky_loss)r   r	   r&   rQ   r   r   r   r   r   r   r      s   z#AsymmetricUnifiedFocalLoss.__init__r    r!   r"   r   c                 C  sX  |j |j krtd|j  d|j  dt|j dkr)t|j dkr)td|j  |j d dkr>t|| jd}t|| jd}t|| jd krRtd	| jd  |j d }| jrj|dkrdt	d
 nt||d}| 
||}| ||}| j| d| j |  }| jtjjkrt|S | jtjjkr|S | jtjjkrt|S td| j d)a  
        Args:
            y_pred : the shape should be BNH[WD], where N is the number of classes.
                It only supports binary segmentation.
                The input should be the original logits since it will be transformed by
                    a sigmoid in the forward function.
            y_true : the shape should be BNH[WD], where N is the number of classes.
                It only supports binary segmentation.

        Raises:
            ValueError: When input and target are different shape
            ValueError: When len(y_pred.shape) != 4 and len(y_pred.shape) != 5
            ValueError: When num_classes
            ValueError: When the number of classes entered does not match the expected number
        r'   r(   r)         z$input shape must be 4 or 5, but got r#   r%   z*Please make sure the number of classes is r$   zUnsupported reduction: z0, available options are ["mean", "sum", "none"].)r/   r2   r7   r   r&   r3   maxr	   r0   r1   rR   rS   rQ   r   r   SUMr   r8   NONErI   r:   )r   r    r"   r<   rR   rS   rA   r   r   r   rB      s0   


z"AsymmetricUnifiedFocalLoss.forward)r	   r
   r&   rP   rQ   r   r   r   r   r   r   r   rC   rD   r   r   r   r   rO      s    $rO   )
__future__r   r0   r3   torch.nn.modules.lossr   monai.networksr   monai.utilsr   r   rL   rO   r   r   r   r   <module>   s   @: