o
    i                     @  sf   d dl mZ d dlZd dlZd dlmZ d dlmZmZ d dl	Z	d dl
mZ dgZG dd deZdS )    )annotationsN)Callable)AnyOptional)_Loss
MaskedLossc                      s.   e Zd ZdZd fd	d
ZddddZ  ZS )r   z
    This is a wrapper class for the loss functions.  It allows for additional
    weighting masks to be applied to both input and target.

    See Also:
        - :py:class:`monai.losses.MaskedDiceLoss`
    loss<Callable[[torch.Tensor, torch.Tensor], torch.Tensor] | _Loss	loss_argsr   loss_kwargsreturnNonec                   s>   t    t|r||i |n|| _t| jstddS )a?  
        Args:
            loss: loss function to be wrapped, this could be a loss class or an instance of a loss class.
            loss_args: arguments to the loss function's constructor if `loss` is a class.
            loss_kwargs: keyword arguments to the loss function's constructor if `loss` is a class.
        z"The loss function is not callable.N)super__init__inspectisclassr   callable
ValueError)selfr   r
   r   	__class__ [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/spatial_mask.pyr   "   s   
	
zMaskedLoss.__init__Ninputtorch.TensortargetmaskOptional[torch.Tensor]c                 C  s   |du rt d | ||S | | kr%t d|j d|j d |jd |jd krC|jd dkrCtd|j d	|j d| dkru|jd dkrYtd
|j d|jdd |jdd krut d|j d|j d | || || S )z
        Args:
            input: the shape should be BNH[WD].
            target: the shape should be BNH[WD].
            mask: the shape should be B1H[WD] or 11H[WD].
        Nz+No mask value specified for the MaskedLoss.zDim of input (z) is different from mask (z).r      zBatch size of mask (z!) must be one or equal to input (zMask (z) must have only one channel.   zSpatial size of input ()warningswarnr   dimshaper   )r   r   r   r   r   r   r   forward2   s   
"zMaskedLoss.forward)r   r	   r
   r   r   r   r   r   )N)r   r   r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r$   __classcell__r   r   r   r   r      s    )
__future__r   r   r    collections.abcr   typingr   r   torchtorch.nn.modules.lossr   __all__r   r   r   r   r   <module>   s   