o
    i                     @  sV   d dl mZ d dlmZ d dlZd dlm  mZ d dl	m
Z
 G dd de
ZeZdS )    )annotations)UnionN)_Lossc                      sD   e Zd ZdZdd fddZddddZd ddZd!ddZ  ZS )"DeepSupervisionLossz
    Wrapper class around the main loss function to accept a list of tensors returned from a deeply
    supervised networks. The final loss is computed as the sum of weighted losses for each of deep supervision levels.
    expNlossr   weight_modestrweightslist[float] | NonereturnNonec                   s&   t    || _|| _|| _d| _dS )a  
        Args:
            loss: main loss instance, e.g DiceLoss().
            weight_mode: {``"same"``, ``"exp"``, ``"two"``}
                Specifies the weights calculation for each image level. Defaults to ``"exp"``.
                - ``"same"``: all weights are equal to 1.
                - ``"exp"``: exponentially decreasing weights by a power of 2: 1, 0.5, 0.25, 0.125, etc .
                - ``"two"``: equal smaller weights for lower levels: 1, 0.5, 0.5, 0.5, 0.5, etc
            weights: a list of weights to apply to each deeply supervised sub-loss, if provided, this will be used
                regardless of the weight_mode
        znearest-exactN)super__init__r   r   r
   interp_mode)selfr   r   r
   	__class__ V/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/ds_loss.pyr      s
   

zDeepSupervisionLoss.__init__   levelsintlist[float]c                 C  s   t d|}| jdurt| j|kr| jd| }|S | jdkr&dg| }|S | jdkr6dd t|D }|S | jdkrFd	d t|D }|S dg| }|S )
zG
        Calculates weights for a given number of scale levels
        r   Nsame      ?r   c                 S  s   g | ]	}t d | dqS )      ?g      ?)max.0lr   r   r   
<listcomp>7   s    z3DeepSupervisionLoss.get_weights.<locals>.<listcomp>twoc                 S  s   g | ]
}|d kr
dndqS )r   r   r   r   r   r   r   r   r!   9   s    )r   r
   lenr   range)r   r   r
   r   r   r   get_weights-   s   






zDeepSupervisionLoss.get_weightsinputtorch.Tensortargetc                 C  sD   |j dd |j dd krtj||j dd | jd}| ||S )z
        Calculates a loss output accounting for differences in shapes,
        and downsizing targets if necessary (using nearest neighbor interpolation)
        Generally downsizing occurs for all level, except for the first (level==0)
           N)sizemode)shapeFinterpolater   r   )r   r&   r(   r   r   r   get_loss?   s   zDeepSupervisionLoss.get_loss-Union[None, torch.Tensor, list[torch.Tensor]]c                 C  s   t |ttfr4| jt|d}tjdtj|jd}t	t|D ]}||| | 
||  | 7 }q|S |d u r<td| | |S )N)r   r   )dtypedevicezinput shouldn't be None.)
isinstancelisttupler%   r#   torchtensorfloatr2   r$   r/   
ValueErrorr   )r   r&   r(   r
   r   r    r   r   r   forwardI   s   "zDeepSupervisionLoss.forward)r   N)r   r   r   r	   r
   r   r   r   )r   )r   r   r   r   )r&   r'   r(   r'   r   r'   )r&   r0   r(   r'   r   r'   )	__name__
__module____qualname____doc__r   r%   r/   r:   __classcell__r   r   r   r   r      s    

r   )
__future__r   typingr   r6   torch.nn.functionalnn
functionalr-   torch.nn.modules.lossr   r   ds_lossr   r   r   r   <module>   s   A