# -*- coding: utf-8 -*-
# @Time    : 08/12/2024 11.13
# @Author  : ljc
# @FileName: uly_spect_logrebin.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL uly_spect_logrebin.pro .
目的:
    对 TGM 模型光谱进行波长重采样, 使得模型光谱的流量维度与待测光谱一致.
函数:
    1) uly_spect_logrebin
解释:
    1) uly_spect_logrebin 函数: 要么返回输入光谱结构, 要么返回 TGM 模型光谱波长重采样后的光谱结构.
注意: 参数推断过程, 该函数输入变量 SignalIn 分别以实测光谱、模型光谱作为输入.
    1) 实测光谱: 调用 1 次, 无需重采样, 初始化时直接返回输入光谱结构.
    2) 模型光谱: 调用 m 次, 需要重采样, m 为使用 curve_fit 迭代计算恒星参数的次数.
"""


# 2. 导库
import numpy as np
from WRS.xrebin import xrebin
import warnings
warnings.filterwarnings("ignore")


# 3. 重采样到待测光谱波长
def uly_spect_logrebin(SignalIn, vsc=None, waverange=None, sampling_function=None, flux_conserve=False, exact=False, overwrite=False) -> dict:

    """
      对输入光谱结构体按对数波长进行重采样.

      输入参数:
      -----------
      SignalIn:
               输入数据为字典结构, 包含流量、波长、采样信息等.
      vsc:
          每个像素的速度尺度 (km/s), vsc = ln_step * c = log10_step * c * ln(10).
      waverange:
                插值波长范围.
      sampling_function:
                       插值方法. 可输入 "splinf", "cubic", "slinear", "quadratic", "linear". 默认使用 "linear" 插值方法.
      flux_conserve:
                    设置此关键字可跳过对局部像素大小变化的强度校正.
                    如果通量保持不变, 强度则会改变, 可除以像素大小变化因子 flat 来校正.
      exact:
            是否强制输出与 waverange 和 vsc 完全对齐.
      overwrite:
                是否覆盖输入光谱结构.

      输出参数:
      -----------
      SignalOut:
                包含对数重采样数据的光谱结构.
    """

    # 3.1 检查输入数据是否为光谱结构并包含光谱流量
    if not isinstance(SignalIn, dict) or 'data' not in SignalIn:
        raise ValueError("Input must be a spectrum structure with 'data' field!")

    # 3.2 获取光谱维度 (维度, 样本量), 采样步长, 光谱的起始波长, 采样模式 (0 为线性, 1 为 ln 对数, 2 为非均匀)
    # 3.2.1 获取光谱的样本量
    npix = SignalIn['data'].shape[0]
    # 3.2.2 获取采样步长
    step = SignalIn.get('step', 1)
    # 3.2.3 获取光谱的起始波长
    start = SignalIn.get('start', 0)
    # 3.2.4 获取采样模式
    sampling = SignalIn.get('sampling', 0)
    # 3.2.5 定义光速常量 (km/s)
    C = 299792.458

    # 3.3 决定使用样条插值或线性插值, 如果设置了 linear, 则 splinf 为 False
    # splinf = not linear

    # 3.4 如果输入光谱为 ln 对数采样且不需要重采样 (vsc 已被设置)
    # 注意: 3.4 代码块用于 LAMOST 光谱, 无需重采样
    # sampling == 0: 表示输入光谱结构的波长为线性采样.
    # sampling == 1: 表示输入光谱结构的波长为 ln 对数采样.
    # sampling == 2: 非均匀采样.
    # 3.4.1 如果输入光谱为 ln 对数采样
    if sampling == 1:
        # 3.4.1.1 标记是否可以直接复制输出
        cpout = False
        # 3.4.1.2 如果 vsc 已赋值
        if vsc is not None:
            # 3.4.1.2.1 检查 vsc 是否与步长一致 (允许少量误差), 如果 vsc 与步长一致, 则 cpout 为 True
            if abs(1.0 - vsc / (C * step)) * npix < 0.001:
                cpout = True
        # 3.4.1.3 如果 vsc 未设置, 则 cpout 为 True
        elif vsc is None:
            cpout = True

        # 3.4.1.4 检查输入光谱的起始波长与输入波长的最小值是否精确对齐
        # 注意: LASP 中, exact=False, cpout=True, waverange=None, 表示输入光谱的起始波长与输入波长的最小值精确对齐
        if exact and cpout and waverange is not None:
            # 3.4.1.4.1 检查输入光谱的起始波长与输入波长的最小值是否精确对齐, 取余数
            nshift = (start - np.log(waverange[0])) / step % 1
            # 3.4.1.4.2 如果余数大于 0.001, 即, 输入光谱的起始波长与输入波长的最小值不精确对齐, cpout 为 False
            if abs(nshift) > 0.001:
                cpout = False
            # 3.4.1.4.2 如果 cpout 为 True, 且 waverange 为长度为 2 的列表, 检查输入光谱的末端波长与输入波长的最大值是否精确对齐
            if cpout and len(waverange) == 2:
                # 3.4.1.4.2.1 检查输入光谱的末端波长与输入波长的最大值是否精确对齐, 取余数   
                # 注意:
                # 1) 我认为 IDL 在计算右侧 nshift 时, 存在 bug, 因此修改如下: waverange[0] 改为 waverange[1]
                # nshift = (start + (npix - 1) * step - np.log(waverange[0])) / step % 1
                nshift = (start + (npix - 1) * step - np.log(waverange[1])) / step % 1
                # 3.4.1.4.2.2 如果余数大于 0.001, 即, 输入光谱的末端波长与输入波长的最大值不精确对齐, cpout 为 False
                if abs(nshift) > 0.001:
                    cpout = False

        # 3.4.1.5 如果 cpout 为 True, 则不需要重新采样, 直接返回输入光谱结构
        if cpout:
            # 3.4.1.5.1 已经是对数采样, 且波长已经对齐, 则不需要重新采样, 直接返回输入光谱结构
            return SignalIn if overwrite else SignalIn.copy()

    # 3.5 TGM 模型光谱重采样
    # 注意: 3.5 代码块用于 TGM 模型光谱, 需重采样
    # sampling == 0: 表示输入光谱结构的波长为线性采样.
    # sampling == 1: 表示输入光谱结构的波长为 ln 对数采样.
    # sampling == 2: 非均匀采样.
    # 波长重采样分为 4 步: 尽可能保证采样前后总通量守恒
    # 1) 计算模型光谱的积分波长与待测光谱的积分波长
    # 2) 计算原始波长的流量积分: 构建累积通量函数, 将离散数据转换为连续表示 (见 xrebin.py 中的 xrebin 函数)
    # 3) 将流量积分插值到待测光谱的积分波长: 在新的波长网格边界上采样累积通量函数 (见 xrebin.py 中的 xrebin 函数)
    # 4) 差分得到重采样光谱: 计算相邻边界点间的通量差值并除以波长间隔, 得到新网格上的光谱强度 (见 xrebin.py 中的 xrebin 函数)
    # 3.5.1 如果输入光谱为线性采样
    if sampling == 0:
        # 3.5.1.1 生成 [0, 1, 2, ..., npix] 数组
        indxarr = np.arange(npix)
        # 3.5.1.2 计算模型光谱的积分波长. 这是为了对区间流量积分再采样, 而不是对中心波长流量采样. 尽可能保证重采样前后总通量守恒
        # 如 [4199.2 4199.4 4199.6 ... 5699.8 5700.  5700.2] 变为了 [4199.1 4199.3 4199.5 ... 5699.9 5700.1 5700.3]
        # 即 [4199.2 4199.4 4199.6 ... 5699.8 5700.  5700.2] 各个波长点的左右区间
        borders = start + np.array([-0.5, *(indxarr + 0.5)]) * step
        # 3.5.1.3 线性波长转为 ln 对数波长, 如 [8.3426255  8.34267312 8.34272075 ... 8.64820391 8.648239   8.64827408]
        bordersLog = np.log(borders)
    # 3.5.2 波长是 ln 对数的
    elif sampling == 1:
        indxarr = np.arange(npix)
        bordersLog = start + np.array([-0.5, *(indxarr + 0.5)]) * step
        borders = np.exp(bordersLog)
    # 3.5.3 非均匀采样, LASP 没有使用该采样, 如有需要, 请参考原始 IDL 代码
    elif sampling == 2:
        wavelen = SignalIn['wavelen']
        borders = (wavelen[:-1] + wavelen[1:]) / 2
        borders = np.concatenate(([2 * wavelen[0] - borders[0]], borders, [2 * wavelen[-1] - borders[-1]]))
        bordersLog = np.log(borders)
    else:
        raise ValueError(f"Invalid sampling mode: {sampling}!")

    # 3.5.4 若未指定 vsc, 则自动计算
    # 注意: 无论是 ln 对数采样还是非均匀采样, 都需要转为 ln 对数采样计算 vsc
    if vsc is None:
        # 3.5.4.1 如果输入光谱为 ln 对数采样
        if sampling == 1:
            # 3.5.4.1.1 vsc 为 ln_step (ln 对数采样步长) 倍光速
            vsc = step * C
        # 3.5.4.2 如果输入光谱为非均匀采样(LASP 没有使用该采样, 如有需要, 请参考原始 IDL 代码)
        elif sampling == 2:
            wrange = wavelen[waverange] if waverange else wavelen[[0, -1]]
            vsc = np.log(wrange[1] / wrange[0]) / npix * C
        # 3.5.4.3 如果输入光谱为线性采样
        else:
            # 3.5.4.3.1 线性采样转为 ln 对数采样
            wrange = np.log([start, start + step * (npix - 1)])
            # 3.5.4.3.2 vsc = 平均对数波长间隔 * 光速
            vsc = (wrange[1] - wrange[0]) / (npix - 1) * C

    # 3.5.5 计算对数步长
    logScale = vsc / C

    # 3.5.6 计算对数波长边界
    # 3.5.6.1 计算 logStart (线性波长起点与终点), 并转为 ln 对数波长
    logRange = start + np.array([-0.5, npix - 0.5]) * step
    if sampling == 0:
        logRange = np.log(logRange)
    elif sampling == 2:
        logRange = np.log([wavelen[0], wavelen[-1]])
    # 3.5.6.2 将 TGM 模型光谱的 ln 对数波长左右收缩 0.5 倍的 LAMOST ln 对数波长间隔
    logRange += np.array([0.5, -0.5]) * logScale

    # 3.5.7 计算新的对数波长起点, 并进行光谱流量重采样
    # 3.5.7.1 如果输入了 waverange, 则计算新的对数波长起点, 匹配 logRange 与 waverange 的共同波长范围
    if waverange is not None:
        # 3.5.7.1.1 计算输入光谱结构的对数波长起点与输入波长起点的差异, 并取交集部分
        nshift = np.ceil(np.max([0, (logRange[0] - np.log(waverange[0]))]) / logScale - 1e-7)
        logRange[0] = np.log(waverange[0]) + logScale * nshift
        if len(waverange) == 2:
            logRange[1] = np.min([np.log(waverange[1]), logRange[1]])
        if logRange[1] < logRange[0]:
            raise ValueError("waverange is not valid!")

    # 3.5.7.2 计算新的对数波长起点
    nout = round((logRange[1] - logRange[0]) / logScale) + 1
    logStart = logRange[0]
    # 3.5.7.3 ln 对数波长转为线性波长, 并计算 dof_factor
    if sampling < 2:
        if sampling == 0:
            logRange = np.exp(logRange)
        nin = np.round((logRange[1] - logRange[0]) / step + 1)
    else:
        nin = len(wavelen)
    dof_factor = nout / nin
    # 3.5.7.4 判断 logStart 是否在 TGM 模型的波长范围内
    if logStart - logScale/2 > bordersLog[npix]:
        raise "start value is not in the valid range!"

    # 3.5.7.5 待测光谱的积分波长
    NewBorders = np.exp(logStart + (np.arange(nout + 1) - 0.5) * logScale)
    # 3.5.7.6 获取光谱的维度
    dim = SignalIn["data"].shape
    n_data = SignalIn["data"].shape[0] * SignalIn["data"].shape[1]
    # 3.5.7.7 获取光谱流量误差
    # 注意:
    # 1) LASP 中对于 LAMOST 光谱设置 err=1, 但需要注意 LAMOST 光谱提供了流量倒方差
    # 2) TGM 模型光谱流量没有误差, 即为 None
    err_ = SignalIn.get("err", None)
    # 3.5.7.8 如果误差不为 None, 则获取误差的维度
    if (err_ is not None) and (hasattr(err_, "shape")) and (len(err_.shape) > 0):
        n_err = SignalIn["err"].shape[0]
    else:
        n_err = 0
    # 注意: .shape[0] 表示光谱流量维度, .shape[1] 表示光谱样本数量, 如: SignalIn["data"].shape = (7506, 1)
    n_dim = SignalIn["data"].shape[1]
    # 3.5.7.9 如果输入光谱为线性采样, 则把上述的 ln 对数波长转为线性波长, 并计算 flat
    # 注意: flat 确定用于处理像素大小变化的转换因子/向量
    # 1) 在对数重采样过程中, 像素的波长范围会发生变化. 在对数采样中, 像素的波长范围是指数增长的, 而在线性采样中, 像素的波长范围是线性增长的.
    # 因此, 在对数重采样过程中, 需要使用 flat 来处理像素大小变化的转换因子/向量. flat 变量就是用来校正这种像素大小变化导致的强度变化的转换因子.
    if sampling == 0:
        flat = np.exp(logStart + np.arange(nout) * logScale) * logScale / step
        if n_dim > 1:
            flat = flat.reshape(nout, dim[1])
    # 3.5.7.10 如果输入光谱为 ln 对数采样, 则计算 flat
    elif sampling == 1:
        flat = logScale / step
    # 3.5.7.11 如果输入光谱为非均匀采样, 则计算 flat
    else:
        flat = np.ones(dim[0], dtype=np.float64)
        # 3.5.7.11.1 重采样
        # flat = xrebin(borders, flat, NewBorders, sampling_function="splinf")
        flat = xrebin(borders,                             # 原始波长区间范围
                      flat,                                # 流量
                      NewBorders,                          # 待采样的波长区间范围
                      sampling_function=sampling_function  # 插值函数
                      )
        if len(dim) > 1:
            flat = flat.reshape(nout, dim[1])

    # 3.5.7.12 初始化输出
    SignalOut = SignalIn if overwrite else SignalIn.copy()

    # 3.5.7.13 光谱流量重采样
    if flux_conserve:
        # SignalOut["data"] = xrebin(borders, SignalIn["data"], NewBorders, sampling_function="splinf")
        SignalOut["data"] = xrebin(borders,                             # 原始波长区间范围
                                   SignalIn["data"],                    # 流量
                                   NewBorders,                          # 待采样的波长区间范围
                                   sampling_function=sampling_function  # 插值函数
                                   )
    else:
        # SignalOut["data"] = xrebin(borders, SignalIn["data"], NewBorders, sampling_function="splinf") / flat
        SignalOut["data"] = xrebin(borders,                             # 原始波长区间范围
                                   SignalIn["data"],                    # 流量
                                   NewBorders,                          # 待采样的波长区间范围
                                   sampling_function=sampling_function  # 插值函数
                                   ) / flat

    # 3.5.8 光谱误差重采样
    # 注意: TGM 模型光谱流量没有误差, 即为 None, 因此不对光谱误差进行重采样
    if n_err == n_data:
        # err = xrebin(borders, SignalIn["err"] ** 2, NewBorders, sampling_function="splinf")
        err = xrebin(borders,                             # 原始波长区间范围
                     SignalIn["err"] ** 2,                # 流量误差
                     NewBorders,                          # 待采样的波长区间范围
                     sampling_function=sampling_function  # 插值函数
                     )
        if len(SignalIn["goodpix"]) != 0:
            minerr = np.min(SignalIn["err"][SignalIn["goodpix"]]) ** 2 * SignalIn["err"].shape[0] / err.shape[0]
        else:
            minerr = np.min(SignalIn["err"]) ** 2 * SignalIn["err"].shape[0] / err.shape[0]
        negative = np.where(err <= minerr)[0].tolist()
        if len(negative) != 0:
            err[negative] = minerr
        if flux_conserve:
            SignalOut["err"] = err ** 0.5
        else:
            SignalOut["err"] = err ** 0.5 / flat
        # 3.5.8.2 如果 dof_factor > 1, 则更新光谱误差
        if dof_factor > 1:
            SignalOut["err"] = SignalOut["err"] / dof_factor ** 0.5
        # 3.5.8.3 如果 SignalIn["dof_factor"] > 1, 则更新光谱误差
        if SignalIn["dof_factor"] > 1:
            d = dof_factor
            if d * SignalIn["dof_factor"] < 1:
                d = 1 / SignalIn["dof_factor"]
            SignalOut["err"] = SignalOut["err"] / d ** 0.5
    
    # 3.5.9 光谱好像素点重采样
    # 注意: LASP 中 TGM 模型光谱并没有设置, 而 LAMOST 光谱设置了 goodpix
    if SignalIn.get("goodpix") is not None:
        if len(SignalIn["goodpix"]) != 0:
            maskI = np.zeros(n_data, dtype=np.uint8)
            maskI[SignalIn["goodpix"]] = 1
            # maskO = xrebin(borders, maskI, NewBorders, sampling_function="splinf") / flat
            maskO = xrebin(borders,                             # 原始波长区间范围
                           maskI,                               # 流量
                           NewBorders,                          # 待采样的波长区间范围
                           sampling_function=sampling_function  # 插值函数
                           ) / flat
            goodpix = np.where(abs(maskO - 1) < 0.1)[0].tolist()
            SignalOut["goodpix"] = goodpix

    # 3.6 更新输出数据
    # 3.6.1 更新输出数据标题
    SignalOut["title"] = SignalIn["title"]
    # 3.6.2 更新输出数据头
    SignalOut["hdr"] = SignalIn["hdr"]
    # 3.6.3 更新输出数据起点
    SignalOut['start'] = logStart
    # 3.6.4 更新输出数据步长
    SignalOut['step'] = logScale
    # 3.6.5 更新输出数据采样模式
    SignalOut['sampling'] = 1
    # 3.6.6 更新输出数据自由度因子
    SignalOut['dof_factor'] = np.max([1.0, SignalIn.get('dof_factor', 1.0) * dof_factor])

    # 3.7 输出 ln 对数波长下的光谱字典数据
    return SignalOut
