# -*- coding: utf-8 -*-
# @Time    : 2025/1/13 21:22
# @Author  : ljc
# @FileName: ulyss_pytorch.py
# @Software: PyCharm


# 1. 简介
"""
目的:
    通过 ulyss 函数设置待测光谱文件位置、流量、流量误差、恒星参数初始值、积分波长节点等信息. 并使用 data_to_pt/data_to_pt.ipynb 将光谱数据存为 .pt 文件
    将光谱信息存为 pt 文件, 便于 Adam 优化器进行分组优化.
函数：
    1) uly_cmp_read
    2) ulyss
解释：
    1) uly_cmp_read 函数: 返回 cmp 初始结构信息.
    2) ulyss 函数: 调用 ulf_fit 函数获取待测光谱流量、流量误差、恒星参数初始值、积分波长节点等信息.
"""


# 2. 调库
from astropy.io import fits
from uly_tgm.uly_tgm import uly_tgm
from uly_read_lms.uly_spect_extract import uly_spect_extract
import numpy as np
from WRS.uly_spect_logrebin import uly_spect_logrebin
from uly_read_lms.uly_spect_get import uly_spect_get
from uly_fit.uly_fit_init import uly_fit_init
import torch
import warnings
warnings.filterwarnings("ignore")


# 3. 读取 TGM 模型文件, 并返回 cmp 初始结构信息
def uly_cmp_read(model_file, t_guess=None, l_guess=None, z_guess=None) -> dict:

    """
        获取 cmp 初始结构信息.

        输入参数:
        -----------
        model_file:
                   TGM 模型文件位置.
        t_guess:
                Teff 初始值.
        l_guess:
                log g 初始值.
        z_guess:
                [Fe/H] 初始值.

        输出参数:
        -----------
        cmp:
            cmp 初始字典结构信息.
    """

    # 3.1 如果 model_file 不是字符串, 则抛出类型错误
    if not isinstance(model_file, str):
        raise TypeError("Argument model_file must be a filename!")

    # 3.2 读取 TGM 模型文件
    try:
        with fits.open(model_file) as hdul:
            header = hdul[0].header
    except Exception as e:
        raise IOError(f"Error reading file: {str(e)}!")

    # 3.3 检查表头中的 ULY_TYPE 键并据此分支处理
    # 注意: 
    # 1) LASP 使用的 uly_type="TGM"
    # 2) 如有需要其他类型, 请查看 IDL 代码
    # 3) 关于非 TGM 模型, 如 SSP 和 STAR, 留给后续版本
    # 3.3.1 获取 ULY_TYPE 键值
    uly_type = header.get('ULY_TYPE', '').strip().upper()
    # 3.3.2 如果 ULY_TYPE 键值为 TGM, 则返回 TGM 字典结构
    if uly_type == 'TGM':
        # 3.3.2.1 返回 TGM 字典结构
        return uly_tgm(model_file=model_file, # 模型光谱文件
                       t_guess=t_guess,       # Teff 初始值
                       l_guess=l_guess,       # log g 初始值
                       z_guess=z_guess        # [Fe/H] 初始值
                       )
    # 3.3.3 如有需要其他类型, 请查看 IDL 代码
    else:
        raise ValueError(f"Current version only supports TGM model. Invalid ULY_TYPE: {uly_type}!")


# 4. 调用 ulf_fit 函数获取待测光谱信息, 并存为 pt 文件
def ulyss(spectrum=None, cmp=None, model_file=None,
          snr=None, velscale=None, t_guess=None, l_guess=None, 
          z_guess=None) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

    """
       返回待存为 pt 文件的光谱信息, 便于 LASP-Adam-GPU 快速推断待测参数.

       输入参数:
       -----------
       spectrum:
                包含 LAMOST 光谱信息的 SignalOut 字典, 详情参考 uly_spect_read_lms 函数.
       cmp:
           TGM 字典, 存储 TGM、待测参数初始值、勒让德多项式默认值等信息.
           cmp 是 1 个字典而不是字典列表, 因此 cmp 的长度记为 1.
       model_file:
                  TGM 模型文件地址.
       snr:
           待测光谱的信噪比.
       velscale:
                ln 对数波长空间在每个步长的速度, 单位为 km/s.
                注意: velscale = ln_step * c = log10_step * c * ln(10) = 69.02976447828436.
       t_guess:
               Teff 初始值.
       l_guess:
               log g 初始值.
       z_guess:
               [Fe/H] 初始值.
      
       输出参数:
       -----------
        ELODIE_wave, borders, NewBorders, flat, lamrange, flux_lamost, goodpix, spec_coef.
        分别对应:
                ELODIE_wave: ELODIE 模型光谱波长;
                borders: ELODIE 模型光谱的积分波长节点;
                NewBorders: 待测光谱的积分波长节点;
                flat: $\Delta \lambda_{1} / (\lambda_{2(i+1)}^{\prime\prime} - \lambda_{2i}^{\prime\prime})$, 请参考论文 LASP-Adam-GPU 方法章节步骤 3;
                lamrange: 待测光谱的波长;
                flux_lamost: 待测光谱的流量值;
                goodpix: 为待拟合光谱的好像素点索引, 可自行加入官方 MASK 字段;
                         由于 No Clean MASK 区域与 lamrange 无交集, 因此该字段不记录在 pt 文件.
                         目前该字段支持输出为好像素点索引值, 建议将索引转为 0 与 1 MASK 表, 便于优化目标函数时批量 MASK.
                spec_coef: ELODIE 多项式光谱模拟器的各项系数值.
    """

    # 4.1 定义光速
    c = 299792.458

    # 4.2 如果没有设置 cmp, 则读取 TGM 模型文件, 初始化 cmp
    if cmp is None:
        if model_file is not None:
            cmp = uly_cmp_read(model_file,        # 模型光谱文件
                               t_guess=t_guess,   # Teff 初始值
                               l_guess=l_guess,   # log g 初始值
                               z_guess=z_guess    # [Fe/H] 初始值
                               )
    

    # 4.3 获取光谱字典数据
    # 注意:
    # 1) LASP 仅对模型光谱波长重采样到 LAMOST 波长, 而不对 LAMOST 光谱重采样
    SignalLog = uly_spect_extract(SignalIn=spectrum     # 待测光谱字典结构
                                  )
    # 4.3.1 LAMOST 光谱的 sampling=1, 因此不需要重采样.
    # 注意: 这行代码可删除, 不起作用, 但这里保持与 IDL 一致
    SignalLog = uly_spect_logrebin(SignalIn=SignalLog,  # 待测光谱字典结构
                                   vsc=velscale,        # 对数波长采样的速度
                                   overwrite=True       # 是否覆盖
                                   )
        
    # 4.4 获取待测光谱的好像素点索引
    if len(SignalLog["goodpix"]) > 0:
        gp = SignalLog["goodpix"]
    else:
        gp = np.arange(SignalLog["data"].shape[0])

    # 4.5 如果没有提供流量误差, 但提供了信噪比, 则可以根据信噪比计算待测光谱的流量误差
    # 注意: 
    # 1) LAMOST 光谱提供了流量倒方差, 因此该参数也是可以被设置的
    # 2) 但 LASP 设置流量误差为 1, 且没有设置信噪比
    # 3) 这部分代码块 LASP 没有使用
    if (len(SignalLog["err"]) == 0) & (snr is not None):
        mean_error = np.mean(SignalLog["data"][gp]) / snr
        if not np.isfinite(mean_error).all():
            raise ValueError("Cannot compute the mean of the signal!")
        # 4.5.1 设置流量误差
        SignalLog["err"] = SignalLog["data"] * 0 + mean_error
    # 4.6 如果提供了流量误差
    if len(SignalLog["err"]) != 0:
        negerr = np.where(SignalLog["err"][gp] <= 0)[0].tolist()
        poserr = np.where(SignalLog["err"][gp] > 0)[0].tolist()
        if len(negerr) == len(SignalLog["err"][gp]):
            raise ValueError("The noise is negative or null!")
        # 4.6.1 如果流量误差存在负值, 则将负值设置为最小正值
        if (len(negerr) > 0) & (len(negerr) < len(SignalLog["err"][gp])):
            SignalLog["err"][gp[negerr]] = np.min(SignalLog["err"][gp[poserr]])

    # 4.7 获取 LAMOST 的线性波长
    lamrange = uly_spect_get(SignalIn=SignalLog,   # 待测光谱字典结构
                             WAVERANGE=True        # 是否获取波长
                             )[0]
    velscale = SignalLog["step"] * c               # 69.02976447828436
    # 4.7.1 更新 cmp, 确保 TGM 模型与 LAMOST 光谱波长范围一致
    status, cmp = uly_fit_init(cmp,                # TGM 模型字典结构
                               lamrange=lamrange,  # 波长范围
                               velscale=velscale   # 对数波长采样的速度
                               )

    # 4.8 由于模型光谱与待测光谱区间端点可能存在偏差, 因此待测光谱左右端点缩放, 保证模型光谱区间稍覆盖待测光谱区间
    model_range = np.exp([cmp["start"] + 0.5 * cmp["step"], cmp["start"] + (cmp["npix"] - 1.5) * cmp["step"]])
    # 4.9 更新 SignalLog: 更新待测光谱波长范围, 流量值等信息, 保证待测光谱波长位于 model_range 范围内
    SignalLog = uly_spect_extract(SignalIn=SignalLog,     # 待测光谱字典结构
                                  waverange=model_range,  # 波长范围
                                  overwrite=True          # 是否覆盖
                                  )

    # 4.10 lamrange 为待拟合的 LAMOST 光谱波长序列: model_range 范围内
    lamrange = np.exp([SignalLog["start"] + np.arange(SignalLog["data"].size) * SignalLog["step"]])[0]
    # 4.11 flux_lamost 为待拟合的 LAMOST 光谱流量
    flux_lamost = SignalLog["data"]
    # 4.12 模型光谱在指定波长范围内的像素点数量 (np.arange(7506)), cmp['eval_data'] 为模型光谱信息
    indxarr = np.arange(cmp['eval_data']["spec_coef"].shape[1])
    # 4.13 ELODIE_wave 为模型光谱的波长序列: 端点区间包含 lamrange
    ELODIE_wave = cmp['eval_data']["mod_start"] + np.array(indxarr) * cmp['eval_data']["mod_step"]
    # 4.14 borders 为模型光谱的积分波长节点
    borders = cmp['eval_data']["mod_start"] + np.array([-0.5, *(indxarr + 0.5)]) * cmp['eval_data']["mod_step"]
    # 4.15 NewBorders 为待测光谱的积分波长节点, cmp["npix"] 为待测光谱在指定波长中的流量点数量
    NewBorders = np.exp([cmp["start"] + (np.arange(cmp["npix"] + 1) - 0.5) * cmp["step"]])[0]
    # 4.16 logScale 为对数波长采样的速度, c 为光速, 69.02976447828436 为对数波长采样的速度
    logScale = 69.02976447828436 / c
    # 4.17 flat 为 $\Delta \lambda_{1} / (\lambda_{2(i+1)}^{\prime\prime} - \lambda_{2i}^{\prime\prime})$, 请参考论文 LASP-Adam-GPU 方法章节步骤 3
    flat = np.exp(cmp["start"] + np.arange(cmp["npix"]) * logScale) * logScale / cmp['eval_data']["mod_step"]
    # 4.18 goodpix 为待拟合的 LAMOST 光谱的好像素点索引
    goodpix = SignalLog["goodpix"]
    # 4.19 spec_coef 为 ELODIE 多项式光谱模拟器的各项系数
    spec_coef = cmp["eval_data"]["spec_coef"][:23, :, :]

    # 4.20 返回待存为 pt 文件的光谱信息
    return ELODIE_wave, borders, NewBorders, flat, lamrange, flux_lamost, goodpix, spec_coef