# -*- coding: utf-8 -*-
# @Time    : 2025/1/5 15:16
# @Author  : ljc
# @FileName: uly_tgm_eval_pytorch.py
# @Software: PyCharm
# Update:  2025/11/26 22:12:14


# ======================================================================
# 1. Introduction
# ======================================================================
"""PyTorch implementation of the spectral emulator for LASP-Adam-GPU.

1.1 Purpose
-----------
Generate batches of N model spectra at specified Teff, log g, and
[Fe/H] using polynomial functions:
    F = a1*Teff+a2*logg+a3[Fe/H]+a4*Teff*Teff+....
      = torch.matmul(param, spec_coef)
where:
    - F: N model spectra generated by polynomial spectral emulator
    - param: Basis function values composed of Teff, log g, [Fe/H]
             (Teff, logg, [Fe/H], Teff*Teff, ...)
    - spec_coef: Polynomial coefficient matrix (a1, a2, a3, ...)

1.2 Functions
-------------
1) spec_param_version_2_warm: Compute basis functions for warm stars.
2) spec_param_version_2_hot: Compute basis functions for hot stars.
3) spec_param_version_2_cold: Compute basis functions for cold stars.
4) uly_tgm_eval: Generate a group of N model spectra at specified
   stellar parameters.

1.3 Explanation
---------------
This module provides batch processing functions for PyLASP-Adam-GPU.
Steps:
    1) Compute polynomial basis function values for a group of N spectra
       at specified Teff, log g, and [Fe/H].
    2) Group spectra by Teff range (<= 4000, 4000-4550, 4550-7000,
       7000-9000, >= 9000) since each batch may contain samples from
       different temperature regimes.
    3) Compute model spectra using polynomial coefficients within each
       temperature regime.
    4) Interpolate between temperature regimes in transition zones.
    5) Return batch of N model spectra.

"""


# ======================================================================
# 2. Import libraries
# ======================================================================
from config.config import default_set, set_all_seeds
import torch
import math

# 2.1 Set random seed
set_all_seeds()
# 2.2 Call GPU and specify data type
dtype, device = default_set()
# 2.3 Set default data type
torch.set_default_dtype(dtype)


# ======================================================================
# 3. Type definitions for better code readability
# ======================================================================
TensorLike = torch.Tensor


# ======================================================================
# 4. Warm stars polynomial basis function computation
# ======================================================================
def spec_param_version_2_warm(labels_batch: TensorLike, group_size: int) -> TensorLike:
    r"""Compute polynomial basis functions for warm stars batch.

    Parameters
    ----------
    labels_batch : torch.Tensor
        shape (group_size, 3)
        Stellar parameters array [Teff, log g, [Fe/H]].
        Teff is in scaled logarithmic form: (log10(Teff) - 3.7617).
        log g is in standard form.
        [Fe/H] is in standard form.
    group_size : int
        Number of samples in the batch.

    Returns
    -------
    param : torch.Tensor
        shape (23, group_size)
        Polynomial basis function values computed from input stellar
        parameters.

    Notes
    -----
    - Applies to stars with Teff in warm temperature regime
      (4550 K < T <= 7000 K).

    Examples
    --------
    >>> labels = torch.tensor([[0.25, -0.44, 0.0], [0.30, 4.0, -0.5]],
    ...                       dtype=dtype, device=device)
    >>> param = spec_param_version_2_warm(labels, 2)
    >>> print(param.shape)
    torch.Size([23, 2])

    """

    # ------------------------------------------------------------------
    # 4.1 Extract stellar parameters from input batch
    # ------------------------------------------------------------------
    teff = labels_batch[:, 0]
    gravi = labels_batch[:, 1]
    feh = labels_batch[:, 2]

    # ------------------------------------------------------------------
    # 4.2 Initialize parameter combination matrix
    # ------------------------------------------------------------------
    param = torch.zeros((23, group_size), device=device, dtype=dtype)

    # ------------------------------------------------------------------
    # 4.3 Compute scaled temperature variables
    # ------------------------------------------------------------------
    tt = teff / 0.2
    tt2 = tt**2 - 1.0

    # ------------------------------------------------------------------
    # 4.4 Fill polynomial basis function terms
    # ------------------------------------------------------------------
    param[0, :] = 1.0
    param[1, :] = tt
    param[2, :] = feh
    param[3, :] = gravi
    param[4, :] = tt**2
    param[5, :] = tt * tt2
    param[6, :] = tt2**2
    param[7, :] = tt * feh
    param[8, :] = tt * gravi
    param[9, :] = tt2 * gravi
    param[10, :] = tt2 * feh
    param[11, :] = gravi**2
    param[12, :] = feh**2
    param[13, :] = tt * (tt2**2)
    param[14, :] = tt * (gravi**2)
    param[15, :] = gravi**3
    param[16, :] = feh**3
    param[17, :] = tt * (feh**2)
    param[18, :] = gravi * feh
    param[19, :] = (gravi**2) * feh
    param[20, :] = gravi * (feh**2)
    param[21, :] = (
        torch.exp(tt) - 1 - tt * (1 + tt / 2 + tt**2 / 6 + tt**3 / 24 + tt**4 / 120)
    )
    param[22, :] = (
        torch.exp(2 * tt)
        - 1
        - 2 * tt * (1 + tt + 2 / 3 * tt**2 + tt**3 / 3 + tt**4 * 2 / 15)
    )

    # ------------------------------------------------------------------
    # 4.5 Return parameter matrix on specified device
    # ------------------------------------------------------------------
    return param


# ======================================================================
# 5. Hot stars polynomial basis function computation
# ======================================================================
def spec_param_version_2_hot(labels_batch: TensorLike, group_size: int) -> TensorLike:
    r"""Compute polynomial basis functions for hot stars batch.

    Parameters
    ----------
    labels_batch : torch.Tensor
        shape (group_size, 3)
        Stellar parameters array [Teff, log g, [Fe/H]].
        Teff is in scaled logarithmic form: (log10(Teff) - 3.7617).
        log g is in standard form.
        [Fe/H] is in standard form.
    group_size : int
        Number of samples in the batch.

    Returns
    -------
    param : torch.Tensor
        shape (23, group_size)
        Polynomial basis function values computed from input stellar
        parameters.

    Notes
    -----
    - Applies to stars with Teff in hot temperature regime
      (T >= 7000 K).

    Examples
    --------
    >>> labels = torch.tensor([[0.50, 2.0, 0.0], [0.55, 1.5, -0.3]],
    ...                       dtype=dtype, device=device)
    >>> param = spec_param_version_2_hot(labels, 2)
    >>> print(param.shape)
    torch.Size([23, 2])

    """

    # ------------------------------------------------------------------
    # 5.1 Extract stellar parameters from input batch
    # ------------------------------------------------------------------
    teff = labels_batch[:, 0]
    gravi = labels_batch[:, 1]
    feh = labels_batch[:, 2]

    # ------------------------------------------------------------------
    # 5.2 Initialize parameter combination matrix
    # ------------------------------------------------------------------
    param = torch.zeros((23, group_size), device=device, dtype=dtype)

    # ------------------------------------------------------------------
    # 5.3 Compute scaled temperature variables
    # ------------------------------------------------------------------
    tt = teff / 0.2
    tt2 = tt**2 - 1.0

    # ------------------------------------------------------------------
    # 5.4 Fill polynomial basis function terms
    # ------------------------------------------------------------------
    param[0, :] = 1.0
    param[1, :] = tt
    param[2, :] = feh
    param[3, :] = gravi
    param[4, :] = tt**2
    param[5, :] = tt * tt2
    param[6, :] = tt2**2
    param[7, :] = tt * feh
    param[8, :] = tt * gravi
    param[9, :] = tt2 * gravi
    param[10, :] = tt2 * feh
    param[11, :] = gravi**2
    param[12, :] = feh**2
    param[13, :] = tt * (tt2**2)
    param[14, :] = tt * (gravi**2)
    param[15, :] = gravi**3
    param[16, :] = feh**3
    param[17, :] = tt * (feh**2)
    param[18, :] = gravi * feh
    param[19, :] = (gravi**2) * feh
    param[20, :] = gravi * (feh**2)
    param[21, :] = (
        torch.exp(tt) - 1 - tt * (1 + tt / 2 + tt**2 / 6 + tt**3 / 24 + tt**4 / 120)
    )
    param[22, :] = (
        torch.exp(2 * tt)
        - 1
        - 2 * tt * (1 + tt + 2 / 3 * tt**2 + tt**3 / 3 + tt**4 * 2 / 15)
    )

    # ------------------------------------------------------------------
    # 5.5 Return parameter matrix on specified device
    # ------------------------------------------------------------------
    return param


# ======================================================================
# 6. Cold stars polynomial basis function computation
# ======================================================================
def spec_param_version_2_cold(labels_batch: TensorLike, group_size: int) -> TensorLike:
    r"""Compute polynomial basis functions for cold stars batch.

    Parameters
    ----------
    labels_batch : torch.Tensor
        shape (group_size, 3)
        Stellar parameters array [Teff, log g, [Fe/H]].
        Teff is in scaled logarithmic form: (log10(Teff) - 3.7617).
        log g is in standard form.
        [Fe/H] is in standard form.
    group_size : int
        Number of samples in the batch.

    Returns
    -------
    param : torch.Tensor
        shape (23, group_size)
        Polynomial basis function values computed from input stellar
        parameters.

    Notes
    -----
    - Applies to stars with Teff in cold temperature regime
      (T <= 4550 K).

    Examples
    --------
    >>> labels = torch.tensor([[-0.10, 4.5, 0.0], [-0.05, 4.0, -0.5]],
    ...                       dtype=dtype, device=device)
    >>> param = spec_param_version_2_cold(labels, 2)
    >>> print(param.shape)
    torch.Size([23, 2])

    """

    # ------------------------------------------------------------------
    # 6.1 Extract stellar parameters from input batch
    # ------------------------------------------------------------------
    teff = labels_batch[:, 0]
    gravi = labels_batch[:, 1]
    feh = labels_batch[:, 2]

    # ------------------------------------------------------------------
    # 6.2 Initialize parameter combination matrix
    # ------------------------------------------------------------------
    param = torch.zeros((23, group_size), device=device, dtype=dtype)

    # ------------------------------------------------------------------
    # 6.3 Compute scaled temperature variables with cold star offset
    # ------------------------------------------------------------------
    tt = (teff + 0.1) / 0.2
    tt2 = tt**2 - 1.0

    # ------------------------------------------------------------------
    # 6.4 Fill polynomial basis function terms
    # ------------------------------------------------------------------
    param[0, :] = 1.0
    param[1, :] = tt
    param[2, :] = feh
    param[3, :] = gravi
    param[4, :] = tt**2
    param[5, :] = tt * tt2
    param[6, :] = tt2**2
    param[7, :] = tt * feh
    param[8, :] = tt * gravi
    param[9, :] = tt2 * gravi
    param[10, :] = tt2 * feh
    param[11, :] = gravi**2
    param[12, :] = feh**2
    param[13, :] = tt * (tt2**2)
    param[14, :] = tt * (gravi**2)
    param[15, :] = gravi**3
    param[16, :] = feh**3
    param[17, :] = tt * (feh**2)
    param[18, :] = gravi * feh
    param[19, :] = (gravi**2) * feh
    param[20, :] = gravi * (feh**2)
    param[21, :] = (
        torch.exp(tt) - 1 - tt * (1 + tt / 2 + tt**2 / 6 + tt**3 / 24 + tt**4 / 120)
    )
    param[22, :] = (
        torch.exp(2 * tt)
        - 1
        - 2 * tt * (1 + tt + 2 / 3 * tt**2 + tt**3 / 3 + tt**4 * 2 / 15)
    )

    # ------------------------------------------------------------------
    # 6.5 Return parameter matrix on specified device
    # ------------------------------------------------------------------
    return param


# ======================================================================
# 7. Batch model spectrum emulator
# ======================================================================
def uly_tgm_eval(spec_coef: TensorLike, para: TensorLike) -> TensorLike:
    r"""Generate a group of N model spectra at specified parameters.

    Parameters
    ----------
    spec_coef : torch.Tensor
        shape (23, n_wavelengths, 3)
        Model polynomial coefficient matrix. Three sets of coefficients
        correspond to warm, hot, and cold temperature regimes.
    para : torch.Tensor
        shape (group_size, 3)
        Stellar parameters array for batch of N spectra.
        Columns are [Teff, log g, [Fe/H]].
        Teff is in scaled logarithmic form: (log10(Teff) - 3.7617).
        log g is in standard form.
        [Fe/H] is in standard form.

    Returns
    -------
    result : torch.Tensor
        shape (group_size, n_wavelengths)
        Batch of N model spectra generated by model polynomial emulator.

    Notes
    -----
    - Function processes spectra in five temperature regimes:
      1) T <= 4000 K
      2) 4000 K < T <= 4550 K
      3) 4550 K < T <= 7000 K
      4) 7000 K < T < 9000 K
      5) T >= 9000 K
    - Smooth linear interpolation is applied in transition zones.
    - Each group of N observed spectra may contain samples from different
      temperature regimes, so spectra are grouped by Teff range before
      computing polynomial basis functions.

    Examples
    --------
    >>> spec_coef = torch.rand(23, 7506, 3, dtype=dtype, device=device)
    >>> para = torch.tensor([[0.25, 4.44, 0.0], [0.50, 2.0, -0.5]],
    ...                     dtype=dtype, device=device)
    >>> model = uly_tgm_eval(spec_coef, para)
    >>> print(model.shape)
    torch.Size([2, 7506])

    """

    # ------------------------------------------------------------------
    # 7.1 Get batch size, spectral dimensions and Teff
    # ------------------------------------------------------------------
    group_size = para.shape[0]
    n_wavelengths = spec_coef.shape[1]
    teff = para[:, 0]

    # ------------------------------------------------------------------
    # 7.2 Pre-compute temperature boundary values
    # ------------------------------------------------------------------
    t_4000, t_4550, t_7000, t_9000 = (
        math.log10(4000.0) - 3.7617,
        math.log10(4550.0) - 3.7617,
        math.log10(7000.0) - 3.7617,
        math.log10(9000.0) - 3.7617,
    )

    # ------------------------------------------------------------------
    # 7.3 Initialize output tensor for N model spectra
    # ------------------------------------------------------------------
    result = torch.zeros((group_size, n_wavelengths), device=device, dtype=dtype)

    # ------------------------------------------------------------------
    # 7.4 Region 1: Cold stars only (T <= 4000 K)
    # ------------------------------------------------------------------
    cold_only_mask = teff <= t_4000
    if cold_only_mask.any():
        # 7.4.1 Compute polynomial basis functions for cold regime
        param_cold = spec_param_version_2_cold(
            labels_batch=para[cold_only_mask],
            group_size=cold_only_mask.sum(),
        )
        # 7.4.2 Generate model spectra using cold coefficients
        t3 = torch.matmul(param_cold.T, spec_coef[:, :, 2])
        result[cold_only_mask, :] = t3

    # ------------------------------------------------------------------
    # 7.5 Region 2: Transition zone (4000 K < T <= 4550 K)
    # ------------------------------------------------------------------
    trans1_mask = (teff > t_4000) & (teff <= t_4550)
    if trans1_mask.any():
        # 7.5.1 Compute polynomial basis functions for both regimes
        para_trans1 = para[trans1_mask]
        param_cold = spec_param_version_2_cold(
            labels_batch=para_trans1,
            group_size=trans1_mask.sum(),
        )
        param_warm = spec_param_version_2_warm(
            labels_batch=para_trans1,
            group_size=trans1_mask.sum(),
        )
        # 7.5.2 Generate model spectra using both coefficient sets
        t1 = torch.matmul(param_warm.T, spec_coef[:, :, 0])
        t3 = torch.matmul(param_cold.T, spec_coef[:, :, 2])
        # 7.5.3 Linear interpolation between cold and warm spectra
        q = ((teff[trans1_mask] - t_4000) / (t_4550 - t_4000)).unsqueeze(1)
        result[trans1_mask, :] = q * t1 + (1.0 - q) * t3
        # result[trans1_mask, :] = torch.lerp(t3, t1, q)

    # ------------------------------------------------------------------
    # 7.6 Region 3: Warm stars only (4550 K < T <= 7000 K)
    # ------------------------------------------------------------------
    warm_only_mask = (teff > t_4550) & (teff <= t_7000)
    if warm_only_mask.any():
        # 7.6.1 Compute polynomial basis functions for warm regime
        param_warm = spec_param_version_2_warm(
            labels_batch=para[warm_only_mask],
            group_size=warm_only_mask.sum(),
        )
        # 7.6.2 Generate model spectra using warm coefficients
        t1 = torch.matmul(param_warm.T, spec_coef[:, :, 0])
        result[warm_only_mask, :] = t1

    # ------------------------------------------------------------------
    # 7.7 Region 4: Transition zone (7000 K < T < 9000 K)
    # ------------------------------------------------------------------
    trans2_mask = (teff > t_7000) & (teff < t_9000)
    if trans2_mask.any():
        # 7.7.1 Compute polynomial basis functions for both regimes
        para_trans2 = para[trans2_mask]
        param_warm = spec_param_version_2_warm(
            labels_batch=para_trans2,
            group_size=trans2_mask.sum(),
        )
        param_hot = spec_param_version_2_hot(
            labels_batch=para_trans2,
            group_size=trans2_mask.sum(),
        )
        # 7.7.2 Generate model spectra using both coefficient sets
        t1 = torch.matmul(param_warm.T, spec_coef[:, :, 0])
        t2 = torch.matmul(param_hot.T, spec_coef[:, :, 1])
        # 7.7.3 Linear interpolation between warm and hot spectra
        q = ((teff[trans2_mask] - t_7000) / (t_9000 - t_7000)).unsqueeze(1)
        result[trans2_mask, :] = q * t2 + (1.0 - q) * t1
        # result[trans2_mask, :] = torch.lerp(t1, t2, q)

    # ------------------------------------------------------------------
    # 7.8 Region 5: Hot stars only (T >= 9000 K)
    # ------------------------------------------------------------------
    hot_only_mask = teff >= t_9000
    if hot_only_mask.any():
        # 7.8.1 Compute polynomial basis functions for hot regime
        param_hot = spec_param_version_2_hot(
            labels_batch=para[hot_only_mask],
            group_size=hot_only_mask.sum(),
        )
        # 7.8.2 Generate model spectra using hot coefficients
        t2 = torch.matmul(param_hot.T, spec_coef[:, :, 1])
        result[hot_only_mask, :] = t2

    # ------------------------------------------------------------------
    # 7.9 Return batch of N generated model spectra
    # ------------------------------------------------------------------
    return result
