U
    tPh                     @  s`   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	 d dl
mZmZ G dd deZdS )	    )annotations)SequenceN)_Loss)
KernelType
SSIMMetric)LossReductionensure_tuple_repc                
      s   e Zd ZdZdejddddejfddd	d
ddddd fddZe	ddddZ
e
jdddddZ
ddddddZ  ZS )SSIMLossa}  
    Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.

    For more info, visit
        https://vicuesoft.com/glossary/term/ssim-ms-ssim/

    SSIM reference paper:
        Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
        similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
    g      ?   g      ?g{Gz?gQ?intfloatzKernelType | strzint | Sequence[int]zfloat | Sequence[float]zLossReduction | str)spatial_dims
data_rangekernel_typewin_sizekernel_sigmak1k2	reductionc	           	   	     s   t  jt|jd || _|| _|| _t|ts:t	||}|| _
t|tsTt	||}|| _|| _|| _t| j| j| j| j
| j| j| jd| _dS )ab  
        Args:
            spatial_dims: number of spatial dimensions of the input images.
            data_range: value range of input images. (usually 1.0 or 255)
            kernel_type: type of kernel, can be "gaussian" or "uniform".
            win_size: window size of kernel
            kernel_sigma: standard deviation for Gaussian kernel.
            k1: stability constant used in the luminance denominator
            k2: stability constant used in the contrast denominator
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.
                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.

        )r   )r   r   r   r   r   r   r   N)super__init__r   valuer   _data_ranger   
isinstancer   r   kernel_sizer   r   r   r   ssim_metric)	selfr   r   r   r   r   r   r   r   	__class__ K/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/losses/ssim_loss.pyr   #   s*    



zSSIMLoss.__init__)returnc                 C  s   | j S N)r   )r   r   r   r    r   X   s    zSSIMLoss.data_rangeNone)r   r!   c                 C  s   || _ || j_d S r"   )r   r   r   )r   r   r   r   r    r   \   s    ztorch.Tensor)inputtargetr!   c                 C  sT   | j ||dd}d| }| jtjjkr8t|}n| jtj	jkrPt
|}|S )a  
        Args:
            input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
            target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])

        Returns:
            1 minus the ssim index (recall this is meant to be a loss function)

        Example:
            .. code-block:: python

                import torch

                # 2D data
                x = torch.ones([1,1,10,10])/2
                y = torch.ones([1,1,10,10])/2
                print(1-SSIMLoss(spatial_dims=2)(x,y))

                # pseudo-3D data
                x = torch.ones([1,5,10,10])/2  # 5 could represent number of slices
                y = torch.ones([1,5,10,10])/2
                print(1-SSIMLoss(spatial_dims=2)(x,y))

                # 3D data
                x = torch.ones([1,1,10,10,10])/2
                y = torch.ones([1,1,10,10,10])/2
                print(1-SSIMLoss(spatial_dims=3)(x,y))
           )r   Z_compute_tensorviewr   r   MEANr   torchmeanSUMsum)r   r$   r%   Z
ssim_valuelossr   r   r    forwarda   s    
zSSIMLoss.forward)__name__
__module____qualname____doc__r   GAUSSIANr   r)   r   propertyr   setterr/   __classcell__r   r   r   r    r	      s   "5r	   )
__future__r   collections.abcr   r*   torch.nn.modules.lossr   Zmonai.metrics.regressionr   r   monai.utilsr   r   r	   r   r   r   r    <module>   s   