# -*- coding: utf-8 -*-
# @Time    : 2025/1/2 11:19
# @Author  : ljc
# @FileName: xrebin_pytorch.py
# @Software: PyCharm


# 1. 简介
"""
目的:
    线性插值光谱, 即波长重采样.
函数:
    1) xrebin
解释:
    1) xrebin 函数: 插值后的光谱.
注意: 
     为了尽可能保证插值前后, 光谱的总通量守恒, 波长重采样使用 3 步:
     1) 在原始积分波长区间范围, 将光谱累积成积分形式.
     2) 在待测光谱积分区间范围, 对积分光谱进行插值.
     3) 对插值后的光谱进行差分, 得到最终的重采样光谱.
     4) LASP 使用样条插值还是线性插值对参数推断的影响很小, 目前分组优化支持线性插值, 但后续可开发多种插值方法.
"""


# 2. 导库
from config.config import default_set, set_all_seeds
import torch
# 2.1 设置随机种子
set_all_seeds()
# 2.2 调用 GPU、指定数据类型
dtype, device = default_set()
# 2.3 默认数据类型
torch.set_default_dtype(dtype)


# 3. 波长重采样函数
def xrebin(xin, yin, xout) -> torch.Tensor:

    """
      对输入光谱进行线性插值, 即重采样.

      输入参数:
      -----------
      xin:
          输入光谱的积分波长.
      yin:
          输入光谱的流量.
      xout:
          待测光谱的积分波长.

      输出参数:
      -----------
      integr_interp:
                    重采样后的光谱流量.
    """

    # 3.1 找到 xout 中每个点在 xin 中的位置索引
    indices = torch.searchsorted(xin, xout)
    
    # 3.2 将光谱累积成积分形式
    flux = torch.cat([torch.zeros_like(yin[:, :1]), torch.cumsum(yin, dim=1)], dim=1)

    # 3.3 获取左侧和右侧的点
    w1_left = torch.gather(xin, 1, indices-1)
    w1_right = torch.gather(xin, 1, indices)
    flux_left = torch.gather(flux, 1, indices-1)
    flux_right = torch.gather(flux, 1, indices)

    # 3.4 线性插值计算
    flux_interpolated = flux_left + (flux_right - flux_left) * (xout - w1_left) / (w1_right - w1_left)
   
    # 3.5 对插值后的光谱进行差分, 得到最终的重采样输出
    integr_interp = (torch.roll(flux_interpolated, -1, dims=1) - flux_interpolated)[:, :-1]

    # 3.6 返回插值光谱
    return integr_interp


# 4. 测试
# xin = torch.tensor([1, 2, 3, 4]).unsqueeze(0)
# yin = torch.tensor([1, 2, 1]).unsqueeze(0)
# xout = torch.tensor([1.2, 2.1, 3.2, 3.5]).unsqueeze(0)  
# print(xrebin(xin, yin, xout))