o
    ijf                     @  sh  d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	 d dl
mZ d dlZd dlm  mZ d dlmZ d dlmZmZmZmZ d d	lmZ d
dlmZ G dd deZG dd deZG dd deZG dd deZG dd deZ dAddZ!G dd deZ"G d d! d!eZ#dBd)d*Z$d+e"j%d,d-fdCd5d6Z&G d7d8 d8eZ'd+e"j%d9d:d,d-d;fdDd?d@Z(dS )E    )annotationsN)abstractmethod)CallableSequence)partial)Any)do_metric_reduction)MetricReductionStrEnumconvert_data_typeensure_tuple_rep)convert_to_dst_type   )CumulativeIterationMetricc                      sZ   e Zd ZdZejdfd fd	d
Z	ddddZdddZe	dddZ
dddZ  ZS )RegressionMetrica  
    Base class for regression metrics.
    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.
    `y_preds` and `y` can be a list of channel-first Tensor (CHW[D]) or a batch-first Tensor (BCHW[D]).

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).
            Here `not_nans` count the number of not nans for the metric, thus its shape equals to the shape of the metric.

    F	reductionMetricReduction | strget_not_nansboolreturnNonec                   s   t    || _|| _d S N)super__init__r   r   selfr   r   	__class__ Z/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/regression.pyr   0   s   

zRegressionMetric.__init__NMetricReduction | str | None0torch.Tensor | tuple[torch.Tensor, torch.Tensor]c                 C  sB   |   }t|tjstdt||p| j\}}| jr||fS |S )ao  
        Args:
            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
                available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
                ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
        z-the data to aggregate must be PyTorch Tensor.)
get_buffer
isinstancetorchTensor
ValueErrorr   r   r   )r   r   datafnot_nansr   r   r   	aggregate5   s
   	zRegressionMetric.aggregatey_predtorch.Tensoryc                 C  s@   |j |j krtd|j  d|j  dt|j dk rtdd S )Nz2y_pred and y shapes dont match, received y_pred: [z
] and y: []   zIeither channel or spatial dimensions required, found only batch dimension)shaper&   lenr   r+   r-   r   r   r   _check_shapeE   s
   zRegressionMetric._check_shapec                 C  s   t d| jj d)Nz	Subclass z must implement this method.)NotImplementedErrorr   __name__r2   r   r   r   _compute_metricM   s   z RegressionMetric._compute_metricc                 C  s8   t |tjrt |tjstd| || | ||S )Nz$y_pred and y must be PyTorch Tensor.)r#   r$   r%   r&   r3   r6   r2   r   r   r   _compute_tensorQ   s   z RegressionMetric._compute_tensorr   r   r   r   r   r   r   )r   r    r   r!   )r+   r,   r-   r,   r   r   r+   r,   r-   r,   r   r,   )r5   
__module____qualname____doc__r	   MEANr   r*   r3   r   r6   r7   __classcell__r   r   r   r   r      s    
r   c                      4   e Zd ZdZejdfd fd	d
ZdddZ  ZS )	MSEMetrica  Compute Mean Squared Error between two tensors using function:

    .. math::
        \operatorname {MSE}\left(Y, \hat{Y}\right) =\frac {1}{n}\sum _{i=1}^{n}\left(y_i-\hat{y_i} \right)^{2}.

    More info: https://en.wikipedia.org/wiki/Mean_squared_error

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fr   r   r   r   r   r   c                   $   t  j||d ttjdd| _d S Nr   r   g       @)exponentr   r   r   r$   powsq_funcr   r   r   r   r   m      zMSEMetric.__init__r+   r,   r-   c                 C     t ||| jdS Nfunc)compute_mean_error_metricsrG   r2   r   r   r   r6   q      zMSEMetric._compute_metricr8   r9   	r5   r:   r;   r<   r	   r=   r   r6   r>   r   r   r   r   r@   X       r@   c                      r?   )	MAEMetrica  Compute Mean Absolute Error between two tensors using function:

    .. math::
        \operatorname {MAE}\left(Y, \hat{Y}\right) =\frac {1}{n}\sum _{i=1}^{n}\left|y_i-\hat{y_i}\right|.

    More info: https://en.wikipedia.org/wiki/Mean_absolute_error

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fr   r   r   r   r   r   c                   s   t  j||d tj| _d S NrC   )r   r   r$   absabs_funcr   r   r   r   r      s   zMAEMetric.__init__r+   r,   r-   c                 C  rI   rJ   )rM   rT   r2   r   r   r   r6      rN   zMAEMetric._compute_metricr8   r9   rO   r   r   r   r   rQ   u   rP   rQ   c                      r?   )
RMSEMetrica/  Compute Root Mean Squared Error between two tensors using function:

    .. math::
        \operatorname {RMSE}\left(Y, \hat{Y}\right) ={ \sqrt{ \frac {1}{n}\sum _{i=1}^{n}\left(y_i-\hat{y_i}\right)^2 } } \
        = \sqrt {\operatorname{MSE}\left(Y, \hat{Y}\right)}.

    More info: https://en.wikipedia.org/wiki/Root-mean-square_deviation

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fr   r   r   r   r   r   c                   rA   rB   rE   r   r   r   r   r      rH   zRMSEMetric.__init__r+   r,   r-   c                 C  s   t ||| jd}t|S rJ   )rM   rG   r$   sqrtr   r+   r-   Zmse_outr   r   r   r6      s   
zRMSEMetric._compute_metricr8   r9   rO   r   r   r   r   rU      s    rU   c                      s4   e Zd ZdZejdfd fddZdddZ  ZS )
PSNRMetrica_  Compute Peak Signal To Noise Ratio between two tensors using function:

    .. math::
        \operatorname{PSNR}\left(Y, \hat{Y}\right) = 20 \cdot \log_{10} \left({\mathit{MAX}}_Y\right) \
        -10 \cdot \log_{10}\left(\operatorname{MSE\left(Y, \hat{Y}\right)}\right)

    More info: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio

    Help taken from:
    https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/image_ops_impl.py line 4139

    Input `y_pred` is compared with ground truth `y`.
    Both `y_pred` and `y` are expected to be real-valued, where `y_pred` is output from a regression model.

    Example of the typical execution steps of this metric class follows :py:class:`monai.metrics.metric.Cumulative`.

    Args:
        max_val: The dynamic range of the images/volumes (i.e., the difference between the
            maximum and the minimum allowed values e.g. 255 for a uint8 image).
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction.
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans).

    Fmax_valint | floatr   r   r   r   r   r   c                   s*   t  j||d || _ttjdd| _d S rB   )r   r   rY   r   r$   rF   rG   )r   rY   r   r   r   r   r   r      s   zPSNRMetric.__init__r+   r,   r-   r   c                 C  s.   t ||| jd}dt| j dt|  S )NrK      
   )rM   rG   mathlog10rY   r$   rW   r   r   r   r6      s   zPSNRMetric._compute_metric)rY   rZ   r   r   r   r   r   r   )r+   r,   r-   r,   r   r   rO   r   r   r   r   rX      s
    rX   r+   r,   r-   rL   r   r   c                 C  s*   t tjdd}tj||||  dddS )Nr   )	start_dimT)dimkeepdim)r   r$   flattenmean)r+   r-   rL   fltr   r   r   rM      s   rM   c                   @  s   e Zd ZdZdZdS )
KernelTypegaussianuniformN)r5   r:   r;   GAUSSIANUNIFORMr   r   r   r   rf      s    rf   c                      sB   e Zd ZdZdejddddejdfd! fddZd"dd Z	  Z
S )#
SSIMMetrica  
    Computes the Structural Similarity Index Measure (SSIM).

    .. math::
        \operatorname {SSIM}(x,y) =\frac {(2 \mu_x \mu_y + c_1)(2 \sigma_{xy} + c_2)}{((\mu_x^2 + \
                \mu_y^2 + c_1)(\sigma_x^2 + \sigma_y^2 + c_2)}

    For more info, visit
        https://vicuesoft.com/glossary/term/ssim-ms-ssim/

    SSIM reference paper:
        Wang, Zhou, et al. "Image quality assessment: from error visibility to structural
        similarity." IEEE transactions on image processing 13.4 (2004): 600-612.

    Args:
        spatial_dims: number of spatial dimensions of the input images.
        data_range: value range of input images. (usually 1.0 or 255)
        kernel_type: type of kernel, can be "gaussian" or "uniform".
        win_size: window size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
        k1: stability constant used in the luminance denominator
        k2: stability constant used in the contrast denominator
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)
          ?         ?{Gz?Q?Fspatial_dimsint
data_rangefloatkernel_typeKernelType | strwin_sizeint | Sequence[int]kernel_sigmafloat | Sequence[float]k1k2r   r   r   r   r   r   c
           
        sf   t  j||	d || _|| _|| _t|tst||}|| _t|ts(t||}|| _	|| _
|| _d S rR   )r   r   rq   rs   ru   r#   r   r   kernel_sizery   r{   r|   )
r   rq   rs   ru   rw   ry   r{   r|   r   r   r   r   r   r     s   




zSSIMMetric.__init__r+   r,   r-   c                 C  s   |  }| jdkr|dkrtd| j d| d| jdkr.|dkr.td| j d| dt||| j| j| j| j| j| j| j	d		\}}|
|jd
 djddd}|S )aR  
        Args:
            y_pred: Predicted image.
                It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
            y: Reference image.
                It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].

        Raises:
            ValueError: when `y_pred` is not a 2D or 3D image.
        r/      Ky_pred should have 4 dimensions (batch, channel, height, width) when using  spatial dimensions, got .      zRy_pred should have 5 dimensions (batch, channel, height, width, depth) when using 	r+   r-   rq   rs   ru   r}   ry   r{   r|   r   r`   r   Trb   )
ndimensionrq   r&   compute_ssim_and_csrs   ru   r}   ry   r{   r|   viewr0   rd   )r   r+   r-   dimsssim_value_full_image_Zssim_per_batchr   r   r   r6     s8   


zSSIMMetric._compute_metric)rq   rr   rs   rt   ru   rv   rw   rx   ry   rz   r{   rt   r|   rt   r   r   r   r   r   r   r9   r5   r:   r;   r<   rf   ri   r	   r=   r   r6   r>   r   r   r   r   rk      s    rk   rq   rr   num_channelsr}   Sequence[int]ry   Sequence[float]c           
   	   C  s   ddd}||d	 |d	 }||d
 |d
 }t | |}|d
|d	 |d
 f}| dkr`||d |d d }	t |dd
d
|d |	|d	 |d
 |d }|d
|d	 |d
 |d f}||S )a  Computes 2D or 3D gaussian kernel.

    Args:
        spatial_dims: number of spatial dimensions of the input images.
        num_channels: number of channels in the image
        kernel_size: size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
    r}   rr   sigmart   r   r,   c                 S  sP   t jd|  d d|  d dd}t t || d d }||  jddS )zComputes 1D gaussian kernel.

        Args:
            kernel_size: size of the gaussian kernel
            sigma: Standard deviation of the gaussian kernel
        r   r/   )startendstepr   ra   )r$   arangeexprF   sum	unsqueeze)r}   r   distgaussr   r   r   gaussian_1dU  s    z%_gaussian_kernel.<locals>.gaussian_1dr   r   r   r/   r   r`   N)r}   rr   r   rt   r   r,   )r$   matmultmulr   repeatexpand)
rq   r   r}   ry   r   Zgaussian_kernel_xZgaussian_kernel_ykernelZkernel_dimensionsZgaussian_kernel_zr   r   r   _gaussian_kernelI  s   

r   rl   ro   rp   rs   rt   ru   rv   r{   r|   !tuple[torch.Tensor, torch.Tensor]c	                 C  s  |j | j krtd| j  d|j  dt| tjtjdd } t|tjtjdd }| d}	|tjkr;t	||	||}
n|tj
krRt|	dg|R tt| }
t|
| dd }
|| d }|| d }ttd	| d
}|| |
|	d}|||
|	d}|| |  |
|	d}||| |
|	d}|| | |
|	d}|||  }|||  }|||  }d| | || |  }d| | | |d |d  |  | }||fS )a  
    Function to compute the Structural Similarity Index Measure (SSIM) and Contrast Sensitivity (CS) for a batch
    of images.

    Args:
        y_pred: batch of predicted images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
        y: batch of target images with shape (batch_size, channels, spatial_dim1, spatial_dim2[, spatial_dim3])
        kernel_size: the size of the kernel to use for the SSIM computation.
        kernel_sigma: the standard deviation of the kernel to use for the SSIM computation.
        spatial_dims: number of spatial dimensions of the images (2, 3)
        data_range: the data range of the images.
        kernel_type: the type of kernel to use for the SSIM computation. Can be either "gaussian" or "uniform".
        k1: the first stability constant.
        k2: the second stability constant.

    Returns:
        ssim: the Structural Similarity Index Measure score for the batch of images.
        cs: the Contrast Sensitivity for the batch of images.
    z*y_pred and y should have same shapes, got z and r   )output_typedtyper   r   )srcdstr/   convd)groups)r0   r&   r   r$   r%   rt   sizerf   ri   r   rj   onesprodtensorr   getattrF)r+   r-   rq   r}   ry   rs   ru   r{   r|   r   r   c1c2conv_fnmu_xmu_yZmu_xxZmu_yyZmu_xysigma_xsigma_yZsigma_xyZcontrast_sensitivityr   r   r   r   r   q  s0   


$(r   c                	      sD   e Zd ZdZdejdddddejdf	d$ fddZd%d"d#Z	  Z
S )&MultiScaleSSIMMetricaL  
    Computes the Multi-Scale Structural Similarity Index Measure (MS-SSIM).

    MS-SSIM reference paper:
        Wang, Z., Simoncelli, E.P. and Bovik, A.C., 2003, November. "Multiscale structural
        similarity for image quality assessment." In The Thirty-Seventh Asilomar Conference
        on Signals, Systems & Computers, 2003 (Vol. 2, pp. 1398-1402). IEEE

    Args:
        spatial_dims: number of spatial dimensions of the input images.
        data_range: value range of input images. (usually 1.0 or 255)
        kernel_type: type of kernel, can be "gaussian" or "uniform".
        kernel_size: size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
        k1: stability constant used in the luminance denominator
        k2: stability constant used in the contrast denominator
        weights: parameters for image similarity and contrast sensitivity at different resolution scores.
        reduction: define the mode to reduce metrics, will only execute reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to ``"mean"``. if "none", will not do reduction
        get_not_nans: whether to return the `not_nans` count, if True, aggregate() returns (metric, not_nans)
    rl   rm   rn   ro   rp   gǺ?g48EG?ga4?g??g9EGr?Frq   rr   rs   rt   ru   rv   r}   rx   ry   rz   r{   r|   weightsr   r   r   r   r   r   r   c                   sl   t  j|	|
d || _|| _|| _t|tst||}|| _t|ts(t||}|| _	|| _
|| _|| _d S rR   )r   r   rq   rs   ru   r#   r   r   r}   ry   r{   r|   r   )r   rq   rs   ru   r}   ry   r{   r|   r   r   r   r   r   r   r     s   




zMultiScaleSSIMMetric.__init__r+   r,   r-   c                 C  s,   t ||| j| j| j| j| j| j| j| jd
S )N)
r+   r-   rq   rs   ru   r}   ry   r{   r|   r   )	compute_ms_ssimrq   rs   ru   r}   ry   r{   r|   r   r2   r   r   r   r6     s   z$MultiScaleSSIMMetric._compute_metric)rq   rr   rs   rt   ru   rv   r}   rx   ry   rz   r{   rt   r|   rt   r   r   r   r   r   r   r   r   r9   r   r   r   r   r   r     s    r   rm   rn   r   rx   rz   r   c
                 C  s  |   }
|dkr|
dkrtd| d|
 d|dkr*|
dkr*td| d|
 dt|ts4t||}t|ts>t||}td	t|	d	 d }| jdd
 }tt|D ]&}|| | || d	 kr|tdt|	 d||  d|| d	 |  dqVt	j
|	| jt	jd}ttd| d}g }tt|D ]1}t| ||||||||d	\}}||jd dd	}|t	| || dd} ||dd}q||jd dd	}t	||d< t	|}t	j||dd	 dd}||jd djd	dd}|S )aQ  
    Args:
        y_pred: Predicted image.
            It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
        y: Reference image.
            It must be a 2D or 3D batch-first tensor [B,C,H,W] or [B,C,H,W,D].
        spatial_dims: number of spatial dimensions of the input images.
        data_range: value range of input images. (usually 1.0 or 255)
        kernel_type: type of kernel, can be "gaussian" or "uniform".
        kernel_size: size of kernel
        kernel_sigma: standard deviation for Gaussian kernel.
        k1: stability constant used in the luminance denominator
        k2: stability constant used in the contrast denominator
        weights: parameters for image similarity and contrast sensitivity at different resolution scores.
    Raises:
        ValueError: when `y_pred` is not a 2D or 3D image.
    r/   r~   r   r   r   r   r   zRy_pred should have 4 dimensions (batch, channel, height, width, depth) when using r   Nz+For a given number of `weights` parameters z and kernel size z', the image height must be larger than )devicer   avg_poolr   r   r   r`   )r}   r   Tr   )r   r&   r#   r   r   maxr1   r0   ranger$   r   r   rt   r   r   r   r   rd   appendrelustackr   )r+   r-   rq   rs   ru   r}   ry   r{   r|   r   r   Zweights_divZy_pred_spatial_dimsiweights_tensorr   Zmultiscale_listr   ZssimcsZcs_per_batchZmultiscale_list_tensorZms_ssim_value_full_imageZms_ssim_per_batchr   r   r   r     sp   





r   )r+   r,   r-   r,   rL   r   r   r,   )
rq   rr   r   rr   r}   r   ry   r   r   r,   )r+   r,   r-   r,   rq   rr   r}   r   ry   r   rs   rt   ru   rv   r{   rt   r|   rt   r   r   )r+   r,   r-   r,   rq   rr   rs   rt   ru   rv   r}   rx   ry   rz   r{   rt   r|   rt   r   r   r   r,   ))
__future__r   r]   abcr   collections.abcr   r   	functoolsr   typingr   r$   torch.nn.functionalnn
functionalr   monai.metrics.utilsr   monai.utilsr	   r
   r   r   monai.utils.type_conversionr   metricr   r   r@   rQ   rU   rX   rM   rf   rk   r   ri   r   r   r   r   r   r   r   <module>   sF   :
'
e.AJ