# -*- coding: utf-8 -*-
# @Time    : 08/12/2024 12.16
# @Author  : ljc
# @FileName: xrebin.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL xrebin.pro .
目的:
    设置不同插值方法, 对输入光谱进行波长重采样.
函数:
    1) xrebin
解释:
    1) xrebin 函数: 插值后的光谱.
注意:
     为了尽可能保证插值前后, 光谱的总通量守恒, 波长重采样使用 3 步:
     1) 在原始积分波长区间范围, 将光谱累积成积分形式.
     2) 在待测光谱积分区间范围, 对积分光谱进行插值.
     3) 对插值后的光谱进行差分, 得到最终的重采样光谱.
"""


# 2. 导库
import numpy as np
from scipy.interpolate import interp1d, CubicSpline
import warnings
warnings.filterwarnings("ignore")


# 3. 波长重采样函数
def xrebin(xin, yin, xout, sampling_function, nan=False) -> np.ndarray:

    """
      根据需要选择插值方法, 对输入光谱进行波长重采样.

      输入参数:
      -----------
      xin:
          输入光谱的波长范围.
      yin:
          输入光谱的流量.
      xout:
           需要采样到的波长范围, 即波长插值范围.
      sampling_function:
                       插值方法. 可输入 "splinf", "cubic", "slinear", "quadratic", "linear". 默认使用 "linear" 插值方法.
                       注意: 不同插值方法对参数推断影响很小, 无论使用哪种插值方法, 最终结果几乎一致. 
                             如 Teff 差异不足 1 K, 而 Rv 差异不足 0.05 km/s, log g 差异不足 0.005 dex, [Fe/H] 差异不足 0.005 dex.
                             因此, 我们认为不应该纠结于插值方法的选择.
      nan:
          如果为 True, 将 yin 中的缺失值视为 0.

      输出参数:
      -----------
      integr_interp:
                   波长重采样后的光谱流量.
    """

    # 3.1 获取 yin 的维度信息, 光谱形状为 (流量维度, 样本数量)
    s = yin.shape
    # 3.2 是否有 1 条样本
    if s[1] < 1:
        raise ValueError("yin must be a spectrum!")

    # 3.3 检查 xin 的长度是否比 yin 的第一维度长 1, 即波长与流量维度要一致
    if s[0] + 1 != len(xin):
        raise ValueError("xin must be one element longer than yin!")

    # 3.4 一次计算 1 条样本的波长重采样结果
    if s[1] == 1:
        # 3.4.1 将 yin 累积成积分形式
        integr = np.insert(np.nancumsum(yin, dtype=np.float64), 0, 0)
        # 3.4.2 根据插值方法进行处理
        if sampling_function == "splinf":
            # 3.4.2.1 使用样条插值(快速版本)
            y2 = CubicSpline(xin, integr, bc_type='natural')
            integr_interp = y2(xout)
        elif sampling_function == "cubic":
            # 3.4.2.2 使用三次样条插值方法
            cubic_interp = interp1d(xin, integr, kind='cubic', fill_value="extrapolate")
            integr_interp = cubic_interp(xout)
        else:
            # 3.4.2.3 一阶样条插值
            if sampling_function == "slinear":
                kind = 'slinear'
            # 3.4.2.4 二次样条插值
            elif sampling_function == "quadratic":
                kind = 'quadratic'
            else:
                # 3.4.2.5 默认使用线性插值方法, 即两点之间使用线性插值
                kind = 'linear'
            interp_func = interp1d(xin, integr, kind=kind, fill_value="extrapolate")
            integr_interp = interp_func(xout)

        # 3.4.3 对累积分布进行差分, 得到最终的重采样输出
        integr_interp = (np.roll(integr_interp, -1) - integr_interp)[:len(integr_interp) - 1]

        # 3.4.4 返回插值光谱
        return integr_interp

    # 3.5 多维数组的情况 (LASP 没有使用该方案, 因此下述代码不测试, 如有需要请参考原始 IDL 代码)
    """
    else:
        # 3.5.1 获取数据的形状
        dim = yin.shape
        # 3.5.2 将 yin 累积成积分形式
        integr = np.concatenate([np.zeros((1, *dim[1:])), np.cumsum(yin, axis=0, dtype=np.float64)], axis=0)

        # 3.5.3 使用不同的插值方法进行处理
        if splinf:
            # 3.5.3.1 使用快速样条插值
            integr_interp = np.zeros((len(xout), dim[1]), dtype=np.float64)
            for i in range(dim[1]):
                spline_func = CubicSpline(xin, integr[:, i], bc_type='natural')
                integr_interp[:, i] = spline_func(xout)
        elif cubic:
            # 3.5.3.2 使用三次插值
            integr_interp = np.zeros((len(xout), dim[1]), dtype=np.float64)
            for i in range(dim[1]):
                cubic_interp = interp1d(xin, integr[:, i], kind='cubic', fill_value="extrapolate")
                integr_interp[:, i] = cubic_interp(xout)
        else:
            # 3.5.3.3 默认使用线性或二次插值
            kind = 'slinear' if spline else 'quadratic' if (lsquadratic or quadratic) else 'linear'
            integr_interp = np.zeros((len(xout), dim[1]), dtype=np.float64)
            for i in range(dim[1]):
                interp_func = interp1d(xin, integr[:, i], kind=kind, fill_value="extrapolate")
                integr_interp[:, i] = interp_func(xout)

        # 3.5.4 对累积分布进行差分, 得到最终的重采样输出
        yout = np.diff(integr_interp, axis=0)
       
        # 3.5.5 返回插值光谱
        return yout
    """


# 4. 测试
# xin = np.array([1, 2, 3, 4]).reshape(-1)
# yin = np.array([1, 2, 1]).reshape(-1, 1)
# xout = np.array([1.2, 2.1, 3.2, 3.5]).reshape(-1)
# yout = xrebin(xin, yin, xout)
# print(yout)

# xin = np.array([1, 2, 3, 5]).reshape(-1)
# yin = np.array([2, 1, 1]).reshape(-1, 1)
# xout = np.array([1, 2, 3, 6]).reshape(-1)
# yout = xrebin(xin, yin, xout, splinf=True)
# print(yout)

# xin = np.array([1, 2, 3, 5]).reshape(-1)
# yin = np.array([2, 1, 1]).reshape(-1, 1)
# xout = np.array([1.1, 2, 3, 4.5]).reshape(-1)
# yout = xrebin(xin, yin, xout, splinf=False)
# print(yout)

# xin = np.array([1., 2., 3., 4.]).reshape(-1)
# yin = np.array([1., 2., 1.]).reshape(-1, 1)
# xout = np.array([1.2, 2.1, 3.2, 3.9]).reshape(-1)
# yout = xrebin(xin, yin, xout, splinf=False)
# print(yout)