o
    i                     @  sJ   d dl mZ d dlmZ d dlZd dlmZ G dd deZdddZdS )    )annotations)CallableN)Metricc                      s.   e Zd ZdZdd fddZdddZ  ZS )	MMDMetrica  
    Unbiased Maximum Mean Discrepancy (MMD) is a kernel-based method for measuring the similarity between two
    distributions. It is a non-negative metric where a smaller value indicates a closer match between the two
    distributions.

    Gretton, A., et al,, 2012.  A kernel two-sample test. The Journal of Machine Learning Research, 13(1), pp.723-773.

    Args:
        y_mapping: Callable to transform the y tensors before computing the metric. It is usually a Gaussian or Laplace
            filter, but it can be any function that takes a tensor as input and returns a tensor as output such as a
            feature extractor or an Identity function., e.g. `y_mapping = lambda x: x.square()`.
    N	y_mappingCallable | NonereturnNonec                   s   t    || _d S N)super__init__r   )selfr   	__class__ S/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/mmd.pyr   #   s   

zMMDMetric.__init__ytorch.Tensory_predc                 C  s   t ||| jS r
   )compute_mmdr   )r   r   r   r   r   r   __call__'   s   zMMDMetric.__call__r
   )r   r   r   r	   )r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r   __classcell__r   r   r   r   r      s    r   r   r   r   r   r   r   c                 C  s  |j d dks| j d dkrtd|dur|| } ||}|j | j kr0td|j  d| j  tt| j d ddD ]}| j|d} |j|d}q;| | j d d} ||j d d}t| |  }t|| }t||  }| j d }|j d }d||d   }	t	|t
t| }
d||d   }t	|t
t| }d	||  }t	|}|	|
 ||  ||  }|S )
a-  
    Args:
        y: first sample (e.g., the reference image). Its shape is (B,C,W,H) for 2D data and (B,C,W,H,D) for 3D.
        y_pred: second sample (e.g., the reconstructed image). It has similar shape as y.
        y_mapping: Callable to transform the y tensors before computing the metric.
    r      z9MMD metric requires at least two samples in y and y_pred.Nz[y_pred and y shapes dont match after being processed by their transforms, received y_pred: z and y: )dim   )shape
ValueErrorrangelensqueezeviewtorchmmtsumdiagdiagonal)r   r   r   dZy_yZy_pred_y_predZy_pred_ymnc1ac2bc3cmmdr   r   r   r   +   s>   


r   )r   r   r   r   r   r   r   r   )	
__future__r   collections.abcr   r&   Zmonai.metrics.metricr   r   r   r   r   r   r   <module>   s   