# -*- coding: utf-8 -*-
# @Time    : 2025/1/4 23:55
# @Author  : ljc
# @FileName: mregress_pytorch.py
# @Software: PyCharm


# 1. 简介
"""
目的：
    执行多元线性回归拟合, 计算回归系数, 进而得到勒让德多项式系数.
函数:
    1) mregress_batch_cholesky、mregress_batch_lu、mregress_batch_svd、mregress_batch_qr、mregress_batch_inv
解释:
    1) mregress_batch_cholesky、mregress_batch_lu、mregress_batch_svd、mregress_batch_qr、mregress_batch_inv 
    函数: 回归系数、以及拟合过程中的信息. 
    2) 对于一般应用, mregress_batch_cholesky 提供了速度和稳定性的良好平衡 (默认使用该方法).
    3) 对于大规模数据, mregress_batch_lu 的分批处理很有用, 效率可能最高.
    4) 对于数值稳定性要求高的情况, mregress_batch_svd 是最佳选择, 但效率可能最低.       
    5) 对于大规模数据, mregress_batch_qr 提供了较好的平衡, 效率相对于 mregress_batch_svd 更高.
    6) 对于大规模数据, mregress_batch_inv 提供了较好的平衡, 效率相对于 mregress_batch_qr 更高.
    7) 效率比较可参考 matrix_inverse_benchmark.py.
注意:
    1) LASP 的 IDL 版本、目前 Python 版本使用该方法计算回归系数.
    2) 该方法涉及矩阵逆运算, 因此在分组优化时, 可能导致效率较低. 目前提供 mregress_batch_cholesky、mregress_batch_lu、
    mregress_batch_svd、mregress_batch_qr、mregress_batch_inv 函数, 可根据需要选择合适的方法.
    3) 分组优化时, 也可将勒让德多项式系数作为待优化变量与 3 个恒星大气参数、2 个 kernal 系数一起优化, 进而避免逆运算、效率可能更高 (目前未实现).
"""


# 2. 导库
from config.config import default_set, set_all_seeds
import torch
# 2.1 设置随机种子
set_all_seeds()
# 2.2 调用 GPU、指定数据类型
dtype, device = default_set()
# 2.3 默认数据类型
torch.set_default_dtype(dtype)


# 3. Cholesky 分解实现的多元线性回归拟合函数
def mregress_batch_cholesky(x, y, measure_errors=None) -> torch.Tensor: 

    """
        拟合 xA=y.
        使用 Cholesky 分解的版本, 计算速度较快, 但可能存在数值不稳定问题.

        输入:
        -----------
        x : torch.Tensor
            输入 tensor, 形状为 (batch_size, npts, nterm).
        y : torch.Tensor
            目标 tensor, 形状为 (batch_size, npts).
        measure_errors : None
                        保持与原函数接口一致, 暂未使用.

        返回:
        --------
        a : torch.Tensor
            回归系数, 形状为 (batch_size, nterm).
    """

    # 3.1 检查输入维度
    batch_size, npts, nterm = x.shape
    if y.shape != (batch_size, npts):
        raise ValueError('X and Y have incompatible dimensions!')

    # 3.2 计算自由度
    nfree = (npts - 1.0)

    # 3.3 计算权重
    weights = torch.ones_like(y, dtype=dtype, device=device)
    sw = weights.sum(dim=1, keepdim=True) / npts
    weights = weights / sw

    # 3.4 计算加权特征矩阵
    wx = x * weights.unsqueeze(-1)

    # 3.5 计算标准差
    sigmax = torch.sqrt((x * wx).sum(dim=1) / nfree)
    sigmay = torch.sqrt((weights * y.pow(2)).sum(dim=1) / nfree)

    # 3.6 计算正规方程矩阵
    ar1 = torch.bmm(wx.transpose(1, 2), x)

    # 3.7 计算外积并标准化
    sigmax_outer = sigmax.unsqueeze(-1) * sigmax.unsqueeze(1)
    ar2 = ar1 / (nfree * sigmax_outer)

    # 3.8 计算相关性
    correlation = torch.bmm(wx.transpose(1, 2), y.unsqueeze(2)).squeeze(-1)
    correlation = correlation / (sigmax * sigmay.unsqueeze(1) * nfree)
    b = correlation.unsqueeze(2)

    # 3.9 添加小的对角项以提高数值稳定性
    try:
        eps = torch.finfo(ar2.dtype).eps
        ar3 = ar2 + torch.eye(nterm, device=ar2.device, dtype=ar2.dtype).unsqueeze(0) * eps
        # 3.9.1 对角线加入小量 eps 后, 使用 Cholesky 分解求解线性方程组, 如果失败, 则尝试加入更大的小量
        L = torch.linalg.cholesky(ar3)
    except:
        try:
            eps = 1e-6
            ar3 = ar2 + torch.eye(nterm, device=ar2.device, dtype=ar2.dtype).unsqueeze(0) * eps
            # 3.9.2 对角线加入小量 eps 后(eps=1e-6), 使用 Cholesky 分解求解线性方程组, 如果失败, 则尝试加入更大的小量
            L = torch.linalg.cholesky(ar3)
        except:
            try:
                eps = 1e-5
                ar3 = ar2 + torch.eye(nterm, device=ar2.device, dtype=ar2.dtype).unsqueeze(0) * eps
                # 3.9.3 对角线加入小量 eps 后(eps=1e-5), 使用 Cholesky 分解求解线性方程组, 如果失败, 则尝试加入更大的小量
                L = torch.linalg.cholesky(ar3)
            except:
                eps = 1e-4
                ar3 = ar2 + torch.eye(nterm, device=ar2.device, dtype=ar2.dtype).unsqueeze(0) * eps
                # 3.9.4 对角线加入小量 eps 后(eps=1e-4), 使用 Cholesky 分解求解线性方程组, 如果失败, 则尝试加入更大的小量
                L = torch.linalg.cholesky(ar3)

    # 3.10 解 Ly = b
    y1 = torch.linalg.solve_triangular(L, b, upper=False)
    
    # 3.11 解 L^Tx = y
    a = torch.linalg.solve_triangular(L.transpose(-2, -1), y1, upper=True)

    # 3.12 调整系数
    a = a.squeeze(-1) * (sigmay.unsqueeze(1) / sigmax)

    # 3.13 返回回归系数
    return a


# 4. LU 分解实现的多元线性回归拟合函数
def mregress_batch_lu(x, y, measure_errors=None, batch_size=100) -> torch.Tensor:

    """
        拟合 xA=y.
        使用 LU 分解的版本, 支持分批处理, 适合大规模数据集.

        输入:
        -----------
        x : torch.Tensor
            输入 tensor, 形状为 (batch_size, npts, nterm).
        y : torch.Tensor
            目标 tensor, 形状为 (batch_size, npts).
        measure_errors : None
        batch_size : int
            分批大小, 默认为 100.

        返回:
        --------
        a : torch.Tensor
            回归系数, 形状为 (batch_size, nterm).
    """

    # 4.1 计算总批次数
    total_size, npts, nterm = x.shape
    num_batches = (total_size + batch_size - 1) // batch_size

    # 4.2 准备存储最终结果
    results = []

    # 4.3 遍历每个批次
    for i in range(num_batches):
        # 4.3.1 获取当前批次的数据
        start_idx = i * batch_size
        end_idx = min(start_idx + batch_size, total_size)
        batch_x = x[start_idx:end_idx]
        batch_y = y[start_idx:end_idx]

        # 4.3.2 当前批次大小
        current_batch_size = end_idx - start_idx
        if batch_y.shape != (current_batch_size, npts):
            raise ValueError('X and Y have incompatible dimensions!')

        # 4.3.3 计算自由度
        nfree = (npts - 1.0)

        # 4.3.4 计算权重
        weights = torch.ones((current_batch_size, npts), dtype=dtype, device=device)
        sw = weights.sum(dim=1, keepdim=True) / npts
        weights = weights / sw

        # 4.3.5 计算加权特征矩阵
        wx = batch_x * weights.unsqueeze(-1)

        # 4.3.6 计算标准差
        sigmax = torch.sqrt((batch_x * wx).sum(dim=1) / nfree)
        sigmay = torch.sqrt((weights * batch_y.pow(2)).sum(dim=1) / nfree)

        # 4.3.7 计算正规方程矩阵
        ar = torch.bmm(wx.transpose(1, 2), batch_x)

        # 4.3.8 计算外积并标准化
        sigmax_outer = sigmax.unsqueeze(-1) * sigmax.unsqueeze(1)
        ar = ar / (nfree * sigmax_outer)

        # 4.3.9 计算相关性
        correlation = torch.bmm(wx.transpose(1, 2), batch_y.unsqueeze(2)).squeeze(-1)
        correlation = correlation / (sigmax * sigmay.unsqueeze(1) * nfree)
        b = correlation.unsqueeze(2)

        # 4.3.10 直接使用 torch.linalg.solve, 内部会使用 LU 分解
        a = torch.linalg.solve(ar, b)

        # 4.3.11 调整系数
        a = a.squeeze(-1) * (sigmay.unsqueeze(1) / sigmax)
        results.append(a)

    # 4.3.12 合并所有批次的结果
    return torch.cat(results, dim=0)


# 5. SVD 分解实现的多元线性回归拟合函数
def mregress_batch_svd(x, y, measure_errors=None) -> torch.Tensor:

    """
        拟合 xA=y
        使用 SVD 分解的版本, 优势在于可以处理非满秩矩阵, 但计算速度较慢.

        输入:
        -----------
        x : torch.Tensor
            输入 tensor, 形状为 (batch_size, npts, nterm).
        y : torch.Tensor
            目标 tensor, 形状为 (batch_size, npts).
        measure_errors : None
            保持与原函数接口一致, 暂未使用.

        返回:
        --------
        a : torch.Tensor
            回归系数, 形状为 (batch_size, nterm).
    """

    # 5.1 检查输入维度
    batch_size, npts, nterm = x.shape
    if y.shape != (batch_size, npts):
        raise ValueError('X and Y have incompatible dimensions!')

    # 5.2 计算自由度
    nfree = (npts - 1.0)

    # 5.3 计算权重
    weights = torch.ones_like(y, dtype=x.dtype, device=x.device)
    sw = weights.sum(dim=1, keepdim=True) / npts
    weights = weights / sw

    # 5.4 计算加权特征矩阵
    wx = x * weights.unsqueeze(-1)

    # 5.5 计算标准差
    sigmax = torch.sqrt((x * wx).sum(dim=1) / nfree)
    sigmay = torch.sqrt((weights * y.pow(2)).sum(dim=1) / nfree)

    # 5.6 计算正规方程矩阵
    ar = torch.bmm(wx.transpose(1, 2), x)

    # 5.7 计算外积并标准化
    sigmax_outer = sigmax.unsqueeze(-1) * sigmax.unsqueeze(1)
    ar = ar / (nfree * sigmax_outer)

    # 5.8 计算相关性
    correlation = torch.bmm(wx.transpose(1, 2), y.unsqueeze(2)).squeeze(-1)
    correlation = correlation / (sigmax * sigmay.unsqueeze(1) * nfree)
    b = correlation.unsqueeze(2)

    # 5.9 使用 SVD 分解求解
    U, S, Vh = torch.linalg.svd(ar)

    # 5.10 使用伪逆求解, 添加小值避免除零
    S_inv = 1.0 / (S + 1e-10)
    a = torch.bmm(U * S_inv.unsqueeze(1), torch.bmm(U.transpose(1, 2), b))

    # 5.11 调整系数
    a = a.squeeze(-1) * (sigmay.unsqueeze(1) / sigmax)

    # 5.12 返回回归系数
    return a


# 6. QR 分解实现的多元线性回归拟合函数
def mregress_batch_qr(x, y, measure_errors=None) -> torch.Tensor:

    """
        拟合 xA=y
        使用 QR 分解的版本, 避免矩阵逆运算, 计算速度较快.

        输入:
        -----------
        x : torch.Tensor
            输入 tensor, 形状为 (batch_size, npts, nterm).
        y : torch.Tensor
            目标 tensor, 形状为 (batch_size, npts).
        measure_errors : None
            保持与原函数接口一致, 暂未使用.

        返回:
        --------
        a : torch.Tensor
            回归系数, 形状为 (batch_size, nterm).
    """

    # 6.1 检查输入维度
    batch_size, npts, nterm = x.shape
    if y.shape != (batch_size, npts):
        raise ValueError('X and Y have incompatible dimensions!')

    # 6.2 计算自由度
    nfree = (npts - 1.0)

    # 6.3 计算权重
    weights = torch.ones_like(y, dtype=x.dtype, device=x.device)
    sw = weights.sum(dim=1, keepdim=True) / npts
    weights = weights / sw

    # 6.4 计算加权特征矩阵
    wx = x * weights.unsqueeze(-1)

    # 6.5 计算标准差
    sigmax = torch.sqrt((x * wx).sum(dim=1) / nfree)
    sigmay = torch.sqrt((weights * y.pow(2)).sum(dim=1) / nfree)

    # 6.6 计算正规方程矩阵
    ar = torch.bmm(wx.transpose(1, 2), x)

    # 6.7 计算外积并标准化
    sigmax_outer = sigmax.unsqueeze(-1) * sigmax.unsqueeze(1)
    ar = ar / (nfree * sigmax_outer)

    # 6.8 计算相关性
    correlation = torch.bmm(wx.transpose(1, 2), y.unsqueeze(2)).squeeze(-1)
    correlation = correlation / (sigmax * sigmay.unsqueeze(1) * nfree)
    b = correlation.unsqueeze(2)

    # 6.9 使用 QR 分解
    q, r = torch.linalg.qr(ar)
    # 6.10 使用三角矩阵求解
    a = torch.linalg.solve_triangular(r, torch.bmm(q.transpose(1, 2), b), upper=True)

    # 6.11 调整系数
    a = a.squeeze(-1) * (sigmay.unsqueeze(1) / sigmax)

    # 6.12 返回回归系数
    return a


# 7. 批量多元线性回归拟合函数
def mregress_batch_inv(x, y, measure_errors=None) -> torch.Tensor:

    """
        拟合 xA=y
        使用 torch.linalg.inv 版本的批量多元线性回归拟合 xA=y, 计算速度较慢.

        输入:
        -----------
        x : torch.Tensor
            输入 tensor, 形状为 (batch_size, npts, nterm).
        y : torch.Tensor
            目标 tensor, 形状为 (batch_size, npts).
        measure_errors : None
            保持与原函数接口一致, 暂未使用.

        返回:
        --------
        a : torch.Tensor
            回归系数, 形状为 (batch_size, nterm).
    """

    # 7.1 检查输入是否连续，如果不是则使用 view 而不是 contiguous()
    batch_size, npts, nterm = x.shape
    if not x.is_contiguous():
        x = x.view(batch_size, npts, nterm)
    if not y.is_contiguous():
        y = y.view(batch_size, npts)
    
    # 7.2 检查输入维度
    if y.shape != (batch_size, npts):
        raise ValueError('X and Y have incompatible dimensions!')

    # 7.3 计算自由度
    nfree = (npts - 1.0)

    # 7.4 计算权重
    weights = torch.ones_like(y, dtype=x.dtype, device=x.device)
    sw = weights.sum(dim=1, keepdim=True).div_(npts)
    weights.div_(sw)

    # 7.5 计算加权特征矩阵
    wx = x * weights.unsqueeze(-1)

    # 7.6 计算标准差
    sigmax = torch.sqrt_((x * wx).sum(dim=1).div_(nfree))
    sigmay = torch.sqrt_((weights * y.pow(2)).sum(dim=1).div_(nfree))

    # 7.7 计算正规方程矩阵
    ar = torch.bmm(wx.transpose(1, 2), x)

    # 7.8 计算外积并标准化
    sigmax_outer = sigmax.unsqueeze(-1) * sigmax.unsqueeze(1)
    ar.div_(nfree * sigmax_outer)

    # 7.9 使用批量求逆
    ar = torch.linalg.inv(ar)

    # 7.10 计算相关性
    correlation = torch.bmm(wx.transpose(1, 2), y.unsqueeze(2)).squeeze(-1)
    correlation.div_(sigmax * sigmay.unsqueeze(1) * nfree)

    # 7.11 最终计算
    a = torch.bmm(ar, correlation.unsqueeze(2)).squeeze(-1)
    # 7.12 调整系数
    a.mul_(sigmay.unsqueeze(1) / sigmax)

    # 7.13 返回回归系数
    return a


# 8. 使用示例
# if __name__ == "__main__":

#     # 8.1 创建测试数据集
#     X = torch.tensor([[2, 4, 6], [3, 5, 9], [10, 11, 3]], dtype=dtype, device=device)
#     y = torch.tensor([6, 8, 11], dtype=dtype, device=device)
#     # 8.1.1 将 X 和 y 扩展为 (batch_size, npts, nterm) 和 (batch_size, npts)
#     X = X.unsqueeze(0)         
#     y = y.unsqueeze(0)

#     # 8.2 设置测量误差参数, 保持与原函数接口一致, 暂未使用
#     measure_errors = torch.tensor([1, 1, 1], dtype=dtype, device=device)
#     # 8.2.1 同样为测量误差添加批次维度
#     measure_errors = measure_errors.unsqueeze(0) 

#     # 8.3 执行多元回归分析
#     a = mregress_batch_cholesky(X, y, measure_errors)
#     print(a)