# -*- coding: utf-8 -*-
# @Time    : 04/12/2024 10.44
# @Author  : ljc
# @FileName: uly_tgm_eval.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL uly_tgm_eval.pro .
目的:
    生成指定 Teff、log g、[Fe/H]下的 TGM 光谱.
函数:
    1) uly_tgm_model_param
    2) uly_tgm_eval
解释:
    1) uly_tgm_model_param 函数: 生成指定 Teff、log g、[Fe/H] 下的 TGM 多项式参数组合.
    2) uly_tgm_eval 函数: 生成指定 Teff、log g、[Fe/H] 下的 1 条模型光谱, 并重采样到 LAMOST 波长.
"""


# 2. 掉包
import numpy as np
from uly_read_lms.uly_spect_alloc import uly_spect_alloc
from WRS.uly_spect_logrebin import uly_spect_logrebin


# 3. 生成指定 Teff、log g、[Fe/H] 下的 TGM 多项式参数组合.
# 注意:
# 1) 由于 IDL 与 Python 的浮点数精度有差异, 因此计算结果会存在一些差异.
# 2) Teff 取对数 (log10(*)).
def uly_tgm_model_param(version, teff, gravi, fehi) -> np.ndarray:

    """
      生成指定 Teff、log g、[Fe/H] 下的 TGM 多项式参数组合.

      输入参数:
      -----------
      version:
              TGM 模型版本号.
      teff:
           log10 对数下的 Teff.
      gravi:
            表面重力 log g.
      fehi:
            金属丰度 [Fe/H].

      输出参数:
      -----------
      param:
            指定 Teff、log g、[Fe/H] 下的 TGM 多项式参数组合.
    """

    # 3.1 根据不同版本获取 TGM 的参数矩阵.
    # V2 的 warm、hot、cold 中 TGM 光谱模拟器系数均为 14501*26 的矩阵, 参数矩阵均为 26 行 3 列, 注意: 多项式系数有全 0 项
    if (version == 1) or (version == 2):
        np_s = 23
    if version == 3:
        np_s = 26
    # 3.1.1 创建 (np_s, 3) 的参数矩阵
    param = np.zeros((np_s, 3), dtype=float)

    # 3.2 遍历每一组参数, 并计算各多项式项
    for i in range(3):
        teffc = teff
        if (version != 1) & (i == 2):
            # 如果是第 3 组参数 (即 cold 恒星), teffc 增加 0.1
            teffc = teff + 0.1
        grav = gravi
        feh = fehi

        # 3.2.1 计算缩放温度
        tt = teffc / 0.2
        # 3.2.2 计算 tt 的平方减去 1
        tt2 = tt ** 2 - 1.0
        # 3.2.3 填充参数矩阵的每个项
        param[0, i] = 1.0
        param[1, i] = tt
        param[2, i] = feh
        param[3, i] = grav
        param[4, i] = tt ** 2
        param[5, i] = tt * tt2
        param[6, i] = tt2 ** 2
        param[7, i] = tt * feh
        param[8, i] = tt * grav
        param[9, i] = tt2 * grav
        param[10, i] = tt2 * feh
        param[11, i] = grav ** 2
        param[12, i] = feh ** 2
        param[13, i] = tt * (tt2 ** 2)
        param[14, i] = tt * (grav ** 2)
        param[15, i] = grav ** 3
        param[16, i] = feh ** 3
        param[17, i] = tt * (feh ** 2)
        param[18, i] = grav * feh
        param[19, i] = (grav ** 2) * feh
        param[20, i] = grav * (feh ** 2)
        # 3.2.4 版本 1
        if version == 1:
            param[21, i] = np.exp(tt)
            param[22, i] = (np.exp(tt ** 2))
        # 3.2.5 版本 2
        if version == 2:
            param[21, i] = np.exp(tt) - 1 - tt * (1 + tt / 2 + tt ** 2 / 6 + tt ** 3 / 24 + tt ** 4 / 120)
            param[22, i] = np.exp(tt * 2) - 1 - 2 * tt * (1 + tt + 2 / 3 * tt ** 2 + tt ** 3 / 3 + tt ** 4 * 2 / 15)
        # 3.2.6 版本 3
        if version == 3:
            param[21, i] = tt * tt2 * grav
            param[22, i] = tt2 * tt2 * grav
            param[23, i] = tt2 * tt * feh
            param[24, i] = tt2 * (grav ** 2)
            param[25, i] = tt2 * (grav ** 3)

    # 3.3 返回参数矩阵
    return param


# 4. 生成指定 Teff、log g、[Fe/H] 下的 1 条模型光谱, 并重采样到 LAMOST 波长.
def uly_tgm_eval(eval_data, para, sampling_function=None) -> np.ndarray:

    """
      生成指定 Teff、log g、[Fe/H] 下的 1 条模型光谱, 并重采样到 LAMOST 波长.

      输入参数:
      -----------
      eval_data:
                TGM 多项式系数矩阵.
      para:
           uly_tgm_model_param 返回的参数矩阵.
      sampling_function:
                       插值方法. 可输入 "splinf", "cubic", "slinear", "quadratic", "linear". 默认使用 "linear" 插值方法.

      输出参数:
      -----------
      tgm_model_evalhc:
                       指定 Teff、log g、[Fe/H] 下的 TGM 光谱.
    """

    # 4.1 获取 TGM 模型系数矩阵, 形状为 (26, 7506, 3)
    spec_coef = eval_data["spec_coef"]
    # 4.1.1 获取 TGM 模型系数矩阵的流量维度
    spec_npix = spec_coef.shape[1]
    
    # 4.2 计算指定 Teff、log g、[Fe/H] 下的 TGM 多项式参数组合
    # 4.2.1 计算缩放温度
    teff = np.log10(np.exp(para[0]))-3.7617
    # 4.2.2 计算表面重力
    grav = para[1] - 4.44
    # 4.2.3 计算指定 Teff、log g、[Fe/H]下的 TGM 光谱系数
    param = uly_tgm_model_param(int(eval_data["version"]),  # TGM 模型版本号
                                teff,                       # 缩放的 Teff
                                grav,                       # 缩放的 log g
                                para[2]                     # [Fe/H]
                                )

    # 4.3 根据指定 Teff、log g、[Fe/H] 下的 TGM 多项式参数组合与 TGM 模型系数矩阵, 计算指定 Teff、log g、[Fe/H]下的 TGM 光谱
    # 4.3.1 获取 TGM 多项式系数矩阵维度
    np_s = param.shape[0]
    # 4.3.2 计算指定 Teff、log g、[Fe/H] 下的 TGM 光谱
    if teff <= np.log10(9000)-3.7617:
        # warm 恒星
        t1 = spec_coef[:np_s, :, 0].T
        t1 = np.dot(t1, param[:, 0].reshape(-1, 1))
    if teff >= np.log10(7000)-3.7617:
        # hot 恒星
        t2 = spec_coef[:np_s, :, 1].T
        t2 = np.dot(t2, param[:, 1].reshape(-1, 1))
    if teff <= np.log10(4550)-3.7617:
        # cold 恒星
        t3 = spec_coef[:np_s, :, 2].T
        t3 = np.dot(t3, param[:, 2].reshape(-1, 1))
    # 4.3.3 插值或选择合适的光谱结果
    if teff <= np.log10(7000)-3.7617:
        if teff > np.log10(4550)-3.7617:
            tgm_model_evalhc = t1
        elif teff > np.log10(4000)-3.7617:
            # 计算插值因子 q
            q = (teff - np.log10(4000)+3.7617) / (np.log10(4550) - np.log10(4000))
            tgm_model_evalhc = q * t1 + (1. - q) * t3
        else:
            tgm_model_evalhc = t3
    elif teff >= np.log10(9000)-3.7617:
        tgm_model_evalhc = t2
    else:
        # 计算插值因子 q
        q = (teff - np.log10(7000)+3.7617) / (np.log10(9000) - np.log10(7000))
        tgm_model_evalhc = q * t2 + (1. - q) * t1

    # 4.4 将模型光谱重采样到 LAMOST 光谱的波长
    # eval_data["start"] 为 LAMOST 光谱的波长起点、eval_data["mod_start"] 为 TGM 光谱的波长起点
    # 如果模型光谱与待测光谱波长一致, 则不需要重采样, 否则将模型光谱重采样到待测光谱波长
    if (eval_data["sampling"] != eval_data["mod_samp"]) | (eval_data["start"] != eval_data["mod_start"]) | (
            eval_data["step"] != eval_data["mod_step"]) | (eval_data["npix"] != spec_npix):
        # 4.4.1 更新 LAMOST 光谱字典结构中的数据
        spec = uly_spect_alloc(DATA=tgm_model_evalhc,          # TGM 光谱流量
                               START=eval_data["mod_start"],   # TGM 光谱波长起点
                               STEP=eval_data["mod_step"],     # TGM 光谱波长步长
                               SAMPLING=eval_data["mod_samp"]  # TGM 光谱波长采样方式
                               )
        # 4.4.2 计算 LAMOST 光谱的波长范围
        wrange = [eval_data["start"], eval_data["start"] + eval_data["npix"] * eval_data["step"]]
        # 4.4.3 如果 LAMOST 光谱的波长采样方法为 ln 对数波长, 则将 LAMOST 光谱的波长范围转换为线性波长
        if eval_data["sampling"] == 1:
            wrange = np.exp(wrange)
        # 4.4.4 计算 LAMOST 光谱的 velscale, velscale = ln_step * c = log10_step * c * ln(10)
        c = 299792.458
        velscale = eval_data["step"] * c
        # 4.4.5 将 TGM 光谱波长重采样到 LAMOST 光谱的波长
        if eval_data["sampling"] == 1:
            # TGM 光谱波长重采样到 LAMOST 光谱的波长, 如: 7506 维度插值到 1327 个维度
            spec = uly_spect_logrebin(spec,                                # TGM 光谱字典结构
                                      velscale,                            # ln 对数采样下的速度
                                      waverange=wrange,                    # 波长范围
                                      sampling_function=sampling_function, # 插值方法
                                      overwrite=True                       # 是否覆盖
                                      )
        # else:
            # spec = uly_spect_linrebin(spec, eval_data["step"], sampling_function=sampling_function, WAVERANGE=wrange, /OVER)
        # 4.4.6 获取波长重采样后的 TGM 光谱
        tgm_model_evalhc = spec["data"]

    # 4.5 下属代码块为 IDL 代码, LASP 没有使用, 如果需要, 请参考 IDL 代码
    """
    if eval_data["calibration"] == "C":
        n = tgm_model_evalhc.size
        if eval_data["sampling"] == 1:
            wavelength = np.exp(eval_data["start"] + np.arange(n, dtype=np.float64) * eval_data["step"])
        else:
            wavelength = eval_data["start"] + np.arange(n, dtype=np.float64) * eval_data["step"]

        w, c3, c1 = 5550., 1.43883e8 / 5550 / np.exp(np.log(para[0])), 3.74185e19 / 5550 ** 5
        if c3 < 50:
            bbm = c1 / (np.exp(c3) - 1)
        else:
            bbm = c1 * np.exp(-c3)
        c3, c1 = 1.43883e8 / wavelength / np.exp(np.log(para[0])), 3.74185e19 / wavelength ** 5 / bbm
        n1, n2 = np.where(c3 < 50)[0].tolist(), np.where(c3 >= 50)[0].tolist()
        if len(n1) > 0:
            tgm_model_evalhc[n1] = tgm_model_evalhc[n1] * (c1[n1] / (np.exp(c3[n1]) - 1))
        if len(n1) < n:
            tgm_model_evalhc[n2] = tgm_model_evalhc[n2] * (c1[n2] / np.exp(-c3[n2]))

    if eval_data.get("lsf") is not None:
        if eval_data["lsf"] != 'no_lsf':
            spec = uly_spect_alloc(DATA=tgm_model_evalhc, START=eval_data["start"], STEP=eval_data["step"], SAMPLING=1)
            tgm_model_evalhc = spec["data"]
    """

    # 4.6 返回 1 条经过波长重采样后的模型光谱
    return tgm_model_evalhc