o
    i                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZ ed\Z	Z
G dd deZdddZd d!ddZd"ddZ	d#d$ddZdS )%    )annotationsN)Metric)optional_importscipyc                   @  s   e Zd ZdZd	ddZdS )
	FIDMetrica  
    Frechet Inception Distance (FID). The FID calculates the distance between two distributions of feature vectors.
    Based on: Heusel M. et al. "Gans trained by a two time-scale update rule converge to a local nash equilibrium."
    https://arxiv.org/abs/1706.08500. The inputs for this metric should be two groups of feature vectors (with format
    (number images, number of features)) extracted from a pretrained network.

    Originally, it was proposed to use the activations of the pool_3 layer of an Inception v3 pretrained with Imagenet.
    However, others networks pretrained on medical datasets can be used as well (for example, RadImageNwt for 2D and
    MedicalNet for 3D images). If the chosen model output is not a scalar, a global spatia average pooling should be
    used.
    y_predtorch.Tensoryreturnc                 C  s
   t ||S )N)get_fid_score)selfr   r	    r   S/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/fid.py__call__$   s   
zFIDMetric.__call__Nr   r   r	   r   r
   r   )__name__
__module____qualname____doc__r   r   r   r   r   r      s    r   r   r   r	   r
   c                 C  sf   |  }|   } | dkrtdtj| dd}t| dd}tj|dd}t|dd}t||||S )a  Computes the FID score metric on a batch of feature vectors.

    Args:
        y_pred: feature vectors extracted from a pretrained network run on generated images.
        y: feature vectors extracted from a pretrained network run on images from the real data distribution.
       z=Inputs should have (number images, number of features) shape.r   )dimF)rowvar)double
ndimension
ValueErrortorchmean_covcompute_frechet_distance)r   r	   Z	mu_y_predZsigma_y_predmu_ysigma_yr   r   r   r   (   s   r   T
input_datar   boolc                 C  sn   |   dk r| dd} |s| ddkr|  } d| dd  }| tj| ddd } || |    S )a	  
    Estimate a covariance matrix of the variables.

    Args:
        input_data: A 1-D or 2-D array containing multiple variables and observations. Each row of `m` represents a variable,
            and each column a single observation of all those variables.
        rowvar: If rowvar is True (default), then each row represents a variable, with observations in the columns.
            Otherwise, the relationship is transposed: each column represents a variable, while the rows contain
            observations.
    r      r   g      ?T)r   keepdim)r   viewsizetr   r   matmulsqueeze)r!   r   factorr   r   r   r   =   s   r   c                 C  s2   t jj|    tjdd\}}t	
|S )z$Compute the square root of a matrix.F)disp)r   linalgsqrtmdetachcpunumpyastypenpfloat64r   
from_numpy)r!   Z	scipy_res_r   r   r   _sqrtmS   s   (
r7   ư>mu_xsigma_xr   r    epsilonfloatc           	      C  s   | | }t ||}t| s4td| d tj|d| j| j	d| }t || || }t
|r]tjt|jtjdtjdddsZtdtt|j d	|j}t|}||t| t| d
|  S )z?The Frechet distance between multivariate normal distributions.z2FID calculation produces singular product; adding z$ to diagonal of covariance estimatesr   )devicedtype)r>   gMbP?)atolzImaginary component z
 too high.r   )r7   mmr   isfiniteallprinteyer'   r=   r>   
is_complexallclosediagonalimagtensorr   r   maxabsrealtracedot)	r9   r:   r   r    r;   diffZcovmeanoffsetZ
tr_covmeanr   r   r   r   Y   s   
$
&r   r   )T)r!   r   r   r"   r
   r   )r!   r   r
   r   )r8   )r9   r   r:   r   r   r   r    r   r;   r<   r
   r   )
__future__r   r1   r3   r   Zmonai.metrics.metricr   monai.utilsr   r   r6   r   r   r   r7   r   r   r   r   r   <module>   s   

