# -*- coding: utf-8 -*-
# @Time    : 10/12/2024 16.20
# @Author  : ljc
# @FileName: uly_spect_extract.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL uly_spect_extract.pro .
目的:
    提取光谱的一部分并返回一个 spect 结构.
函数:
    1) uly_spect_extract
解释:
    1) uly_spect_extract 函数: 提取光谱的一部分并返回一个 spect 字典结构.
"""


# 2. 导库
import numpy as np
from uly_read_lms.uly_spect_get import uly_spect_get
from uly_read_lms.uly_spect_alloc import uly_spect_alloc
import warnings
warnings.filterwarnings("ignore")


# 3. 提取指定波长范围的光谱
def uly_spect_extract(SignalIn, pos=None, waverange=None, overwrite=False) -> dict:

    """
      提取光谱的一部分并返回一个 spect 结构.

      输入参数:
      -----------
      SignalIn:
               输入光谱字典结构.
      pos:
          该参数仅适用于长缝光谱, 但 LAMOST 光谱不是长缝光谱, 因此该参数不适用 LASP.
          关于长缝光谱的截取, 请参考 IDL 代码.
      waverange:
                指定要提取的波长范围.
      overwrite:
                overwrite 为真, 表示 SignalExtr 与 SignalIn 是同一个变量.

      输出参数:
      -----------
      SignalOut:
                提取的光谱 (spect 结构, 参见 uly_spect_alloc).
    """

    # 3.1 若设置了 overwrite, 则 SignalOut 为 SignalIn; 否则, 创建新的光谱结构
    SignalOut = SignalIn if overwrite else uly_spect_alloc(SPECTRUM=SignalIn)

    # 3.2 获取流量维度
    ndim = SignalOut["data"].shape[0]

    # 3.3 提取长缝光谱
    # 注意: LAMOST 光谱不是长缝光谱, 如有需要, 请参考 IDL 代码
    """
    if pos is not None and ndim == 1:
        if pos[0] != 0:
            print("Invalid ONED position")
            return -1
    # 3.3.2 提取多维光谱
    elif pos is not None:
        # 获取流量维度
        ndim = SignalOut["data"].shape[0]
        if len(str(pos)) >= ndim and ndim > 1:
            print("Invalid ONED position")
            return -1

        # 获取光谱数据的维度
        dim = SignalOut["data"].shape
        # invalid_pos = np.where((np.array(pos) < 0) | (np.array(pos) >= np.array(dim[1:])), 1)
        # if np.any(invalid_pos):
        #     print("Invalid ONED position")
        #     return -1

        # 计算从指定位置索引 pos 提取的数据
        # sdim = [1] + list(np.cumsum(np.log(dim)))
        ntot = SignalOut["data"].size
        # ind = np.sum(np.array(pos) * np.array(sdim))
        ind = 0

        # 重新整形数据以提取
        SignalOut['data'] = np.reshape(SignalOut['data'], (dim[0], ntot // dim[0]))[:, ind]
        if 'err' in SignalOut:
            if len(SignalOut["err"]) == ntot:
                SignalOut['err'] = np.reshape(SignalOut['err'], (dim[0], ntot // dim[0]))[:, ind]

        # 更新掩码信息
        msk = uly_spect_get(SignalOut, MASK=True)

        if msk is not None:
            if len(msk) == ntot:
                msk = np.reshape(msk, (ntot // dim[0]))[:, ind]
                SignalOut['goodpix'] = np.where(msk == 1)[0].tolist()
    """

    # 3.4 裁剪波长范围
    if waverange is not None:
        # 3.4.1 检查波长格式是否正确
        if len(waverange) > 2:
            raise ValueError("WAVERANGE must be a 1 or 2 elements list!")

        # 3.4.2 计算采样
        # 注意: sampling 为 0 或 1 时, 表示光谱为线性或 ln 对数采样
        if (SignalOut['sampling'] == 0) | (SignalOut['sampling'] == 1):
            wr = waverange
            # 3.4.2.1 将波长范围转换为 ln 对数波长
            if SignalOut['sampling'] == 1:
                wr = np.log(waverange)
            
            # 3.4.2.2 TGM 光谱左侧 MASK 多少点, 根据 ln 对数波长计算
            # 注意: 
            # 1) step 为 LAMOST 的 log10 对数波长间隔转为了 ln 对数波长间隔, 即 ln_step = ln(10) * log10_step
            # 2) 为了避免多删除流量点, 使用 np.floor 向下取整, 而不是 np.ceil 向上取整
            # 3) np.floor 向下取整, 为了避免计算误差, 形如 1.999999 被取为 1 的情况, 加上 0.01 作为补偿
            nummin = int(np.floor((wr[0] - SignalOut["start"]) / SignalOut["step"] + 0.01))
            # 若 nummin 小于 0, 则将 nummin 设置为 1
            if nummin < 0:
                nummin = 1

            # 3.4.2.3 计算 TGM 光谱右侧 MASK 多少点
            # 注意: 
            # 1) 为了避免多删除流量点, 使用 np.ceil 向上取整, 而不是 np.floor 向下取整
            # 2) np.ceil 向上取整, 为了避免计算误差, 形如 2.0000001 被取为 3 的情况, 减去 0.01 作为补偿
            npix = SignalOut['data'].shape[0]
            if len(waverange) == 2:
                nummax = int(np.ceil((wr[1] - SignalOut["start"]) / SignalOut["step"] - 0.01))
                # 若 nummax 大于 npix, 则将 nummax 设置为 npix - 1
                if nummax > npix:
                    nummax = npix - 1
            # 若 nummax 小于 0, 则将 nummax 设置为 0
            else:
                nummax = npix - 1

            # 3.4.2.4 更新光谱的起始位置
            SignalOut['start'] += nummin * SignalOut['step']

            # 3.4.2.5 更新掩码
            if 'goodpix' in SignalOut:
                if len(SignalOut["goodpix"]) > 0:
                    # [3] 即只返回 msk 的值, 不返回 WAVERANGE, GOODPIX, HDR
                    msk = uly_spect_get(SignalOut,   # 光谱字典结构
                                        MASK=True    # 是否返回掩码
                                        )[3][nummin: nummax + 1]
                    SignalOut['goodpix'] = np.where(msk == 1)[0].tolist()

            # 3.4.2.6 裁剪流量和流量误差
            SignalOut['data'] = SignalOut['data'][nummin: nummax + 1]
            if 'err' in SignalOut:
                if len(SignalOut["err"]) > 0:
                    SignalOut['err'] = SignalOut['err'][nummin: nummax + 1]

        # 3.4.3 提取指定波长范围的数据, 非均匀采样, 根据波长的实际值提取
        # 注意: LASP-MPFit 不使用该方法, 如有需要, 请参考 IDL 代码
        elif SignalOut['sampling'] == 2:
            # 3.4.3.1 提取指定波长范围的数据
            lmn = min(SignalOut['wavelen']) if len(waverange) == 0 else waverange[0]
            lmx = max(SignalOut['wavelen']) if len(waverange) <= 1 else waverange[1]

            # 3.4.3.2 获取波长范围内的数据
            extr = np.where((SignalOut['wavelen'] >= lmn) & (SignalOut['wavelen'] <= lmx))[0].tolist()
            if extr.size == 0:
                raise ValueError("No data left in extraction!")

            # 3.4.3.3 更新波长信息
            SignalOut['wavelen'] = SignalOut['wavelen'][extr]
            if 'goodpix' in SignalOut:
                if len(SignalOut["goodpix"]) > 0:
                    mask = np.zeros(SignalOut['data'].size, dtype=int)
                    mask[SignalOut['goodpix']] = 1
                    SignalOut['goodpix'] = np.where(mask[extr] == 1)[0].tolist()

            # 3.4.3.4 裁剪流量和流量误差
            SignalOut['data'] = SignalOut['data'][extr]
            if 'err' in SignalOut:
                if len(SignalOut["err"]) > 0:
                    SignalOut['err'] = SignalOut['err'][extr]

    # 3.5 返回所提取的部分数据
    return SignalOut