o
    i                     @  s`   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	 d dl
mZmZ G dd deZdS )	    )annotations)SequenceN)_Loss)
KernelType
SSIMMetric)LossReductionensure_tuple_repc                      s^   e Zd ZdZdejddddejfd" fddZe	d#ddZ
e
jd$ddZ
d%d d!Z  ZS )&SSIMLossa}  
    Compute the loss function based on the Structural Similarity Index Measure (SSIM) Metric.

    For more info, visit
        https://vicuesoft.com/glossary/term/ssim-ms-ssim/

    SSIM reference paper:
        Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
        similarity." IEEE transactions on image processing 13.4 (2004): 600-612.
    g      ?   g      ?g{Gz?gQ?spatial_dimsint
data_rangefloatkernel_typeKernelType | strwin_sizeint | Sequence[int]kernel_sigmafloat | Sequence[float]k1k2	reductionLossReduction | strc	           	   	     s   t  jt|jd || _|| _|| _t|tst	||}|| _
t|ts*t	||}|| _|| _|| _t| j| j| j| j
| j| j| jd| _dS )ab  
        Args:
            spatial_dims: number of spatial dimensions of the input images.
            data_range: value range of input images. (usually 1.0 or 255)
            kernel_type: type of kernel, can be "gaussian" or "uniform".
            win_size: window size of kernel
            kernel_sigma: standard deviation for Gaussian kernel.
            k1: stability constant used in the luminance denominator
            k2: stability constant used in the contrast denominator
            reduction: {``"none"``, ``"mean"``, ``"sum"``}
                Specifies the reduction to apply to the output. Defaults to ``"mean"``.
                - ``"none"``: no reduction will be applied.
                - ``"mean"``: the sum of the output will be divided by the number of elements in the output.
                - ``"sum"``: the output will be summed.

        )r   )r   r   r   r   r   r   r   N)super__init__r   valuer   _data_ranger   
isinstancer   r   kernel_sizer   r   r   r   ssim_metric)	selfr   r   r   r   r   r   r   r   	__class__ X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/losses/ssim_loss.pyr   #   s*   



zSSIMLoss.__init__returnc                 C  s   | j S N)r   )r    r#   r#   r$   r   X   s   zSSIMLoss.data_ranger   Nonec                 C  s   || _ || j_d S r&   )r   r   r   )r    r   r#   r#   r$   r   \   s   inputtorch.Tensortargetc                 C  sV   | j ||dd}d| }| jtjjkrt|}|S | jtj	jkr)t
|}|S )a  
        Args:
            input: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
            target: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])

        Returns:
            1 minus the ssim index (recall this is meant to be a loss function)

        Example:
            .. code-block:: python

                import torch

                # 2D data
                x = torch.ones([1,1,10,10])/2
                y = torch.ones([1,1,10,10])/2
                print(1-SSIMLoss(spatial_dims=2)(x,y))

                # pseudo-3D data
                x = torch.ones([1,5,10,10])/2  # 5 could represent number of slices
                y = torch.ones([1,5,10,10])/2
                print(1-SSIMLoss(spatial_dims=2)(x,y))

                # 3D data
                x = torch.ones([1,1,10,10,10])/2
                y = torch.ones([1,1,10,10,10])/2
                print(1-SSIMLoss(spatial_dims=3)(x,y))
           )r   Z_compute_tensorviewr   r   MEANr   torchmeanSUMsum)r    r(   r*   Z
ssim_valuelossr#   r#   r$   forwarda   s   

zSSIMLoss.forward)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r%   r   )r   r   r%   r'   )r(   r)   r*   r)   r%   r)   )__name__
__module____qualname____doc__r   GAUSSIANr   r.   r   propertyr   setterr4   __classcell__r#   r#   r!   r$   r	      s    5r	   )
__future__r   collections.abcr   r/   torch.nn.modules.lossr   Zmonai.metrics.regressionr   r   monai.utilsr   r   r	   r#   r#   r#   r$   <module>   s   