o
    i                     @  sn   d dl mZ d dlmZ d dlZd dlmZ d dlm  mZ	 d dl
mZ d dlmZmZ G dd deZdS )    )annotations)AnyN)_Loss)GaussianFilter
MeanFilterc                      sB   e Zd ZdZ					dd fddZdddZd ddZ  ZS )!NACLLossa  
    Neighbor-Aware Calibration Loss (NACL) is primarily developed for developing calibrated models in image segmentation.
    NACL computes standard cross-entropy loss with a linear penalty that enforces the logit distributions
    to match a soft class proportion of surrounding pixel.

    Murugesan, Balamurali, et al.
    "Trust your neighbours: Penalty-based constraints for model calibration."
    International Conference on Medical Image Computing and Computer-Assisted Intervention, MICCAI 2023.
    https://arxiv.org/abs/2303.06268

    Murugesan, Balamurali, et al.
    "Neighbor-Aware Calibration of Segmentation Networks with Penalty-Based Constraints."
    https://arxiv.org/abs/2401.14487
       meanl1皙?      ?classesintdimkernel_size
kernel_opsstrdistance_typealphafloatsigmareturnNonec                   s   t    |dvrtd|dvrtd| d|dvr$td| || _|| _t | _|| _|| _	|| _
|  |dkrOt||d	| _| jj||  | j_|d
kr\t||d| _dS dS )am  
        Args:
            classes: number of classes
            dim: dimension of data (supports 2d and 3d)
            kernel_size: size of the spatial kernel
            distance_type: l1/l2 distance between spatial kernel and predicted logits
            alpha: weightage between cross entropy and logit constraint
            sigma: sigma of gaussian
        )r	   gaussianz*Kernel ops must be either mean or gaussian)   r   zSupport 2d and 3d, got dim=.)r
   l2z+Distance type must be either L1 or L2, got r	   )spatial_dimssizer   )r   r   N)super__init__
ValueErrorncr   nnCrossEntropyLosscross_entropyr   r   ksr   
svls_layerfilterr   )selfr   r   r   r   r   r   r   	__class__ X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/nacl_loss.pyr    (   s(   

zNACLLoss.__init__masktorch.Tensorc                 C  s   | j dkr tj|tj| jddddd 	 }| 
|}| j dkrAtj|tj| jdddddd 	 }| 
|}|S )a  
        Converts the mask to one hot represenation and is smoothened with the selected spatial filter.

        Args:
            mask: the shape should be BH[WD].

        Returns:
            torch.Tensor: the shape would be BNH[WD], N being number of classes.
        r   )num_classesr   r         )r   Fone_hottotorchint64r"   permute
contiguousr   r'   )r)   r.   Z	oh_labelsrmaskr,   r,   r-   get_constr_targetU   s   
,

.
zNACLLoss.get_constr_targetinputstargetsc                 C  sh   |  ||}| |}| jdkr||  }n| jdkr+||d  }|| j|  }|S )a  
        Computes standard cross-entropy loss and constraints it neighbor aware logit penalty.

        Args:
            inputs: the shape should be BNH[WD], where N is the number of classes.
            targets: the shape should be BH[WD].

        Returns:
            torch.Tensor: value of the loss.

        Example:
            >>> import torch
            >>> from monai.losses import NACLLoss
            >>> B, N, H, W = 8, 3, 64, 64
            >>> input = torch.rand(B, N, H, W)
            >>> target = torch.randint(0, N, (B, H, W))
            >>> criterion = NACLLoss(classes = N, dim = 2)
            >>> loss = criterion(input, target)
        r
   r   r   )r%   r;   r   subabs_r	   pow_r   )r)   r<   r=   Zloss_ceZutargetsZ	loss_conflossr,   r,   r-   forwardk   s   


zNACLLoss.forward)r   r	   r
   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r.   r/   r   r/   )r<   r/   r=   r/   r   r/   )__name__
__module____qualname____doc__r    r;   rB   __classcell__r,   r,   r*   r-   r      s    
-r   )
__future__r   typingr   r6   torch.nnr#   torch.nn.functional
functionalr3   torch.nn.modules.lossr   monai.networks.layersr   r   r   r,   r,   r,   r-   <module>   s   