o
    i
                     @  s2   d dl mZ d dlZd dlmZ 	ddddZdS )    )annotationsNTinputtorch.Tensortargetreduce_axis	list[int]ordint
soft_labelbool	decoupledreturn/tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s  |dkr;|s;t j| | |d}|r$t j| |d| }t j||d| }nZt j| d|  |d}t jd|  | |d}nCtj| ||d}	tj|||d}
tj| | ||d}|dkrnt j|	|d}	t j|
|d}
t j||d}|	|
 | d }|	| }|
| }|||fS )a  
    Args:
        input: the shape should be BNH[WD], where N is the number of classes.
        target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.
        reduce_axis: the axis to be reduced.
        ord: the order of the vector norm.
        soft_label: whether the target contains non-binary values (soft labels) or not.
            If True a soft label formulation of the loss will be used.
        decoupled: whether the input and the target should be decoupled when computing fp and fn.
            Only for the original implementation when soft_label is False.

    Adapted from:
        https://github.com/zifuwanggg/JDTLosses
       )dim)r   r   )exponent   )torchsumLAvector_normpow)r   r   r   r   r
   r   tpfpfnZpred_oground_o
difference r   T/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/utils.pycompute_tp_fp_fn   s$   
r   )T)r   r   r   r   r   r   r   r	   r
   r   r   r   r   r   )
__future__r   r   Ztorch.linalglinalgr   r   r   r   r   r   <module>   s
   	