# -*- coding: utf-8 -*-
# @Time    : 2025/2/25 15:54
# @Author  : ljc
# @FileName: loss_reduced.py
# @Software: PyCharm


# 1. 简介
"""
目的：
    计算每组 LAMOST 光谱的流量拟合残差, 用于计算每个样本的雅克比矩阵、黑塞矩阵、参数误差.
函数:
    1) loss_reduced
解释:
    1) loss_reduced 函数: 计算每组 LAMOST 光谱的流量拟合残差, 用于计算每个样本的雅克比矩阵、黑塞矩阵、参数误差. 
注意:
    1) 计算雅克比矩阵时, loss_reduced 函数返回的 loss 为真实光谱-模型光谱.
    2) 计算黑塞矩阵时, loss_reduced 函数返回的 loss 为真实光谱-模型光谱的平方和.
    3) 计算参数误差时, loss_reduced 函数返回的 loss 为真实光谱-模型光谱的平方和.
    4) 默认使用 No Clean 模式, 即不剔除异常流量点. 也可设置 goodPixels_final=True, 使用 Clean 模式, 即剔除异常流量点.
"""


# 2. 调包
from config.config import default_set, set_all_seeds
import torch
from legendre_polynomial.mregress_pytorch import mregress_batch_cholesky
from resolution_reduction.convol_pytorch import convol
from uly_tgm_eval.uly_tgm_eval_pytorch import uly_tgm_eval
from WRS.xrebin_pytorch import xrebin
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)
# 2.1 设置随机种子、调用 GPU、数据类型、匹配的波长范围
# 2.1.1 设置随机种子
set_all_seeds()
# 2.1.2 调用 GPU、指定数据类型
dtype, device = default_set()
# 2.1.3 默认数据类型
torch.set_default_dtype(dtype)


# 3. 计算每组 LAMOST 光谱的流量拟合残差, 用于计算每个样本的雅克比行列式、黑塞矩阵、参数误差
def loss_reduced(best_params, specs, spec_coef, borders_, NewBorders_, flat, leg_array_, goodPixels_final=False, Jacobian=True, Hessian=False) -> torch.Tensor:

    """
       生成指定 Teff、log g、[Fe/H]下的 TGM 光谱.

       输入参数：
          -----------
          best_params:
                     使用 Adam 获取的最佳参数(注意: 还没有转化为恒星大气物理参数).
          specs:
               待测光谱数据.
          spec_coef:
                    TGM 系数.
          borders_、NewBorders_、flat:
                                    待测光谱插值到 TGM 光谱的波长设置.
          leg_array_:
                    勒让德多项式值.
          goodPixels_final=False
                    默认迭代优化时, 不剔除异常流量点.
          Jacobian=True
                       默认使用雅克比矩阵计算参数误差.
          Hessian=False
                       默认不使用黑塞矩阵计算参数误差.

          输出参数:
          -----------
          loss:
               每组 LAMOST 光谱的流量拟合残差, 用于计算每个样本的雅克比行列式、黑塞矩阵、参数误差.
    """

    # 3.1 生成 TGM 模型光谱
    TGM_model_predict_spectra = uly_tgm_eval(spec_coef.to(device, dtype=dtype),               # TGM 模型光谱系数
                                             best_params[:, :3]                               # 待测恒星参数
                                             )

    # 3.2 插值到 LAMOST 波长
    TGM_model_predict_spectra_xrebin = (xrebin(borders_,                                      # 输入 TGM 模型光谱的波长区间范围
                                               TGM_model_predict_spectra,                     # 输入 TGM 模型光谱的流量
                                               NewBorders_                                    # 输出 LAMOST 光谱的波长区间范围
                                               ) / flat).to(device, dtype=dtype)

    # 3.3 降低分辨率
    low_resolution_spec = convol(TGM_model_predict_spectra_xrebin,                            # 输入重采样后的 TGM 模型光谱的流量
                                 best_params[:, 3].reshape(-1, 1),                            # 输入 losvd 第 1 个参数
                                 best_params[:, 4].reshape(-1, 1)                             # 输入 losvd 第 2 个参数
                                 ).to(device, dtype=dtype)

    # 3.4 乘以多项式
    coefs_pol = mregress_batch_cholesky(leg_array_[:, :-2, :] * low_resolution_spec[:, :-2],  # 输入勒让德多项式值与重采样后的 TGM 模型光谱的乘积
                                        specs[:, :-2]                                         # 输入 LAMOST 光谱的流量
                                        ).unsqueeze(1)
    poly1 = torch.matmul(coefs_pol, leg_array_.transpose(1, 2)).squeeze(1).unsqueeze(2)
    polynomial_multiply_TGM_model_predict_spectra = low_resolution_spec * poly1

    # 3.5 每组 LAMOST 与 TGM 模型光谱的流量残差
    # 3.5.1 如果使用 Jacobian 矩阵计算参数误差(默认使用)
    if Jacobian is True:
        if goodPixels_final is False:
            # 3.5.1.1 不迭代剔除异常流量(默认不剔除)
            loss = (specs[:, :-2] - polynomial_multiply_TGM_model_predict_spectra.squeeze(-1)[:, :-2])
        else:
            # 3.5.1.2 迭代剔除异常流量(后续版本, 可开发更合适的异常剔除方案!)
            loss = (specs[:, :-2] - polynomial_multiply_TGM_model_predict_spectra.squeeze(-1)[:, :-2]) * goodPixels_final
    # 3.5.2 如果使用 Hessian 矩阵计算参数误差
    if Hessian is True:
        if goodPixels_final is False:
            # 3.5.2.1 不迭代剔除异常流量(默认不剔除)
            loss = torch.sum((specs[:, :-2] - polynomial_multiply_TGM_model_predict_spectra.squeeze(-1)[:, :-2]) ** 2, dim=1)
        else:
            # 3.5.2.2 迭代剔除异常流量(后续版本, 可开发更合适的异常剔除方案!)
            loss = torch.sum((specs[:, :-2] - polynomial_multiply_TGM_model_predict_spectra.squeeze(-1)[:, :-2]) * goodPixels_final, dim=1)

    # 3.6 返回各待测光谱的流量拟合残差(真实光谱-拟合光谱), 用于计算每个样本的雅克比行列式、黑塞矩阵、参数误差.
    return loss