# -*- coding: utf-8 -*-
# @Time    : 10/12/2024 14.31
# @Author  : ljc
# @FileName: ulyss.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL ulyss.pro .
目的:
    通过 uluss 函数设置待测光谱文件位置、恒星参数初始值、波长范围、多项式阶数等信息, 传入 uly_fit_IDL_method 函
    数获取待测光谱的参数、以及误差推断值.
函数:
    1) uly_cmp_read
    2) ulyss
解释:
    1) uly_cmp_read 函数: 返回 cmp 初始结构信息.
    2) ulyss 函数: 调用 uly_fit_IDL_method 函数获取待测光谱的参数、以及误差推断值.
"""


# 2. 调库
import numpy as np
from astropy.io import fits
from uly_tgm.uly_tgm import uly_tgm
from uly_read_lms.uly_spect_extract import uly_spect_extract
from WRS.uly_spect_logrebin import uly_spect_logrebin
from uly_read_lms.uly_spect_get import uly_spect_get
from uly_fit.uly_fit_init import uly_fit_init
from uly_tgm_alone.ulyss_fit_IDL_method import uly_fit_IDL_method
import warnings
warnings.filterwarnings("ignore")


# 3. 读取 TGM 模型文件, 并返回 cmp 初始结构信息
def uly_cmp_read(model_file, t_guess=None, l_guess=None, z_guess=None) -> dict:

    """
        获取 cmp 初始结构信息.

        输入参数:
        -----------
        model_file:
                   TGM 模型文件位置.
        t_guess:
                Teff 初始值.
        l_guess:
                log g 初始值.
        z_guess:
                [Fe/H] 初始值.

        输出参数:
        -----------
        cmp:
            cmp 初始字典结构信息.
    """

    # 3.1 如果 model_file 不是字符串, 则抛出类型错误
    if not isinstance(model_file, str):
        raise TypeError("Argument model_file must be a filename!")

    # 3.2 读取 TGM 模型文件
    try:
        with fits.open(model_file) as hdul:
            header = hdul[0].header
    except Exception as e:
        raise IOError(f"Error reading file: {str(e)}!")

    # 3.3 检查表头中的 ULY_TYPE 键并据此分支处理
    # 注意: 
    # 1) LASP 使用的 uly_type="TGM"
    # 2) 如有需要其他类型, 请查看 IDL 代码
    # 3) 关于非 TGM 模型, 如 SSP 和 STAR, 留给后续版本
    # 3.3.1 获取 ULY_TYPE 键值
    uly_type = header.get('ULY_TYPE', '').strip().upper()
    # 3.3.2 如果 ULY_TYPE 键值为 TGM, 则返回 TGM 字典结构
    if uly_type == 'TGM':
        # 3.3.2.1 返回 TGM 字典结构
        return uly_tgm(model_file=model_file, # 模型光谱文件
                       t_guess=t_guess,       # Teff 初始值
                       l_guess=l_guess,       # log g 初始值
                       z_guess=z_guess        # [Fe/H] 初始值
                       )
    # 3.3.3 如有需要其他类型, 请查看 IDL 代码
    else:
        raise ValueError(f"Current version only supports TGM model. Invalid ULY_TYPE: {uly_type}!")


# 4. 调用 ulf_fit 函数获取待测光谱的参数、以及误差推断值
def ulyss(inspectr=None, cmp=None, model_file=None, position=None, 
          err_sp=None, snr=None, shift_guess=None, kmoment=None, 
          sigma_guess=None, sigma_limits=None, velscale=None, 
          adegree=None, mdegree=None, polpen=None, lmin=None, lmax=None,
          kfix=None, kpen=None, clean=None, quiet=False, 
          modecvg=None, allow_polynomial_reduction=False, sampling_function=None,
          t_guess=None, l_guess=None, z_guess=None, full_output=False, 
          IDL_fit_method=True, plot_fitting=False) -> tuple[float, float, float, float, float, float, float, float, float, float]:

    """
       调用 ulf_fit 函数获取待测光谱的参数、以及误差推断值.

       输入参数:
       -----------
       inspectr:
                包含 LAMOST 光谱信息的 SignalOut 字典, 详情参考 uly_spect_read_lms 函数.
       cmp:
           TGM 字典, 存储 TGM、待测参数初始值、勒让德多项式默认值等信息.
           cmp 是 1 个字典而不是字典列表, 因此 cmp 的长度记为 1.
       model_file:
                  TGM 模型文件地址.
       position:
                SignalLog 可能是二维的(比如长狭缝), 在这种情况下提取一个扫描.
                注意: LASP 没有使用该参数, 如有需要请参考 IDL 代码.
       err_sp:
              待测光谱的流量误差.
              注意: 
              1) LASP 将待测光谱流量误差设置为 1, 因此该参数无效. 考虑到参数推断是基于卡方函数, 将流量误差设置为
                1 导致参数误差不合理, 因此需要对流量误差进行无偏估计, 并改正参数误差.
              2) IDL 代码单独做了流量误差的无偏估计, 并改正了参数推断误差. 但 LASP 是基于重复观测估计了参数误差. 
                 Python 版延续了 IDL 代码的思路, 使用流量误差的无偏估计改正参数推断误差.
              3) LAMOST 光谱提供了流量倒方差, 因此该参数也是可以被设置的.
       snr:
           待测光谱的信噪比.
       shift_guess:
                   losvd 第 1 个参数的初始猜测.
       kmoment:
               高斯-厄米特矩的阶数.
               设置为 2 时仅拟合 [cz, sigma].
               设置为 4 时同时拟合 [cz, sigma, h3, h4].
               设置为 6 时同时拟合 [cz, sigma, h3, h4, h5, h6].
               注意: 
               1) LASP 设置为 2.
       sigma_guess:
                   losvd 第 2 个参数的初始猜测.
       sigma_limits:
                    losvd 第 2 个参数的边界.
       velscale:
                ln 对数波长空间在每个步长的速度, 单位为 km/s.
                注意: velscale = ln_step * c = log10_step * c * ln(10).
       adegree:
               用于修正 TGM 模板光谱形状的加法勒让德多项式的阶数. 
               注意:
               1) 在拟合过程中, 默认不使用任何加法多项式.
               2) 如果要禁用加法多项式, 请设置 adegree=-1, LASP 设置为 -1.
       mdegree:
               勒让德多项式阶数.
               注意: 
               1) ULySS 默认是 10.
               2) LASP 设置为 50.
               3) 不同任务, mdegree 的设置不同, 可参考 best_mdegree.ipynb 文件对 mdegree 进行设置.
       polpen:
              乘法多项式的偏置水平. 此关键词可用于减少乘法多项式中不重要项的影响.
              注意: 
              1) 默认情况下不应用偏置. 如果某些系数的绝对值小于 polpen 倍的统计
                 误差, 这些系数会通过因子 (abs(coef)/(polpen*err))^2 被抑制.
              2) 该功能仅在 mdegree>0 时有效, polpen=2 是一个合理的选择.
       lmin:
            待拟合的波长范围最小值.
       lmax:
            待拟合的波长范围最大值.
       kfix:
            kfix 是一个数组, 用来指定 losvd 参数(包括系统速度、速度弥散以及高阶赫莫尼克项)是否被固定.
            注意: 
            1) 0 表示自由 (待优化), 1 表示固定 (不优化).
            2) LASP 中同时优化这些参数.
       kpen:
            此参数会将 (h3, h4, ...) 的测量值偏向零, 除非其包含项显著减少了流量拟合误差.
            注意: 
            1) 默认情况下, kpen=0, 表示未启用惩罚项 (LASP 中默认设置为 0).
            2) 如果设置为严格正值, 解 (包括 cz 和 sigma) 会减少噪声. 使用惩罚时, 建议使用蒙特卡洛模拟测试 kpen 的选择.
            3) kpen 的值范围应在 0.5 到 1.0 之间, 作为好的初始猜测.
       clean:
             是否以迭代方式检测并剪切 TGM 光谱与待测光谱流量残差的离群值.
             注意:
             1) Clean=True, 表示以迭代方式检测并剪切 TGM 光谱与待测光谱流量残差的离群值.
             2) Clean=False, 表示不进行离群值检测与剪切.
       quiet:
             是否抑制屏幕上打印的消息.
             注意:
             1) quiet=True, 表示抑制屏幕上打印的消息.
             2) quiet=False, 表示不抑制屏幕上打印的消息.
       modecvg:
               指定拟合方法的收敛模式. LASP 不使用, 但我们保留了该参数. 如有需要, 请参考 IDL 代码.
               注意:
               1) modecvg=0 (默认选项), 这是最快的, 但如果解错失, 问题可能发生.
               2) modecvg=1 (每次迭代仅完成一次), 为每个 LM 迭代计算导数.
               3) modecvg=2, uly_fit_lin 始终收敛, 但速度较慢.
       allow_polynomial_reduction:
                                  是否允许多项式阶数减少. 默认值为 False, 即不允许多项式阶数减少.
                                  如果设置为 True, 则允许多项式阶数减少.
                                  注意: 伪连续谱不应该小于 0, 如果小于 0, 提供两种处理方法:
                                  1) 为避免伪连续谱存在负值, 循环减少多项式阶数, LASP IDL 版本采用这种方法.
                                  2) 直接认为该光谱质量较差, 参数推断失败, Python 版本默认采用这种方法 (因为连续谱为负值, 大概率由于光谱流量存在负值).
       sampling_function:
                        插值方法. 可输入 "splinf", "cubic", "slinear", "quadratic", "linear". 默认使用 "linear" 插值方法.
       t_guess:
               Teff 初始值.
       l_guess:
               log g 初始值.
       z_guess:
               [Fe/H] 初始值.
       full_output:
                   是否返回所有参数的拟合结果, 以及拟合信息.
                   注意:
                   1) full_output=True, 表示返回所有参数的拟合结果, 以及拟合信息.
                   2) full_output=False, 表示仅返回 Rv、Teff、log g、[Fe/H] 以及误差推断值.
       IDL_fit_method:
                      是否使用 IDL 的拟合方法.
       plot_fitting:
                   是否绘制光谱的拟合流量残差图.
                   注意:
                   1) plot_fitting=True, 表示绘制拟合流量残差图.
                   2) plot_fitting=False, 表示不绘制拟合流量残差图.

       输出参数:
       -----------
       res:
           Rv, Teff, logg, FeH, Rv_err, Teff_err, logg_err, FeH_err, used_time, loss
           注意:
           1) used_time 表示推断 1 条光谱的参数所用的时间, 单位为秒.
           2) loss 表示拟合的流量残差的均方根误差.
    """


    """
    第 1 部分: 检查输入参数
    """
    # 4.1 查看是否指定了待测光谱文件
    if inspectr is None:
        raise ValueError("No LAMOST fits!")
    else:
        spectrum = inspectr
        if shift_guess is not None:
            raise ValueError("ULYSS: When <spectrum> is a structure SG=sg should not be specified!")
        if lmin is not None:
            raise ValueError("ULYSS: When <spectrum> is a structure LMIN and LMAX should not be specified!")
        if err_sp is not None:
            raise ValueError("ULYSS: When <spectrum> is a structure ERR_SP should not be specified!")

    # 4.2 查看是否指定 TGM 模型
    if cmp is not None:
        if model_file is not None:
            raise ValueError("Invalid arguments, <cmp> and MODEL_FILE are exclusive!")

    # 4.3 设置光速、加型多项式度数、勒让德乘法多项式度数、losvd 参数数量
    # 4.3.1 光速
    c = 299792.458
    # 4.3.2 默认不使用加型多项式
    if adegree is None:
        adegree = -1
    # 4.3.3 勒让德乘法多项式默认设置为 10
    if mdegree is None:
        mdegree = 10
    # 4.3.4 losvd 参数默认使用 2 个
    if kmoment is None:
        kmoment = 2
    # 4.3.5 kfix 长度必须与 losvd 参数数量一致
    if kfix is not None:
        if len(kfix) > kmoment:
            raise ValueError("The number of elements of KFIX should not exceed" + " " + str(kmoment) + "!")

    # 4.4 如果没有设置 cmp, 则读取 TGM 模型文件, 初始化 cmp
    if cmp is None:
        if model_file is not None:
            cmp = uly_cmp_read(model_file,        # 模型光谱文件
                               t_guess=t_guess,   # Teff 初始值
                               l_guess=l_guess,   # log g 初始值
                               z_guess=z_guess    # [Fe/H] 初始值
                               )

    # 4.5 quiet 设置为 True 时, 显示拟合过程
    if quiet is True:
        print('--------------------------------------------------------------------')
        print('一、参数设置')
        print('--------------------------------------------------------------------')
        print('1. 勒让德乘法多项式度数:', mdegree)
        if allow_polynomial_reduction:
            print('2. 允许多项式阶数减少, 伪连续谱小于 0 时, 会减少多项式阶数')
        else:
            print('2. 不允许多项式阶数减少, 伪连续谱小于 0 时, 参数将推断失败')
        if adegree == -1:
            print('3. 不设置可加型多项式')
        else:
            print('3. 可加型多项式度数:', adegree)
        print("4. 恒星大气物理参数初始值:", np.exp(cmp["para"][0]["guess"]), cmp["para"][1]["guess"], cmp["para"][2]["guess"])
        print('--------------------------------------------------------------------')


    """
    第 2 部分: 读取、重采样
    """
    # 4.6 SignalLog 可能是比如长狭缝, 在这种情况下提取一个扫描
    # 注意: LASP 没有使用该参数, 如有需要请参考 IDL 代码
    if position is not None:
        pos = position
    else:
        pos = 0

    # 4.7 获取光谱字典数据
    # 注意: LAMOST 光谱不需要波长重采样!!!!!!
    if spectrum is not None:
        SignalLog = uly_spect_extract(SignalIn=spectrum,    # 待测光谱字典结构
                                      pos=pos               # 位置
                                      )
        # 4.7.1 LAMOST 光谱的 sampling=1, 因此不需要重采样
        SignalLog = uly_spect_logrebin(SignalIn=SignalLog,                  # 待测光谱字典结构
                                       vsc=velscale,                        # 对数波长采样的速度
                                       sampling_function=sampling_function, # 插值方法
                                       overwrite=True                       # 是否覆盖
                                       )
    # 4.7.2 spectrum 就是 inspectr, LASP 没有使用这部分代码, 如有需要请参考 IDL 代码
    # else:
    #     SignalLog = uly_spect_read(file, lmin, lmax, VELSCALE=velscale, ERR_SP=err_sp, SG=shift_guess, quiet=quiet)
    #     SignalLog = uly_spect_extract(SignalIn=SignalLog, pos=pos, overwrite=True, status=True)
    #     if SignalLog["sampling"] != 1:
    #         SignalLog = uly_spect_logrebin(SignalIn=SignalLog, vsc=velscale, sampling_function=sampling_function, overwrite=True)

    # 4.8 获取待测光谱的好像素点索引
    if len(SignalLog["goodpix"]) > 0:
        gp = SignalLog["goodpix"]
    else:
        gp = np.arange(SignalLog["data"].shape[0])

    # 4.9 如果没有提供流量误差, 但提供了信噪比, 则可以根据信噪比计算待测光谱的流量误差
    # 注意: 
    # 1) LAMOST 光谱提供了流量倒方差, 因此该参数也是可以被设置的
    # 2) 但 LASP 设置流量误差为 1, 且没有设置信噪比
    # 3) 这部分代码块 LASP 没有使用
    if (len(SignalLog["err"]) == 0) & (snr is not None):
        mean_error = np.mean(SignalLog["data"][gp]) / snr
        if not np.isfinite(mean_error).all():
            raise ValueError("Cannot compute the mean of the signal!")
        # 4.9.1 设置流量误差
        SignalLog["err"] = SignalLog["data"] * 0 + mean_error
    # 4.10 如果提供了流量误差
    if len(SignalLog["err"]) != 0:
        negerr = np.where(SignalLog["err"][gp] <= 0)[0].tolist()
        poserr = np.where(SignalLog["err"][gp] > 0)[0].tolist()
        if len(negerr) == len(SignalLog["err"][gp]):
            raise ValueError("The noise is negative or null!")
        # 4.10.1 如果流量误差存在负值, 则将负值设置为最小正值
        if (len(negerr) > 0) & (len(negerr) < len(SignalLog["err"][gp])):
            SignalLog["err"][gp[negerr]] = np.min(SignalLog["err"][gp[poserr]])
        # 4.10.2 设置权重
        weight = 1 / (SignalLog["err"][gp]) ** 2
        # 4.10.3 如果权重存在大于 100 倍的平均权重, 则打印警告
        large_weight = np.where(weight > 100 * np.mean(weight))[0].tolist()
        if len(large_weight) > 0:
            print("Some pixels have more than 100 times the average weight ..." + "it may be an error (up to "
                  + np.max(weight) / np.mean(weight) + ")!")


    """
    第 3 部分: 初始化 cmp 光谱字典结构
    """
    # 4.11 获取 LAMOST 的线性波长
    lamrange = uly_spect_get(SignalIn=SignalLog,   # 待测光谱字典结构
                             WAVERANGE=True        # 是否获取波长
                             )[0]
    velscale = SignalLog["step"] * c
    # 4.12 更新 cmp, 确保 TGM 模型与 LAMOST 光谱波长范围一致
    status, cmp = uly_fit_init(cmp,                # TGM 模型字典结构
                               lamrange=lamrange,  # 波长范围
                               velscale=velscale   # 对数波长采样的速度
                               )
    if status != 0:
        raise ValueError("Fit initialization failed, abort!")


    """
    第 4 部分: 准备 & 调用拟合过程
    """
    # 4.13 TGM 模型光谱的线性波长范围
    model_range = np.exp([cmp["start"] + 0.5 * cmp["step"], cmp["start"] + (cmp["npix"] - 1.5) * cmp["step"]])
    # 4.14 更新 SignalLog, 确保待测光谱的波长范围也要在该区间内
    SignalLog = uly_spect_extract(SignalIn=SignalLog,     # 待测光谱字典结构
                                  waverange=model_range,  # 波长范围
                                  overwrite=True          # 是否覆盖
                                  )
   
    # 4.15 如果没有提供初始 losvd 参数, 则这里设置
    if shift_guess is None:
        sigma_guess = SignalLog["step"] * c
    # 4.15.1 设置 losvd 参数初始值
    cz_guess = 0
    kguess = [cz_guess, sigma_guess]
    # 4.15.2 如果 kmoment 大于 2, 则设置 losvd 参数初始值
    if kmoment is not None:
        if kmoment > 2:
            kguess = kguess + [0] * (kmoment - 2)
    if sigma_limits is not None:
        if len(sigma_limits) == 2:
            klim = [kguess[0] - 2000, kguess[0] + 2000, sigma_limits]

    # 4.16 如果 quiet=True, 则打印推断过程
    if quiet is True:
        velscale = SignalLog["step"] * c
        lamrange = np.exp([SignalLog["start"], SignalLog["start"] + SignalLog["data"].size * SignalLog["step"]])
        print('--------------------------------------------------------------------')
        print('二、传递给 uly_fit_IDL_method 函数的参数')
        print('--------------------------------------------------------------------')
        print('1. 所用波长范围\t\t:', lamrange[0], lamrange[1], '[Å]')
        print('2. 对数波长采样\t\t:', velscale, '[km/s]')
        print('3. 信号中独立像素数\t:', int(np.ceil(SignalLog["data"].size) / SignalLog["dof_factor"]))
        print('4. 拟合像素数\t\t:', SignalLog["data"].size)
        print('5. DOF factor\t\t:', SignalLog["dof_factor"])
        print('--------------------------------------------------------------------')

    # 4.17 拟合
    # 调用 uly_fit 函数, 使用 scipy 中的数值优化方法推断待测光谱的恒星参数、以及误差、耗时、流量残差的均方根误差
    Rv, Teff, logg, FeH, \
    Rv_err, Teff_err, logg_err, FeH_err, \
    used_time, loss = uly_fit_IDL_method(signalLog=SignalLog,                                     # 待测光谱字典结构
                                         cmp=cmp,                                                 # TGM 模型字典结构
                                         kmoment=kmoment,                                         # losvd 参数数量
                                         kguess=kguess,                                           # losvd 参数初始值
                                         kfix=kfix,                                               # losvd 参数是否固定
                                         klim=None,                                               # losvd 参数边界
                                         kpen=kpen,                                               # losvd 参数惩罚
                                         adegree=adegree,                                         # 加型多项式阶数
                                         mdegree=mdegree,                                         # 勒让德多项式阶数
                                         polpen=polpen,                                           # 乘法多项式偏置水平
                                         clean=clean,                                             # 是否剪切离群值, 即 Clean 模式与 No Clean 模式
                                         modecvg=modecvg,                                         # 拟合收敛模式, 即 0, 1, 2. LASP 不使用, 但我们保留了该参数. 如有需要, 请参考 IDL 代码
                                         allow_polynomial_reduction=allow_polynomial_reduction,   # 是否允许多项式阶数减少
                                         sampling_function=sampling_function,                     # 插值方法
                                         quiet=quiet,                                             # 是否抑制屏幕上打印的消息
                                         full_output=full_output,                                 # 是否返回所有参数的拟合结果, 以及拟合信息
                                         plot_fitting=plot_fitting                                # 是否绘制光谱的拟合流量残差图
                                         )

    # 4.18 输出 Rv、Teff、log g、[Fe/H] 以及误差、耗时、流量残差的均方根误差
    return Rv, Teff, logg, FeH, Rv_err, Teff_err, logg_err, FeH_err, used_time, loss