# -*- coding: utf-8 -*-
# @Time    : 12/12/2024 11.21
# @Author  : ljc
# @FileName: mregress.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL mregress.pro .
目的:
    执行多元线性回归拟合, 计算回归系数.
函数:
    1) mregress
解释:
    1) mregress 函数: 计算回归系数及提供拟合过程中的信息.
"""


# 2. 导库
import numpy as np
# import scipy as sc
import warnings
# 2.1 忽略不必要的警告信息
warnings.filterwarnings("ignore")


# 3. 多元线性回归函数
def mregress(x, y, measure_errors=None, sigma=None, inv=None, status=None) -> tuple[np.ndarray, dict]:

    """
    y[i] = A[0]*x[0,i] + A[1]*x[1,i] + ... + A[Nterms-1]*x[Nterms-1,i].
    执行多元线性回归拟合 xA=y, 求解系数矩阵 A.
    
    对于加权最小二乘法, 系数矩阵 A 的显式解为:
        A = (X^T·W·X)^(-1)·X^T·W·y
    其中:
        X 是自变量矩阵
        W 是权重对角矩阵, 对角元素为 1/(measure_errors^2)
        y 是因变量向量
        X^T 表示 X 的转置
        ^(-1) 表示矩阵求逆

    输入参数:
    -----------
    x:
        自变量数据数组. x 维度必须为 (Npoints, Nterms), 其中 Nterms 是待求解的系数(自变量)数量,
        Npoints 是样本点数. 注意: 此数组相对于 IDL 程序 regress 的输入是转置关系.
    y:
        因变量数据向量. y 必须包含 Npoints 个元素, 与 x 的第一维度匹配.
    measure_errors:
        包含每个点 y[i] 标准测量误差的向量. 向量长度必须与 x 和 y 相同.
    sigma:
        命名变量, 用于接收返回的系数误差估计值.
    inv:
        命名变量, 用于接收包含协方差矩阵和其他可重用中间计算结果的结构.
    status:
        命名变量, 用于接收操作执行状态. 可能的状态值包括:
        1) 0: 成功完成计算;
        2) 1: 遇到奇异矩阵 (无法求逆);
        3) 2: 警告使用了较小的主元素, 可能导致精度显著降低.

    输出参数:
    -----------
    a:
        方程 xA=y 中求解得到的系数矩阵 A.
    inv:
        包含拟合详细信息的字典, 可用于进一步分析或诊断.
    """

    # 3.1 数据处理
    # 3.1.1 确保输入为 numpy 数组
    x = np.asarray(x)
    y = np.asarray(y)

    # 3.1.2 获取数组维度信息
    # 3.1.2.1 xA=y 中，x 形状为（流量维度，样本）, y 形状为（流量维度，）, y 的形状为（流量维度，）
    sx = x.shape
    ndimX = len(sx)
    sy = y.shape

    # 3.1.3 维度分析
    # 3.1.3.1 度数，x 的第 2 个维度的大小
    nterm = 1 if ndimX == 1 else sx[ndimX - 1]

    # 3.1.3.2 光谱流量维度检查
    nptsX = sx[0]
    npts = sy[0]
    if nptsX != npts:
        raise ValueError('X and Y have incompatible dimensions!')

    # 3.1.4 数组预处理
    # 3.1.4.1 确保 x 为二维数组
    if ndimX == 1:
        x = x.reshape(npts, 1)
    # 3.1.4.2 计算自由度
    nfree = npts - 1

    # 3.2 回归计算
    # 3.2.1 初始化计算状态
    invert = False
    if inv is not None:
        if len(getattr(inv, 'wx', [])) != len(y):
            invert = True
    with np.errstate(under='ignore'):
        # 3.2.2 权重计算及处理
        if inv is None or invert:
            # 3.2.2.1 基于测量误差计算权重
            weights = 1 / (measure_errors ** 2)
            sw = np.sum(weights) / npts
            weights = weights / sw
            wgt = np.tile(weights, (nterm, 1)).T

            # 3.2.2.2 数据加权
            wx = wgt * x
            sigmax = np.sqrt(np.sum(x * wx, axis=0) / nfree)

            # 3.2.3 系数矩阵计算
            # 3.2.3.1 构建待求逆的系数矩阵
            ar = np.dot(wx.T, x) / (nfree * np.outer(sigmax, sigmax))
            try:
                # 3.2.3.2 计算矩阵的伪逆
                # 注意：IDL 与 Python 计算逆矩阵时可能存在数值精度差异，但对最终参数推断影响很小
                ar = np.linalg.pinv(ar)
                # ar = sc.linalg.lu_solve(sc.linalg.lu_factor(ar), np.eye(ar.shape[0]))
                status = 0
            except np.linalg.LinAlgError:
                # 3.2.3.3 矩阵求逆失败处理
                if status is None:
                    raise ValueError("Inversion failed due to singular array!")
                status = 1

            # 3.2.4 误差分析
            # 3.2.4.1 计算 sigma
            sigma = np.diag(ar) / (sw * nfree * sigmax ** 2)
            # 3.2.4.2 处理负 sigma
            neg = np.where(sigma < 0)[0].tolist()
            if len(neg) > 0:
                sigma[neg] = 0
                status = 2
                warnings.warn("Pseudo-continuum is less than 0, please check the spectrum quality!")
            sigma = np.sqrt(sigma)

            # 3.2.5 保存计算结果与中间值
            inv = {'a': ar, 'ww': weights, 'wx': wx, 'sx': sigmax, 'sigma': sigma, 'status': status}

        # 3.2.6 相关性计算
        sigmay = np.sqrt(np.sum(inv['ww'] * (y ** 2)) / nfree)
        correlation = np.dot(inv['wx'].T, y) / (inv['sx'] * sigmay * nfree)

        # 3.2.7 回归系数计算
        # a = correlation @ inv['a'] * (sigmay / inv['sx'])
        a = np.dot(inv['a'], correlation) * (sigmay / inv['sx'])
    
    # 3.3 返回计算结果
    return a, inv


# 4. 使用示例
# if __name__ == "__main__":

#     # 4.1 创建测试数据集
#     X = np.array([[2, 4, 6], [3, 5, 9], [10, 11, 3]])
#     y = np.array([6, 8, 11])

#     # 4.2 设置测量误差参数
#     measure_errors = np.array([1, 1, 1])

#     # 4.3 执行多元回归分析
#     a = mregress(X, y, measure_errors)
#     print(a)