# -*- coding: utf-8 -*-
# @Time    : 2025/1/4 16:05
# @Author  : ljc
# @FileName: convol_pytorch.py
# @Software: PyCharm


# 1. 简介
"""
 Python conversion of the IDL convol .
目的:
    降低 TGM 光谱的分辨率.
函数:
    1) convol
解释:
    1) convol 函数: 降低 TGM 光谱的分辨率.
注意:
    1) LASP 的 IDL 版本、目前 Python 版本动态降低分辨率, 即: 同一条光谱, 每一次迭代优化时卷积核均值、标准差、卷积核大小不一样. 不同光谱, 卷积核也不一样.
    2) 目前, 对每轮迭代时的核函数大小分组处理, 即: 核函数大小去重、具有相同核函数大小的 TGM 光谱一起卷积、合并具有不同核函数大小的所有卷积结果.
    3) 可自定义所有光谱使用相同的卷积核, 即卷积核均值、标准差、大小都一样. 这样会更高效, 但参数推断精度可能降低、或存在系统差 ([Fe/H] 可能更敏感).
"""


# 2. 导库
from config.config import default_set, set_all_seeds
import torch
import torch.nn.functional as F
# 2.1 设置随机种子
set_all_seeds()
# 2.2 调用 GPU、指定数据类型
dtype, device = default_set()
# 2.3 默认数据类型
torch.set_default_dtype(dtype)


# 3. 批量降低光谱分辨率
def convol(models, kernal_mu, kernal_stds) -> torch.Tensor:

    """
        按不同 kernel_size 分别卷积.

        输入:
        -----------
        models: torch.Tensor [batch_size, spectrum_length]
               TGM 光谱数据, 比如 shape 为 [1000, 7000].
        kernal_mu: torch.Tensor [batch_size, 1]
                  每个样本的 mu 值, shape 为 [1000, 1].
        kernal_stds: torch.Tensor [batch_size, 1]
                   每个样本的 std 值, shape 为 [1000, 1].

        Returns:
        --------
        torch.Tensor [batch_size, spectrum_length, 1]
               降低分辨率后的光谱.
    """

    # 3.1 计算用于降低每条 TGM 生成谱的 kernal size, kernal 的标准差最小值限制为 0.1
    kernal_stds = torch.clamp(kernal_stds, min=0.1)
    dx = torch.ceil(torch.abs(kernal_mu) + 5.0 * kernal_stds)
    kernel_sizes = (2 * dx + 1).long().view(-1)

    # 3.2 对于每组 LAMOST 光谱, 对所有 kernel sizes 去重、排序
    unique_sizes = torch.unique(kernel_sizes)

    # 3.3 创建结果张量
    result = torch.zeros_like(models)

    # 3.4 对每个 kernel size 分别处理
    for size in unique_sizes:
        # 3.4.1 找到当前 size 对应的样本索引
        indices = torch.where(kernel_sizes == size)[0]
        current_batch_size = len(indices)
        if current_batch_size == 0:
            continue

        # 3.4.2 获取当前组的 TGM 光谱流量, 形状为 [current_batch_size, spectrum_length]
        current_models = models[indices, :]
        # 3.4.3 kernel size 为 size 时的核函数期望值的形状为 [current_batch_size, 1]
        current_mu = kernal_mu[indices, :]
        # 3.4.4 kernel size 为 size 时的核函数标准差的形状为 [current_batch_size, 1]
        current_std = kernal_stds[indices, :]
        current_dx = dx[indices, :]

        # 3.4.5 生成当前 size 的 x 值, 即 kernal 值, 形状为 [1, size]
        x = current_dx - torch.arange(size, device=device)
        # 3.4.6 计算 kernal, mask 掉异常 kernal 值
        w = (x - current_mu) / current_std
        w2 = w * w
        mask = torch.abs(w) <= 5.0
        kernal = torch.exp(-0.5 * w2) / (torch.sqrt(torch.tensor(2.0 * torch.pi)) * current_std)
        kernal = kernal * mask
        # 3.4.7 kernal 归一化
        kernal = kernal / torch.sum(kernal, dim=1, keepdim=True)

        # 3.4.8 改变 TGM 光谱流量、kernal 形状, 便于卷积, TGM 光谱流量形状为 [1, current_batch_size, spectrum_length]
        current_models = current_models.view(1, current_batch_size, -1)
        # 3.4.9 卷积 TGM 光谱的 kernal 形状为 [current_batch_size, 1, size]
        kernal = kernal.view(current_batch_size, 1, -1)

        # 3.4.10 填充 TGM 光谱流量两端
        pad_size = (size - 1) // 2
        padded = F.pad(current_models, (pad_size, pad_size), mode='replicate')

        # 3.4.11 批量卷积, 形状为 [1, current_batch_size, spectrum_length]
        conv_result = F.conv1d(
            padded,                    # TGM 光谱流量形状为 [1, current_batch_size, padded_length]
            kernal,                    # 卷积 TGM 光谱的 kernal 形状为 [current_batch_size, 1, kernel_size]
            groups=current_batch_size  # 分组卷积
        )

        # 3.4.12 保存结果
        result[indices, :] = conv_result.squeeze(0).squeeze(1)

    # 3.4.13 低分辨率光谱形状为 [batch_size, spectrum_length, 1]
    return result.unsqueeze(-1)