# -*- coding: utf-8 -*-
# @Time    : 2025/1/20 17:20
# @Author  : ljc
# @FileName: matrix_inverse_benchmark.py
# @Software: PyCharm


# 1. 简介
"""
 不同矩阵求逆方法效率对比.
目的:
    执行多元线性回归拟合, 计算回归系数, 进而得到勒让德多项式系数
函数:
    1) mregress_batch_cholesky、mregress_batch_lu、mregress_batch_svd、mregress_batch_qr、mregress_batch_inv
解释：
    1) mregress_batch_cholesky、mregress_batch_lu、mregress_batch_svd、mregress_batch_qr、mregress_batch_inv
     等函数: 回归系数、以及拟合过程中的信息.
注意:
    不同方法效率有差异, 可根据需要选择合适的方法.
"""


# 2. 导库
import time
import torch
from mregress_pytorch import mregress_batch_cholesky, mregress_batch_lu, mregress_batch_svd, mregress_batch_qr, mregress_batch_inv


# 3. 性能测试函数
def benchmark_methods_batched(x, y, batch_size=10, num_runs=5) -> tuple[dict, dict]:
    
    """
        不同矩阵求逆方法效率对比.
        
        输入:
        -----------
        x : torch.Tensor
            输入 tensor, 形状为 (batch_size, npts, nterm).
        y : torch.Tensor
            目标 tensor, 形状为 (batch_size, npts).
        batch_size : int
            分块大小.
        num_runs : int
            运行次数.

        返回:
        --------
        times : dict
            不同方法的平均运行时间.
        results : dict
            不同方法的回归系数的最大差异.
    """

    methods = {'Cholesky': mregress_batch_cholesky,   # 使用 Cholesky 分解法
               'LU': mregress_batch_lu,               # 使用 LU 分解法
               'SVD': mregress_batch_svd,             # 使用 SVD 分解法
               'QR': mregress_batch_qr,               # 使用 QR 分解法
               'inv': mregress_batch_inv              # 使用逆矩阵法
               }

    # 3.1 使用 chunk 直接分割数据
    x_chunks = torch.chunk(x, x.size(0) // batch_size)
    y_chunks = torch.chunk(y, y.size(0) // batch_size)

    # 3.2 结果存储
    results = {}
    times = {}

    # 3.3 遍历方法
    for name, method in methods.items():
        print(f"\nTesting {name} method:")
        total_time = 0
        for run in range(num_runs):
            start = time.perf_counter()
            # 3.3.1 直接处理所有分块
            results_chunks = [method(x_chunk, y_chunk) for x_chunk, y_chunk in zip(x_chunks, y_chunks)]
            # 3.3.2 合并结果
            result = torch.cat(results_chunks)
            end = time.perf_counter()
            run_time = end - start
            total_time += run_time
            print(f"Run {run + 1}: {run_time * 1000:.2f} ms")

        # 3.3.3 计算平均时间
        avg_time = total_time / num_runs
        times[name] = avg_time
        results[name] = result

    # 3.4 打印总体时间对比
    print("\nTotal Performance comparison:")
    for name, t in times.items():
        print(f"{name:8s}: {t * 1000:.2f} ms average per full dataset")

    # 3.5 验证结果一致性
    base_result = results['Cholesky']
    print("\nResult differences from Cholesky:")
    for name, result in results.items():
        if name != 'Cholesky':
            diff = torch.abs(base_result - result).max().item()
            print(f"{name:8s}: {diff:.2e}")

    return times, results


# 4. 使用示例
# 4.1 创建测试数据
# x = torch.randn(2500, 1325, 51)
# y = torch.randn(2500, 1325)
# 4.2 运行分批性能测试
# times, results = benchmark_methods_batched(x, y, batch_size=30)