o
    i(                     @  s^   d dl mZ d dlZd dlm  mZ d dlmZ d dl	m
Z
 d dlmZ G dd de
ZdS )    )annotationsN)fftn)_Loss)LossReductionc                      s@   e Zd ZdZddejfd fddZdddZdddZ  Z	S )JukeboxLossa  
    Calculate spectral component based on the magnitude of Fast Fourier Transform (FFT).

    Based on:
        Dhariwal, et al. 'Jukebox: A generative model for music.' https://arxiv.org/abs/2005.00341

    Args:
        spatial_dims: number of spatial dimensions.
        fft_signal_size: signal size in the transformed dimensions. See torch.fft.fftn() for more information.
        fft_norm: {``"forward"``, ``"backward"``, ``"ortho"``} Specifies the normalization mode in the fft. See
            torch.fft.fftn() for more information.

        reduction: {``"none"``, ``"mean"``, ``"sum"``}
            Specifies the reduction to apply to the output. Defaults to ``"mean"``.

            - ``"none"``: no reduction will be applied.
            - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
            - ``"sum"``: the output will be summed.
    Northospatial_dimsintfft_signal_sizetuple[int] | Nonefft_normstr	reductionLossReduction | strreturnNonec                   s>   t  jt|jd || _|| _ttd|d | _|| _	d S )Nr         )
super__init__r   valuer   r
   tuplerangefft_dimr   )selfr   r
   r   r   	__class__ \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/spectral_loss.pyr   +   s
   
zJukeboxLoss.__init__inputtorch.Tensortargetc                 C  sl   |  |}|  |}tj||dd}| jtjjkr| }|S | jtjjkr,|	 }|S | jtj
jkr4	 |S )Nnoner   )_get_fft_amplitudeFmse_lossr   r   MEANr   meanSUMsumNONE)r   r    r"   Zinput_amplitudeZtarget_amplitudelossr   r   r   forward9   s   

zJukeboxLoss.forwardimagesc                 C  s<   t || j| j| jd}tt|d t|d  }|S )z
        Calculate the amplitude of the fourier transformations representation of the images

        Args:
            images: Images that are to undergo fftn

        Returns:
            fourier transformation amplitude
        )sdimnormr   )r   r
   r   r   torchsqrtrealimag)r   r.   Zimg_fft	amplituder   r   r   r$   J   s   
"zJukeboxLoss._get_fft_amplitude)
r   r	   r
   r   r   r   r   r   r   r   )r    r!   r"   r!   r   r!   )r.   r!   r   r!   )
__name__
__module____qualname____doc__r   r'   r   r-   r$   __classcell__r   r   r   r   r      s    
r   )
__future__r   r2   torch.nn.functionalnn
functionalr%   Z	torch.fftr   torch.nn.modules.lossr   monai.utilsr   r   r   r   r   r   <module>   s   