U
    tPh                     @  sb   d dl mZ d dlmZ d dlZd dlm  mZ d dl	m
Z
 d dlmZ G dd de
ZeZdS )    )annotations)UnionN)_Loss)pytorch_afterc                      sd   e Zd ZdZdddddd fd	d
ZddddddZddddddZddddddZ  ZS )DeepSupervisionLossz
    Wrapper class around the main loss function to accept a list of tensors returned from a deeply
    supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels.
    expNr   strzlist[float] | NoneNone)lossweight_modeweightsreturnc                   s4   t    || _|| _|| _tddr*dnd| _dS )a  
        Args:
            loss: main loss instance, e.g DiceLoss().
            weight_mode: {``"same"``, ``"exp"``, ``"two"``}
                Specifies the weights calculation for each image level. Defaults to ``"exp"``.
                - ``"same"``: all weights are equal to 1.
                - ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
                - ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
            weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
                regardless of the weight_mode
              znearest-exactnearestN)super__init__r
   r   r   r   interp_mode)selfr
   r   r   	__class__ I/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/ds_loss.pyr      s
    
zDeepSupervisionLoss.__init__r   intzlist[float])levelsr   c                 C  s   t d|}| jdk	r2t| j|kr2| jd| }n\| jdkrHdg| }nF| jdkrfdd t|D }n(| jdkrd	d t|D }n
dg| }|S )
zG
        Calculates weights for a given number of scale levels
        r   Nsame      ?r   c                 S  s   g | ]}t d | dqS )      ?g      ?)max.0lr   r   r   
<listcomp>9   s     z3DeepSupervisionLoss.get_weights.<locals>.<listcomp>twoc                 S  s   g | ]}|d krdndqS )r   r   r   r   r   r   r   r   r"   ;   s     )r   r   lenr   range)r   r   r   r   r   r   get_weights/   s    




zDeepSupervisionLoss.get_weightsztorch.Tensor)inputtargetr   c                 C  sD   |j dd |j dd kr8tj||j dd | jd}| ||S )z
        Calculates a loss output accounting for differences in shapes,
        and downsizing targets if necessary (using nearest neighbor interpolation)
        Generally downsizing occurs for all level, except for the first (level==0)
           N)sizemode)shapeFinterpolater   r
   )r   r'   r(   r   r   r   get_lossA   s    zDeepSupervisionLoss.get_lossz-Union[None, torch.Tensor, list[torch.Tensor]]c                 C  s   t |ttfrh| jt|d}tjdtj|jd}t	t|D ]$}||| | 
||  | 7 }q>|S |d krxtd| | |S )N)r   r   )dtypedevicezinput shouldn't be None.)
isinstancelisttupler&   r$   torchtensorfloatr1   r%   r/   
ValueErrorr
   )r   r'   r(   r   r
   r!   r   r   r   forwardK   s    "zDeepSupervisionLoss.forward)r   N)r   )	__name__
__module____qualname____doc__r   r&   r/   r9   __classcell__r   r   r   r   r      s
   
r   )
__future__r   typingr   r5   torch.nn.functionalnn
functionalr-   torch.nn.modules.lossr   monai.utilsr   r   ds_lossr   r   r   r   <module>   s   A