o
    
i}                     @  sp   d dl mZ d dlZd dlmZ d dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZ G d	d
 d
eZdS )    )annotationsN)CrossEntropyLoss)
functional)_Loss)DiceLoss)SobelGradients)HoVerNetBranchc                      sN   e Zd ZdZ						dd fddZdddZdddZdddZ  ZS ) HoVerNetLossa:  
    Loss function for HoVerNet pipeline, which is combination of losses across the three branches.
    The NP (nucleus prediction) branch uses Dice + CrossEntropy.
    The HV (Horizontal and Vertical) distance from centroid branch uses MSE + MSE of the gradient.
    The NC (Nuclear Class prediction) branch uses Dice + CrossEntropy
    The result is a weighted sum of these losses.

    Args:
        lambda_hv_mse: Weight factor to apply to the HV regression MSE part of the overall loss
        lambda_hv_mse_grad: Weight factor to apply to the MSE of the HV gradient part of the overall loss
        lambda_np_ce: Weight factor to apply to the nuclei prediction CrossEntropyLoss part
            of the overall loss
        lambda_np_dice: Weight factor to apply to the nuclei prediction DiceLoss part of overall loss
        lambda_nc_ce: Weight factor to apply to the nuclei class prediction CrossEntropyLoss part
            of the overall loss
        lambda_nc_dice: Weight factor to apply to the nuclei class prediction DiceLoss part of the
            overall loss

           @      ?lambda_hv_msefloatlambda_hv_mse_gradlambda_np_celambda_np_dicelambda_nc_celambda_nc_dicereturnNonec                   sn   || _ || _|| _|| _|| _|| _t   tdddddd| _	t
dd| _tddd	| _tdd
d	| _d S )NTgMbP?sum)softmaxZ	smooth_drZ	smooth_nr	reductionbatchmean)r      r   )kernel_sizespatial_axes   )r   r   r   r   r   r   super__init__r   dicer   cer   sobel_vsobel_h)selfr   r   r   r   r   r   	__class__ k/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/pathology/losses/hovernet_loss.pyr   -   s   	
zHoVerNetLoss.__init__imagetorch.Tensorc                 C  s>   |  |dddf }| |dddf }tj||gddS )ai  Compute the Sobel gradients of the horizontal vertical map (HoVerMap).
        More specifically, it will compute horizontal gradient of the input horizontal gradient map (channel=0) and
        vertical gradient of the input vertical gradient map (channel=1).

        Args:
            image: a tensor with the shape of BxCxHxW representing HoVerMap

        Nr   r   dim)r#   r"   torchstack)r$   r)   Zresult_hZresult_vr'   r'   r(   _compute_sobelC   s   	zHoVerNetLoss._compute_sobel
predictiontargetfocusc                 C  sb   |  |}|  |}|| }|ddddf }t||fd}|||  }| | d  }|S )z[Compute the MSE loss of the gradients of the horizontal and vertical centroid distance mapsN.r   g:0yE>)r/   r-   catr   )r$   r0   r1   r2   Z	pred_gradZ	true_gradlossr'   r'   r(   _mse_gradient_lossP   s   

zHoVerNetLoss._mse_gradient_lossdict[str, torch.Tensor]c                 C  s  t jj|v rt jj|v stdt jj|v rt jj|v s tdt jj|vr0t jj|v r0tdt jj|v r@t jj|vr@td| |t jj |t jj | j }|t jj jdd}| 	|t jj || j
 }|| }t|t jj |t jj | j }| |t jj |t jj |t jj dddf | j }|| }	d}
t jj|v r| |t jj |t jj | j }|t jj jdd}| 	|t jj || j }|| }
|	| |
 }|S )a!  
        Args:
            prediction: dictionary of predicted outputs for three branches,
                each of which should have the shape of BNHW.
            target: dictionary of ground truths for three branches,
                each of which should have the shape of BNHW.
        zrnucleus prediction (NP) and horizontal_vertical (HV) branches must be present for prediction and target parametersz_type_prediction (NC) must be present in both or neither of the prediction and target parametersr   r+   Nr   )r   NPvalueHV
ValueErrorNCr    r   argmaxr!   r   Fmse_lossr   r5   r   r   r   )r$   r0   r1   Zdice_loss_npZargmax_targetZ
ce_loss_npZloss_npZloss_hv_mseZloss_hv_mse_gradZloss_hvZloss_ncZdice_loss_ncZ
ce_loss_ncr4   r'   r'   r(   forwardb   sR   	  

 zHoVerNetLoss.forward)r
   r   r   r   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r)   r*   r   r*   )r0   r*   r1   r*   r2   r*   r   r*   )r0   r6   r1   r6   r   r*   )	__name__
__module____qualname____doc__r   r/   r5   r?   __classcell__r'   r'   r%   r(   r	      s    

r	   )
__future__r   r-   torch.nnr   r   r=   Ztorch.nn.modules.lossr   monai.lossesr   monai.transformsr   monai.utils.enumsr   r	   r'   r'   r'   r(   <module>   s   