# -*- coding: utf-8 -*-
# @Time    : 2025/1/5 15:16
# @Author  : ljc
# @FileName: uly_tgm_eval_pytorch.py
# @Software: PyCharm


# 1. 简介
"""
目的:
     由多项式函数 TGM 批量生成一组 N 条模型光谱:
        F = a1*Teff+a2*logg+a3[Fe/H]+a4*Teff*Teff+.... = torch.matmul(param, spec_coef)
        F 为多项式光谱模拟器计算得到的 N 条模型光谱;
        param 为多项式函数中的 Teff, log g, [Fe/H] 组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,...;
        spec_coef 为多项式函数的系数矩阵 a1, a2, a3,....
函数:
    1) spec_param_version_2_warm、spec_param_version_2_hot、spec_param_version_2_cold (2 指的是 ELODIE 版本 3.2)
    2) uly_tgm_eval
解释:
    1) spec_param_version_2_warm、spec_param_version_2_hot、spec_param_version_2_cold 函数: 批量计算一组 N 条数据的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
    2) uly_tgm_eval 函数: 批量计算 1 组 N 条模型光谱.
注意:
    1) uly_tgm_eval 函数: 由于每组 N 条 LAMOST 光谱可能包含不同 Teff 范围的样本, 因此针对每组光谱, 首先按 Teff 范围分组
     (<= 4000, 4000-4550, 4550-7000, 7000-9000, >=9000)、然后在特定 Teff 范围内批量计算多项式函数的基函数值.
    2) 由于 TGM 模型按 Teff 范围分段, 因此在区间端点处可能不可导, 进而导致部分位于区间端点的样本参数推断不够稳定.
       如: LASP-Adam-GPU 对初始 Teff 具有方向敏感性.
"""


# 2. 导库
from config.config import default_set, set_all_seeds
import torch
# 2.1 设置随机种子
set_all_seeds()
# 2.2 调用 GPU、指定数据类型
dtype, device = default_set()
# 2.3 默认数据类型
torch.set_default_dtype(dtype)


# 3. 批量计算 warm 恒星 Teff、log g、[Fe/H] 样本矩阵组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
def spec_param_version_2_warm(labels_batch, group_size) -> torch.Tensor:

    """
      批量计算 warm 恒星 Teff、log g、[Fe/H] 样本矩阵组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,....

      输入参数：
      -----------
      labels_batch:
                   Teff、log g、[Fe/H] 数组.
      group_size:
                 一组多少样本.

      输出参数:
      -----------
      param:
            由 Teff、log g、[Fe/H] 样本矩阵计算的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
    """

    # 3.1 初始化 param 并使用矢量化的方式计算所有参数组成
    teff = labels_batch[:, 0]
    gravi = labels_batch[:, 1]
    feh = labels_batch[:, 2]
    # 3.2 初始化参数组合矩阵
    param = torch.zeros((23, group_size))
    # 3.3 填充 param 的各项
    tt = teff / 0.2
    tt2 = tt ** 2 - 1.0
    param[0, :] = 1.0
    param[1, :] = tt
    param[2, :] = feh
    param[3, :] = gravi
    param[4, :] = tt ** 2
    param[5, :] = tt * tt2
    param[6, :] = tt2 ** 2
    param[7, :] = tt * feh
    param[8, :] = tt * gravi
    param[9, :] = tt2 * gravi
    param[10, :] = tt2 * feh
    param[11, :] = gravi ** 2
    param[12, :] = feh ** 2
    param[13, :] = tt * (tt2 ** 2)
    param[14, :] = tt * (gravi ** 2)
    param[15, :] = gravi ** 3
    param[16, :] = feh ** 3
    param[17, :] = tt * (feh ** 2)
    param[18, :] = gravi * feh
    param[19, :] = (gravi ** 2) * feh
    param[20, :] = gravi * (feh ** 2)
    param[21, :] = torch.exp(tt) - 1 - tt * (1 + tt / 2 + tt ** 2 / 6 + tt ** 3 / 24 + tt ** 4 / 120)
    param[22, :] = torch.exp(2 * tt) - 1 - 2 * tt * (1 + tt + 2 / 3 * tt ** 2 + tt ** 3 / 3 + tt ** 4 * 2 / 15)

    # 3.4 返回 param
    return param.to(device)


# 4. 批量计算 hot 恒星 Teff、log g、[Fe/H] 样本矩阵组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
def spec_param_version_2_hot(labels_batch, group_size) -> torch.Tensor:

    """
      批量计算 hot 恒星 Teff、log g、[Fe/H] 样本矩阵组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,.....

      输入参数：
      -----------
      labels_batch:
                   Teff、log g、[Fe/H] 数组.
      group_size:
                 一组多少样本.

      输出参数:
      -----------
      param:
            由 Teff、log g、[Fe/H] 样本矩阵计算的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
    """

    # 4.1 初始化 param 并使用矢量化的方式计算所有参数组成
    teff = labels_batch[:, 0]
    gravi = labels_batch[:, 1]
    feh = labels_batch[:, 2]
    # 4.2 初始化参数组合矩阵
    param = torch.zeros((23, group_size))
    # 4.3 填充 param 的各项
    tt = teff / 0.2
    tt2 = tt ** 2 - 1.0
    param[0, :] = 1.0
    param[1, :] = tt
    param[2, :] = feh
    param[3, :] = gravi
    param[4, :] = tt ** 2
    param[5, :] = tt * tt2
    param[6, :] = tt2 ** 2
    param[7, :] = tt * feh
    param[8, :] = tt * gravi
    param[9, :] = tt2 * gravi
    param[10, :] = tt2 * feh
    param[11, :] = gravi ** 2
    param[12, :] = feh ** 2
    param[13, :] = tt * (tt2 ** 2)
    param[14, :] = tt * (gravi ** 2)
    param[15, :] = gravi ** 3
    param[16, :] = feh ** 3
    param[17, :] = tt * (feh ** 2)
    param[18, :] = gravi * feh
    param[19, :] = (gravi ** 2) * feh
    param[20, :] = gravi * (feh ** 2)
    param[21, :] = torch.exp(tt) - 1 - tt * (1 + tt / 2 + tt ** 2 / 6 + tt ** 3 / 24 + tt ** 4 / 120)
    param[22, :] = torch.exp(2 * tt) - 1 - 2 * tt * (1 + tt + 2 / 3 * tt ** 2 + tt ** 3 / 3 + tt ** 4 * 2 / 15)

    # 4.4 返回 param
    return param.to(device)


# 5. 批量计算 cold 恒星 Teff、log g、[Fe/H] 样本矩阵组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
def spec_param_version_2_cold(labels_batch, group_size) -> torch.Tensor:

    """
      批量计算 cold 恒星 Teff、log g、[Fe/H] 样本矩阵组成的基函数值 Teff, logg, [Fe/H], Teff*Teff,....

      输入参数：
      -----------
      labels_batch:
                   Teff、log g、[Fe/H] 数组.
      group_size:
                 一组多少样本.

      输出参数:
      -----------
      param:
            由 Teff、log g、[Fe/H] 样本矩阵计算的基函数值 Teff, logg, [Fe/H], Teff*Teff,....
    """

    # 5.1 初始化 param 并使用矢量化的方式计算所有参数组成
    teff = labels_batch[:, 0]
    gravi = labels_batch[:, 1]
    feh = labels_batch[:, 2]
    # 5.2 初始化参数组合矩阵
    param = torch.zeros((23, group_size))
    # 5.3 填充 param 的各项
    # ELODIE 版本 2, 冷星参数矩阵计算时需要对 Teff 加 0.1
    tt = (teff + 0.1) / 0.2
    tt2 = tt ** 2 - 1.0
    param[0, :] = 1.0
    param[1, :] = tt
    param[2, :] = feh
    param[3, :] = gravi
    param[4, :] = tt ** 2
    param[5, :] = tt * tt2
    param[6, :] = tt2 ** 2
    param[7, :] = tt * feh
    param[8, :] = tt * gravi
    param[9, :] = tt2 * gravi
    param[10, :] = tt2 * feh
    param[11, :] = gravi ** 2
    param[12, :] = feh ** 2
    param[13, :] = tt * (tt2 ** 2)
    param[14, :] = tt * (gravi ** 2)
    param[15, :] = gravi ** 3
    param[16, :] = feh ** 3
    param[17, :] = tt * (feh ** 2)
    param[18, :] = gravi * feh
    param[19, :] = (gravi ** 2) * feh
    param[20, :] = gravi * (feh ** 2)
    param[21, :] = torch.exp(tt) - 1 - tt * (1 + tt / 2 + tt ** 2 / 6 + tt ** 3 / 24 + tt ** 4 / 120)
    param[22, :] = torch.exp(2 * tt) - 1 - 2 * tt * (1 + tt + 2 / 3 * tt ** 2 + tt ** 3 / 3 + tt ** 4 * 2 / 15)

    # 5.4 返回 param
    return param.to(device)


# 6. 批量生成指定 Teff、log g、[Fe/H] 下的一组 N 条模型光谱
def uly_tgm_eval(spec_coef, para) -> torch.Tensor:

    """
      批量生成指定 Teff、log g、[Fe/H] 下的一组 N 条模型光谱.

      输入参数：
      -----------
      spec_coef:
                TGM 模型系数矩阵.
      para:
           spec_param_version_2_warm、spec_param_version_2_hot、spec_param_version_2_cold 返回基函数值 Teff, logg, [Fe/H], Teff*Teff,....

      输出参数:
      -----------
      tgm_model_evalhc:
                       一组 N 条模型光谱.
    """

    # 6.1 获取每组 LAMOST 光谱的数量 N、以及多项式光谱模拟器的系数矩阵大小 (流量维度)
    group_size = para.shape[0]
    n_wavelengths = spec_coef.shape[1]

    # 6.2 获取每组 LAMOST 光谱在每轮迭代时的 Teff 值, 依此分段计算
    teff = para[:, 0]

    # 6.3 预计算所有的临界点
    t_4000 = torch.log10(torch.tensor(4000.)) - 3.7617
    t_4550 = torch.log10(torch.tensor(4550.)) - 3.7617
    t_7000 = torch.log10(torch.tensor(7000.)) - 3.7617
    t_9000 = torch.log10(torch.tensor(9000.)) - 3.7617

    # 6.4 初始化 N 条模型光谱
    result = torch.zeros((group_size, n_wavelengths), device=device)

    # 6.5 区域一: 低温区域 (<=4000)
    cold_only_mask = (teff <= t_4000)
    if cold_only_mask.any():
        param_cold = spec_param_version_2_cold(labels_batch=para[cold_only_mask],    # 输入大气参数
                                               group_size=cold_only_mask.sum()       # 每组待测光谱有多少样本
                                               )
        # 光谱模拟器批量计算模型光谱
        t3 = torch.matmul(param_cold.T, spec_coef[:, :, 2])
        result[cold_only_mask, :] = t3

    # 6.6 区域二: 过渡区域 (4000-4550)
    trans1_mask = ((teff > t_4000) & (teff <= t_4550))
    if trans1_mask.any():
        para_trans1 = para[trans1_mask]
        param_cold = spec_param_version_2_cold(labels_batch=para_trans1,     # 输入大气参数
                                               group_size=trans1_mask.sum()  # 每组待测光谱有多少样本
                                               )
        param_warm = spec_param_version_2_warm(labels_batch=para_trans1,     # 输入大气参数
                                               group_size=trans1_mask.sum()  # 每组待测光谱有多少样本
                                               )
        # 光谱模拟器批量计算模型光谱
        t1 = torch.matmul(param_warm.T, spec_coef[:, :, 0])
        t3 = torch.matmul(param_cold.T, spec_coef[:, :, 2])
        q = ((teff[trans1_mask] - t_4000) / (t_4550 - t_4000)).unsqueeze(1)
        result[trans1_mask, :] = q * t1 + (1. - q) * t3
        # result[trans1_mask, :] = torch.lerp(t3, t1, q)

    # 6.7 区域三: 暖区域 (4550-7000)
    warm_only_mask = ((teff > t_4550) & (teff <= t_7000))
    if warm_only_mask.any():
        param_warm = spec_param_version_2_warm(labels_batch=para[warm_only_mask],    # 输入大气参数
                                               group_size=warm_only_mask.sum()       # 每组待测光谱有多少样本
                                               )
        # 光谱模拟器批量计算模型光谱
        t1 = torch.matmul(param_warm.T, spec_coef[:, :, 0])
        result[warm_only_mask, :] = t1

    # 6.8 区域四: 过渡区域 (7000-9000)
    trans2_mask = ((teff > t_7000) & (teff < t_9000))
    if trans2_mask.any():
        para_trans2 = para[trans2_mask]
        param_warm = spec_param_version_2_warm(labels_batch=para_trans2,     # 输入大气参数
                                               group_size=trans2_mask.sum()  # 每组待测光谱有多少样本
                                               )
        param_hot = spec_param_version_2_hot(labels_batch=para_trans2,       # 输入参数
                                             group_size=trans2_mask.sum()    # 每组待测光谱有多少样本
                                             )
        # 光谱模拟器批量计算模型光谱
        t1 = torch.matmul(param_warm.T, spec_coef[:, :, 0])
        t2 = torch.matmul(param_hot.T, spec_coef[:, :, 1])
        q = ((teff[trans2_mask] - t_7000) / (t_9000 - t_7000)).unsqueeze(1)
        result[trans2_mask, :] = q * t2 + (1. - q) * t1
        # result[trans2_mask, :] = torch.lerp(t1, t2, q)

    # 6.9 区域五: 高温区域 (>=9000)
    hot_only_mask = (teff >= t_9000)
    if hot_only_mask.any():
        param_hot = spec_param_version_2_hot(labels_batch=para[hot_only_mask],    # 输入大气参数
                                             group_size=hot_only_mask.sum()       # 每组待测光谱有多少样本
                                             )
        # 光谱模拟器批量计算模型光谱
        t2 = torch.matmul(param_hot.T, spec_coef[:, :, 1])
        result[hot_only_mask, :] = t2

    # 6.10 返回光谱模拟器生成的一组 N 条模型光谱
    return result.to(device)