U
    Phjf                     @  s  d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 d dl
mZ d dlZd dlm  mZ d dlmZ d dlmZmZmZmZ d d	lmZ d
dlmZ G dd deZG dd deZG dd deZG dd deZG dd deZ dddddddZ!G dd deZ"G dd deZ#ddd d!dd"d#d$Z$d%e"j%d&d'fdddd d!d(d)d(d(d*d+
d,d-Z&G d.d/ d/eZ'd%e"j%d0d1d&d'd2fdddd(d)d3d4d(d(d!dd5d6d7Z(dS )8    )annotationsN)abstractmethod)CallableSequence)partial)Any)do_metric_reduction)MetricReductionStrEnumconvert_data_typeensure_tuple_rep)convert_to_dst_type   )CumulativeIterationMetricc                      s~   e Zd ZdZejdfdddd fddZdd
ddddZddddddZe	ddddddZ
ddddddZ  ZS )RegressionMetrica  
    Base class for regression metrics.
    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.
    `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.

    FMetricReduction | strboolNone	reductionget_not_nansreturnc                   s   t    || _|| _d S N)super__init__r   r   selfr   r   	__class__ M/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/metrics/regression.pyr   0   s    
zRegressionMetric.__init__NzMetricReduction | str | Nonez0torch.Tensor | tuple[torch.Tensor, torch.Tensor])r   r   c                 C  sB   |   }t|tjstdt||p(| j\}}| jr>||fS |S )ao  
        Args:
            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
                available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
                ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
        z-the data to aggregate must be PyTorch Tensor.)
get_buffer
isinstancetorchTensor
ValueErrorr   r   r   )r   r   datafnot_nansr   r   r    	aggregate5   s
    	zRegressionMetric.aggregatetorch.Tensory_predyr   c                 C  s@   |j |j kr&td|j  d|j  dt|j dk r<tdd S )Nz2y_pred and y shapes dont match, received y_pred: [z
] and y: []   zIeither channel or spatial dimensions required, found only batch dimension)shaper%   lenr   r,   r-   r   r   r    _check_shapeE   s    zRegressionMetric._check_shapec                 C  s   t d| jj dd S )Nz	Subclass z must implement this method.)NotImplementedErrorr   __name__r2   r   r   r    _compute_metricM   s    z RegressionMetric._compute_metricc                 C  s8   t |tjrt |tjs td| || | ||S )Nz$y_pred and y must be PyTorch Tensor.)r"   r#   r$   r%   r3   r6   r2   r   r   r    _compute_tensorQ   s    z RegressionMetric._compute_tensor)N)r5   
__module____qualname____doc__r	   MEANr   r)   r3   r   r6   r7   __classcell__r   r   r   r    r      s    r   c                      sD   e Zd ZdZejdfdddd fddZd	d	d	d
ddZ  ZS )	MSEMetrica  Compute Mean Squared Error between two tensors using function:

    .. math::
        \operatorname {MSE}\left(Y, \hat{Y}\right) =\frac {1}{n}\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}.

    More info: https://en.wikipedia.org/wiki/Mean_squared_error

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fr   r   r   r   c                   s$   t  j||d ttjdd| _d S Nr   r   g       @)exponentr   r   r   r#   powsq_funcr   r   r   r    r   m   s    zMSEMetric.__init__r*   r+   c                 C  s   t ||| jdS Nfunc)compute_mean_error_metricsrC   r2   r   r   r    r6   q   s    zMSEMetric._compute_metric	r5   r8   r9   r:   r	   r;   r   r6   r<   r   r   r   r    r=   X   s   r=   c                      sD   e Zd ZdZejdfdddd fddZd	d	d	d
ddZ  ZS )	MAEMetrica  Compute Mean Absolute Error between two tensors using function:

    .. math::
        \operatorname {MAE}\left(Y, \hat{Y}\right) =\frac {1}{n}\sum _{i=1}^{n}\left|y_i-\hat{y_i}\right|.

    More info: https://en.wikipedia.org/wiki/Mean_absolute_error

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fr   r   r   r   c                   s   t  j||d tj| _d S Nr?   )r   r   r#   absabs_funcr   r   r   r    r      s    zMAEMetric.__init__r*   r+   c                 C  s   t ||| jdS rD   )rG   rL   r2   r   r   r    r6      s    zMAEMetric._compute_metricrH   r   r   r   r    rI   u   s   rI   c                      sD   e Zd ZdZejdfdddd fddZd	d	d	d
ddZ  ZS )
RMSEMetrica/  Compute Root Mean Squared Error between two tensors using function:

    .. math::
        \operatorname {RMSE}\left(Y, \hat{Y}\right) ={ \sqrt{ \frac {1}{n}\sum _{i=1}^{n}\left(y_i-\hat{y_i}\right)^2 } } \
        = \sqrt {\operatorname{MSE}\left(Y, \hat{Y}\right)}.

    More info: https://en.wikipedia.org/wiki/Root-mean-square_deviation

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fr   r   r   r   c                   s$   t  j||d ttjdd| _d S r>   rA   r   r   r   r    r      s    zRMSEMetric.__init__r*   r+   c                 C  s   t ||| jd}t|S rD   )rG   rC   r#   sqrtr   r,   r-   Zmse_outr   r   r    r6      s    zRMSEMetric._compute_metricrH   r   r   r   r    rM      s   rM   c                      sF   e Zd ZdZejdfddddd fdd	Zd
d
ddddZ  ZS )
PSNRMetrica_  Compute Peak Signal To Noise Ratio between two tensors using function:

    .. math::
        \operatorname{PSNR}\left(Y, \hat{Y}\right) = 20 \cdot \log_{10} \left({\mathit{MAX}}_Y\right) \
        -10 \cdot \log_{10}\left(\operatorname{MSE\left(Y, \hat{Y}\right)}\right)

    More info: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

    Help taken from:
    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/image_ops_impl.py line 4139

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        max_val: The dynamic range of the images/volumes (i.e., the difference between the
            maximum and the minimum allowed values e.g. 255 for a uint8 image).
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fzint | floatr   r   r   )max_valr   r   r   c                   s*   t  j||d || _ttjdd| _d S r>   )r   r   rQ   r   r#   rB   rC   )r   rQ   r   r   r   r   r    r      s    zPSNRMetric.__init__r*   r   r+   c                 C  s.   t ||| jd}dt| j dt|  S )NrE      
   )rG   rC   mathlog10rQ   r#   rO   r   r   r    r6      s    zPSNRMetric._compute_metricrH   r   r   r   r    rP      s
    rP   r*   r   )r,   r-   rF   r   c                 C  s*   t tjdd}tj||||  dddS )Nr   )	start_dimT)dimkeepdim)r   r#   flattenmean)r,   r-   rF   fltr   r   r    rG      s    rG   c                   @  s   e Zd ZdZdZdS )
KernelTypegaussianuniformN)r5   r8   r9   GAUSSIANUNIFORMr   r   r   r    r]      s   r]   c                      s`   e Zd ZdZdejddddejdfdd	d
ddd	d	dddd
 fddZddddddZ	  Z
S )
SSIMMetrica  
    Computes the Structural Similarity Index Measure (SSIM).

    .. math::
        \operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \
                \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}

    For more info, visit
        https://vicuesoft.com/glossary/term/ssim-ms-ssim/

    SSIM reference paper:
        Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
        similarity." IEEE transactions on image processing 13.4 (2004): 600-612.

    Args:
        spatial_dims: number of spatial dimensions of the input images.
        data_range: value range of input images. (usually 1.0 or 255)
        kernel_type: type of kernel, can be "gaussian" or "uniform".
        win_size: window size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
        k1: stability constant used in the luminance denominator
        k2: stability constant used in the contrast denominator
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)
          ?         ?{Gz?Q?FintfloatKernelType | strint | Sequence[int]float | Sequence[float]r   r   r   )
spatial_dims
data_rangekernel_typewin_sizekernel_sigmak1k2r   r   r   c
           
        sf   t  j||	d || _|| _|| _t|ts6t||}|| _t|tsPt||}|| _	|| _
|| _d S rJ   )r   r   rm   rn   ro   r"   r   r   kernel_sizerq   rr   rs   )
r   rm   rn   ro   rp   rq   rr   rs   r   r   r   r   r    r     s    



zSSIMMetric.__init__r*   r+   c                 C  s   |  }| jdkr2|dkr2td| j d| d| jdkr\|dkr\td| j d| dt||| j| j| j| j| j| j| j	d		\}}|
|jd
 djddd}|S )aR  
        Args:
            y_pred: Predicted image.
                It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
            y: Reference image.
                It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].

        Raises:
            ValueError: when `y_pred` is not a 2D or 3D image.
        r/      Ky_pred should have 4 dimensions (batch, channel, height, width) when using  spatial dimensions, got .      zRy_pred should have 5 dimensions (batch, channel, height, width, depth) when using 	r,   r-   rm   rn   ro   rt   rq   rr   rs   r   rW   r   TrY   )
ndimensionrm   r%   compute_ssim_and_csrn   ro   rt   rq   rr   rs   viewr0   r[   )r   r,   r-   dimsssim_value_full_image_Zssim_per_batchr   r   r    r6     s2    
 zSSIMMetric._compute_metricr5   r8   r9   r:   r]   r`   r	   r;   r   r6   r<   r   r   r   r    rb      s   &rb   rh   zSequence[int]Sequence[float])rm   num_channelsrt   rq   r   c           
   	   C  s   dddddd}||d |d }||d |d }t | |}|d|d |d f}| d	kr||d
 |d
 d }	t |ddd|d
 |	|d |d |d
 }|d|d |d |d
 f}||S )a  Computes 2D or 3D gaussian kernel.

    Args:
        spatial_dims: number of spatial dimensions of the input images.
        num_channels: number of channels in the image
        kernel_size: size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
    rh   ri   r*   )rt   sigmar   c                 S  sP   t jd|  d d|  d dd}t t || d d }||  jddS )zComputes 1D gaussian kernel.

        Args:
            kernel_size: size of the gaussian kernel
            sigma: Standard deviation of the gaussian kernel
        r   r/   )startendstepr   rX   )r#   arangeexprB   sum	unsqueeze)rt   r   distgaussr   r   r    gaussian_1dU  s     z%_gaussian_kernel.<locals>.gaussian_1dr   r   ry   r/   r   rW   )r#   matmultmulr   repeatexpand)
rm   r   rt   rq   r   Zgaussian_kernel_xZgaussian_kernel_ykernelZkernel_dimensionsZgaussian_kernel_zr   r   r    _gaussian_kernelI  s    r   rc   rf   rg   ri   rj   z!tuple[torch.Tensor, torch.Tensor])
r,   r-   rm   rt   rq   rn   ro   rr   rs   r   c	                 C  s  |j | j kr&td| j  d|j  dt| tjtjdd } t|tjtjdd }| d}	|tjkrvt	||	||}
n,|tj
krt|	df|tt| }
t|
| dd }
|| d }|| d }ttd	| d
}|| |
|	d}|||
|	d}|| |  |
|	d}||| |
|	d}|| | |
|	d}|||  }|||  }|||  }d| | || |  }d| | | |d |d  |  | }||fS )a  
    Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch
    of images.

    Args:
        y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
        y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
        kernel_size: the size of the kernel to use for the SSIM computation.
        kernel_sigma: the standard deviation of the kernel to use for the SSIM computation.
        spatial_dims: number of spatial dimensions of the images (2, 3)
        data_range: the data range of the images.
        kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform".
        k1: the first stability constant.
        k2: the second stability constant.

    Returns:
        ssim: the Structural Similarity Index Measure score for the batch of images.
        cs: the Contrast Sensitivity for the batch of images.
    z*y_pred and y should have same shapes, got z and rx   )output_typedtyper   r   )srcdstr/   convd)groups)r0   r%   r   r#   r$   ri   sizer]   r`   r   ra   onesprodtensorr   getattrF)r,   r-   rm   rt   rq   rn   ro   rr   rs   r   r   c1c2Zconv_fnmu_xmu_yZmu_xxZmu_yyZmu_xysigma_xsigma_yZsigma_xyZcontrast_sensitivityr   r   r   r    r~   q  s0    


"(r~   c                      sd   e Zd ZdZdejdddddejdf	d	d
dddd
d
ddddd fddZddddddZ	  Z
S )MultiScaleSSIMMetricaL  
    Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM).

    MS-SSIM reference paper:
        Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. "Multiscale structural
        similarity for image quality assessment." In The Thirty-Seventh Asilomar Conference
        on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). IEEE

    Args:
        spatial_dims: number of spatial dimensions of the input images.
        data_range: value range of input images. (usually 1.0 or 255)
        kernel_type: type of kernel, can be "gaussian" or "uniform".
        kernel_size: size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
        k1: stability constant used in the luminance denominator
        k2: stability constant used in the contrast denominator
        weights: parameters for image similarity and contrast sensitivity at different resolution scores.
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)
    rc   rd   re   rf   rg   gǺ?g48EG?ga4?g??g9EGr?Frh   ri   rj   rk   rl   r   r   r   r   )rm   rn   ro   rt   rq   rr   rs   weightsr   r   r   c                   sl   t  j|	|
d || _|| _|| _t|ts6t||}|| _t|tsPt||}|| _	|| _
|| _|| _d S rJ   )r   r   rm   rn   ro   r"   r   r   rt   rq   rr   rs   r   )r   rm   rn   ro   rt   rq   rr   rs   r   r   r   r   r   r    r     s    



zMultiScaleSSIMMetric.__init__r*   r+   c                 C  s,   t ||| j| j| j| j| j| j| j| jd
S )N)
r,   r-   rm   rn   ro   rt   rq   rr   rs   r   )	compute_ms_ssimrm   rn   ro   rt   rq   rr   rs   r   r2   r   r   r    r6     s    z$MultiScaleSSIMMetric._compute_metricr   r   r   r   r    r     s   (r   rd   re   r   rk   rl   )r,   r-   rm   rn   ro   rt   rq   rr   rs   r   r   c
                 C  s   |   }
|dkr.|
dkr.td| d|
 d|dkrT|
dkrTtd| d|
 dt|tsht||}t|ts|t||}td	t|	d	 d }| jdd
 }tt|D ]L}|| | || d	 krtdt|	 d||  d|| d	 |  dqt	j
|	| jt	jd}ttd| d}g }tt|D ]d}t| ||||||||d	\}}||jd dd	}|t	| || dd} ||dd}q0||jd dd	}t	||d< t	|}t	j||dd	 dd}||jd djd	dd}|S )aQ  
    Args:
        y_pred: Predicted image.
            It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
        y: Reference image.
            It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
        spatial_dims: number of spatial dimensions of the input images.
        data_range: value range of input images. (usually 1.0 or 255)
        kernel_type: type of kernel, can be "gaussian" or "uniform".
        kernel_size: size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
        k1: stability constant used in the luminance denominator
        k2: stability constant used in the contrast denominator
        weights: parameters for image similarity and contrast sensitivity at different resolution scores.
    Raises:
        ValueError: when `y_pred` is not a 2D or 3D image.
    r/   ru   rv   rw   rx   ry   rz   zRy_pred should have 4 dimensions (batch, channel, height, width, depth) when using r   Nz+For a given number of `weights` parameters z and kernel size z', the image height must be larger than )devicer   avg_poolr   r{   r   rW   )rt   r   Tr|   )r}   r%   r"   r   r   maxr1   r0   ranger#   r   r   ri   r   r   r~   r   r[   appendrelustackr   )r,   r-   rm   rn   ro   rt   rq   rr   rs   r   r   Zweights_divZy_pred_spatial_dimsiweights_tensorr   Zmultiscale_listr   ZssimcsZcs_per_batchZmultiscale_list_tensorZms_ssim_value_full_imageZms_ssim_per_batchr   r   r    r     s`    



*

 r   ))
__future__r   rT   abcr   collections.abcr   r   	functoolsr   typingr   r#   torch.nn.functionalnn
functionalr   monai.metrics.utilsr   monai.utilsr	   r
   r   r   monai.utils.type_conversionr   metricr   r   r=   rI   rM   rP   rG   r]   rb   r   r`   r~   r   r   r   r   r   r    <module>   sD   :'e."AJ