o
    i4                     @  sn   d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 dd
dZ
dddZe
edZG dd deZdS )    )annotationsN)_Loss)gaussian_1dseparable_filtering)LossReductionsigmaintreturntorch.Tensorc                 C  s,   | dkrt d|  tt| ddddS )Nr   $expecting positive sigma, got sigma=   sampledF)r   	truncatedapprox	normalize)
ValueErrorr   torchtensorr    r   Z/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/multi_scale.pymake_gaussian_kernel   s   r   c                   sb    dkrt d  t d }t fddt| |d D }t|}|t| }|S )Nr   r      c                   s   g | ]
}|  d  d qS )      r   ).0xr   r   r   
<listcomp>   s    z&make_cauchy_kernel.<locals>.<listcomp>r   )r   r   r   r   range
reciprocalsum)r   tailkr   r   r   make_cauchy_kernel   s   $
r#   )gaussiancauchyc                      s6   e Zd ZdZddejfd fddZdddZ  ZS )MultiScaleLossz
    This is a wrapper class.
    It smooths the input and target at different scales before passing them into the wrapped loss function.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    Nr$   lossr   scaleslist | Nonekernelstr	reductionLossReduction | strr	   Nonec                   sF   t  jt|jd |tvrtd| dt| | _|| _|| _dS )z
        Args:
            loss: loss function to be wrapped
            scales: list of scalars or None, if None, do not apply any scaling.
            kernel: gaussian or cauchy.
        )r,   zgot unsupported kernel type: z only support gaussian and cauchyN)	super__init__r   valuekernel_fn_dictr   	kernel_fnr'   r(   )selfr'   r(   r*   r,   	__class__r   r   r0   1   s   

zMultiScaleLoss.__init__y_truer
   y_predc                 C  s   | j d u r| ||}nDg }| j D ]7}|dkr!|| || q|| t|| ||g|jd  t|| ||g|jd   qtj|dd}| j	t
jjkr^t|}|S | j	t
jjkrlt|}|S | j	t
jjkr|td| j	 d|S )Nr   r   )dimzUnsupported reduction: z0, available options are ["mean", "sum", "none"].)r(   r'   appendr   r3   tondimr   stackr,   r   MEANr1   meanSUMr    NONEr   )r4   r7   r8   r'   	loss_listsr   r   r   forwardE   s,   

  

zMultiScaleLoss.forward)
r'   r   r(   r)   r*   r+   r,   r-   r	   r.   )r7   r
   r8   r
   r	   r
   )	__name__
__module____qualname____doc__r   r>   r0   rD   __classcell__r   r   r5   r   r&   (   s    r&   )r   r   r	   r
   )
__future__r   r   torch.nn.modules.lossr   monai.networks.layersr   r   monai.utilsr   r   r#   r2   r&   r   r   r   r   <module>   s   



