U
    tPh4                     @  sz   d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 dddd	d
Z
dddddZe
edZG dd deZdS )    )annotationsN)_Loss)gaussian_1dseparable_filtering)LossReductioninttorch.Tensor)sigmareturnc                 C  s,   | dkrt d|  tt| ddddS )Nr   $expecting positive sigma, got sigma=   sampledF)r	   	truncatedapprox	normalize)
ValueErrorr   torchtensorr	    r   M/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/multi_scale.pymake_gaussian_kernel   s    r   c                   sb    dkrt d  t d }t fddt| |d D }t|}|t| }|S )Nr   r      c                   s   g | ]}|  d  d qS )      r   ).0xr   r   r   
<listcomp>   s     z&make_cauchy_kernel.<locals>.<listcomp>r   )r   r   r   r   range
reciprocalsum)r	   tailkr   r   r   make_cauchy_kernel   s    $
r#   )gaussiancauchyc                      sJ   e Zd ZdZddejfdddddd	 fd
dZddddddZ  ZS )MultiScaleLossz
    This is a wrapper class.
    It smooths the input and target at different scales before passing them into the wrapped loss function.

    Adapted from:
        DeepReg (https://github.com/DeepRegNet/DeepReg)
    Nr$   r   zlist | NonestrzLossReduction | strNone)lossscaleskernel	reductionr
   c                   sF   t  jt|jd |tkr,td| dt| | _|| _|| _dS )z
        Args:
            loss: loss function to be wrapped
            scales: list of scalars or None, if None, do not apply any scaling.
            kernel: gaussian or cauchy.
        )r,   zgot unsupported kernel type: z only support gaussian and cauchyN)	super__init__r   valuekernel_fn_dictr   	kernel_fnr)   r*   )selfr)   r*   r+   r,   	__class__r   r   r.   1   s    
zMultiScaleLoss.__init__r   )y_truey_predr
   c                 C  s   | j d kr| ||}ng }| j D ]n}|dkrB|| || q"|| t|| ||g|jd  t|| ||g|jd   q"tj|dd}| j	t
jjkrt|}n:| j	t
jjkrt|}n | j	t
jjkrtd| j	 d|S )Nr   r   )dimzUnsupported reduction: z0, available options are ["mean", "sum", "none"].)r*   r)   appendr   r1   tondimr   stackr,   r   MEANr/   meanSUMr    NONEr   )r2   r5   r6   r)   	loss_listsr   r   r   forwardE   s(    

  zMultiScaleLoss.forward)	__name__
__module____qualname____doc__r   r<   r.   rB   __classcell__r   r   r3   r   r&   (   s   r&   )
__future__r   r   torch.nn.modules.lossr   monai.networks.layersr   r   monai.utilsr   r   r#   r0   r&   r   r   r   r   <module>   s   

