# -*- coding: utf-8 -*-
# @Time    : 2025/3/3 21:37
# @Author  : ljc
# @FileName: clean_outliers.py
# @Software: PyCharm


# 1. 简介
"""
目的:
    对 TGM 模型光谱与 LAMOST 光谱流量残差进行异常点裁剪.
函数:
    1) clean_outliers
解释:
    1) clean_outliers 函数: 对 TGM 模型光谱与 LAMOST 光谱流量残差进行异常点裁剪.
    注意: 裁剪分为 3 层, 每层裁剪目的不同, 裁剪的点也不同.
    1.1) 第一层裁剪: 剔除明显离群值 (识别宇宙线、错误数据点等明显异常).
       1.1.1) 计算 TGM 模型光谱与实测光谱的流量残差 (loss) 的标准差 (rbst_sig).
       1.1.2) 使用流量残差阈值 (3 sigma, 4 sigma, 5 sigma, 7 sigma) 识别异常点. 判断标准: abs(loss) - modelgrd > clip_level * rbst_sig.
       1.1.3) 如果检测到的点超过 3%, 则提高阈值, 确保不过度剔除.
       1.1.4) 这里代码与 IDL 或另一 Python 版有些许差异, 但逻辑一致.
    1.2) 第二层裁剪: 在第一层获取的好像素点上进一步剔除离群值 (剔除与主要离群点相关的 "边缘" 像素. 处理如天空线残余, 它们通常影响多个相邻像素).
       1.2.1) 检查第一层好像素点的邻居 (左右 1 个像素) 是否为坏点.
       1.2.2) 使用较低阈值 (2 sigma) 判断第一层获取的好像素点是否轻微离群. 判断标准: abs(loss) - modelgrd > 2 * rbst_sig.
       1.2.3) 如果邻居为坏点, 且第一层的好点满足轻微离群, 则将该点剔除.
       1.2.4) 与 IDL 或另一 Python 版有些许差异:
          1.2.4.1) 这里是从第一层获取的 current_iter_mask 开始, 而不是从坏像素点构造的邻居索引 (near) 开始
          1.2.4.2) IDL 版与另一 Python 版, 从第一层的坏像素点构造 near 存在冗余, 如: near 中存在重复索引, near 中依然存在坏像素点
          1.2.4.3) 至于如何选择, 依据你个人判断吧
    1.3) 第三层裁剪: 在第二层获取的好像素点上进一步剔除离群值 (完整追踪和剔除整个光谱特征, 如发射线).
       1.3.1) 迭代方式追踪光谱特征 (最多 20 次迭代).
       1.3.2) 每次迭代重新计算 TGM 模型流量残差的标准差.
"""


# 2. 导库
from config.config import default_set, set_all_seeds
import torch
# 2.1 设置随机种子
set_all_seeds()
# 2.2 调用 GPU、指定数据类型
dtype, device = default_set()
# 2.3 默认数据类型
torch.set_default_dtype(dtype)


# 3. 定义 Clean 模式
def clean_outliers(j, bestfit, loss, npix, goodPixels0, goodPixels=None, noise=1., quiet=True) -> torch.Tensor:
    
    """
       并行处理 Clean 模式, 可 "累积" 或 "不累积" 记录 MASK 位置.
       注意: 
       1) 累积: 每次 Clean 的掩码与上一次 Clean 的掩码相乘, 计算 loss 时使用累积掩码.
       2) 不累积: 本次 Clean 的掩码, 计算 loss 时使用当前掩码.
       3) 默认不累积.

       输入参数:
       -----------
       j: 
         Adam 优化器推断函数极小值点时的迭代次数.
       bestfit: 
               TGM 模型光谱流量, 形状为 [batch_size, npix] 的 tensor.
       loss: 
            TGM 模型光谱与实测光谱流量残差, 形状为 [batch_size, npix] 的 tensor.
       npix: 
            每条光谱的流量维度.
       goodPixels0: 
                   Clean 前已经确定的好像素点 0 与 1 的掩码, 形状为 [batch_size, npix] 的布尔掩码 tensor.
                   # 注意: 
                   # 1) 0 表示坏像素点, 1 表示好像素点.
                   # 2) goodPixels0 为 Clean 前已经确定的好像素点, 需要个人根据实际情况确定, 这里默认全 1.
       goodPixels: 
                  之前迭代的掩码, 形状为 [batch_size, npix] 的 tensor.
       noise: 
             待测光谱流量不确定性, 默认值为 1.
       quiet: 
             是否安静模式, 默认值为 True.

       输出参数:
       -----------
       mask: 
            累积了所有迭代的异常点, 形状为 [batch_size, npix] 的 tensor.
       goodPixels_final: 
                        表示好像素, 形状为 [batch_size, npix] 的布尔掩码 tensor.
    """

    # 3.1 确保输入数据形状兼容
    if bestfit.dim() == 3:
        # 3.1.1 如果是 [batch_size, npix, 1] 形状, 转为 [batch_size, npix]
        bestfit = bestfit.squeeze(-1)
    if loss.dim() > 2:
        # 3.1.2 确保 loss 是 2D 的
        loss = loss.squeeze(-1)
    # 3.1.3 确保 npix 匹配实际数据维度
    if isinstance(npix, int) and bestfit.shape[1] != npix:
        npix = bestfit.shape[1]

    # 3.2 获取批次大小
    batch_size = bestfit.shape[0]

    # 3.3 初始化掩码, 如果有 goodPixels 则使用它
    # 注意:
    # 1) goodPixels 为上次 Clean 掩码
    if goodPixels is not None:
        # 3.3.1 如果是 [batch_size, npix, 1] 形状, 转为 [batch_size, npix]
        if goodPixels.dim() == 3:
            goodPixels = goodPixels.squeeze(-1)
        # 3.3.2 将之前的掩码转换为浮点型
        goodPixels = goodPixels.float()
    else:
        # 3.3.3 如果没有提供 goodPixels, 创建一个全 1 的初始掩码
        goodPixels = torch.ones_like(bestfit, dtype=dtype, device=device)

    """
    第一层裁剪: 剔除明显离群值 (识别宇宙线、错误数据点等明显异常)
    # 注意: goodPixels 为上次 Clean 的掩码结果
    # 1) 计算残差 (loss * goodPixels) 的标准差 (rbst_sig)
    # 2) 使用自适应 TGM 模型流量残差阈值 (3 sigma, 4 sigma, 5 sigma, 7 sigma) 识别异常点
    # 3) 判断标准: abs(loss) - modelgrd > clip_level * rbst_sig
    # 4) 如果检测到的点超过 3%, 则提高阈值, 确保不过度剔除
    """
    # 3.4 计算模型光谱的梯度谱  
    facsh = 0.5 if j == 0 else 0.2
    # 3.4.1 沿着像素维度计算梯度 (保持批次维度)
    # 3.4.1.1 向前\向后移动一个像素
    bestfit_left, bestfit_right = torch.roll(bestfit, -1, dims=1), torch.roll(bestfit, 1, dims=1)
    # 3.4.1.2 计算梯度谱
    modelgrd = torch.maximum(torch.abs(bestfit - bestfit_left),
                             torch.abs(bestfit - bestfit_right)) * facsh / noise

    # 3.5 为每个样本计算标准差 (保持正确的维度), 形状为 [batch_size, 1]
    # rbst0 = torch.std(loss * goodPixels, dim=1, keepdim=True, correction=1) 
    # 注意: 不同版本 pytorch 的参数有差异, 比如 correction 在其他版本中是 unbiased 参数
    rbst_sig = torch.std(loss * goodPixels, dim=1, keepdim=True, unbiased=True) 

    # 3.5.1 为不同裁剪级别创建掩码
    # 注意:
    # 1) mask3: 3 sigma 裁剪
    # 2) mask4: 4 sigma 裁剪
    # 3) mask5: 5 sigma 裁剪
    # 4) mask7: 7 sigma 裁剪
    mask3 = (torch.abs(loss) - modelgrd) <= (3 * rbst_sig)
    mask4 = (torch.abs(loss) - modelgrd) <= (4 * rbst_sig)
    mask5 = (torch.abs(loss) - modelgrd) <= (5 * rbst_sig)
    mask7 = (torch.abs(loss) - modelgrd) <= (7 * rbst_sig)

    # 3.7 计算每个样本在不同阈值下的异常点百分比
    npix_goodPixels = torch.sum(goodPixels, dim=1)
    outlier_percent3 = 1.0 - torch.sum(mask3.float() * goodPixels, dim=1) / npix_goodPixels
    outlier_percent4 = 1.0 - torch.sum(mask4.float() * goodPixels, dim=1) / npix_goodPixels
    outlier_percent5 = 1.0 - torch.sum(mask5.float() * goodPixels, dim=1) / npix_goodPixels
    # outlier_percent7 = 1.0 - torch.sum(mask7.float() * goodPixels, dim=1) / npix_goodPixels
    
    # 3.8 创建一个新的 0 与 1 的掩码用于本次迭代
    current_iter_mask = torch.ones((batch_size, npix), dtype=dtype, device=device)

    # 3.9 基于异常点百分比选择裁剪级别
    threshold = 0.03

    # 3.9.1 默认使用级别 3, 如果 3 倍标准差剔除了超过 3% 的流量点, 则设置更宽松的剔除级别、即 4 倍标准差
    # 注意: 这里记录了 3 倍标准差剔除了不超过 3% 的流量点的样本
    condition = (outlier_percent3 <= threshold).unsqueeze(1).to(device)
    mask3 = (torch.abs(loss) - modelgrd) <= (3 * rbst_sig)
    current_iter_mask = torch.where(condition, mask3.float(), current_iter_mask)
    current_iter_mask = current_iter_mask * goodPixels0

    # 3.9.2 需要级别 4 的样本, 如果 4 倍标准差剔除了超过 3% 的流量点, 则设置更宽松的剔除级别、即 5 倍标准差
    # 注意: 这里保留了 4 倍标准差剔除了不超过 3% 的流量点的样本
    condition = (outlier_percent3 > threshold) & (outlier_percent4 <= threshold)
    condition = condition.unsqueeze(1).to(device)
    mask4 = (torch.abs(loss) - modelgrd) <= (4 * rbst_sig)
    current_iter_mask = torch.where(condition, mask4.float(), current_iter_mask)
    current_iter_mask = current_iter_mask * goodPixels0

    # 3.9.3 需要级别 5 的样本, 如果 5 倍标准差剔除了超过 3% 的流量点, 则设置更宽松的剔除级别、即 7 倍标准差
    # 注意: 这里保留了 5 倍标准差剔除了不超过 3% 的流量点的样本
    condition = (outlier_percent3 > threshold) & (outlier_percent4 > threshold) & (outlier_percent5 <= threshold)
    condition = condition.unsqueeze(1).to(device)
    mask5 = (torch.abs(loss) - modelgrd) <= (5 * rbst_sig)
    current_iter_mask = torch.where(condition, mask5.float(), current_iter_mask)
    current_iter_mask = current_iter_mask * goodPixels0

    # 3.9.4 需要级别 7 的样本, 这里保留了 7 倍标准差剔除了不超过 3% 的流量点的样本
    condition = (outlier_percent3 > threshold) & (outlier_percent4 > threshold) & (outlier_percent5 > threshold)
    condition = condition.unsqueeze(1).to(device)
    mask7 = (torch.abs(loss) - modelgrd) <= (7 * rbst_sig)
    current_iter_mask = torch.where(condition, mask7.float(), current_iter_mask)
    current_iter_mask = current_iter_mask * goodPixels0

    # 3.9.5 输出统计信息
    if not quiet:
        # 3.9.5.1 计算每个样本中被剔除的点数（值为 0 的点）
        total_outliers = torch.sum(current_iter_mask == 0, dim=1)
        new_outliers = torch.sum((goodPixels == 1) & (current_iter_mask == 0), dim=1)

        # 3.9.5.2 打印每个样本的剔除点数
        for i in range(batch_size):
            print(f'第一层裁剪: 样本 {i + 1}: 总计剔除 {total_outliers[i].item()} 个点, 本次新增 {new_outliers[i].item()} 个点')


    """
    第二层裁剪: 第二层裁剪是对第一层裁剪后的好像素点进行裁剪 (剔除与主要离群点相关的 "边缘" 像素. 处理如天空线残余, 它们通常影响多个相邻像素)
    # 注意: 
    # 1) 检查第一层剔除点的直接相邻像素 (左右各一个) 是否为坏点
    # 2) 使用较低阈值 (2 sigma) 判断是否轻微离群. 判断标准: abs(loss) - modelgrd > 2 * rbst_sig
    # 3) 与 IDL 或另一 Python 版有些许差异:
    #    3.1) 这里是从第一层得到的 current_iter_mask 开始, 而不是从坏像素点构造的邻居索引 (near) 开始
    #    3.2) IDL 版与另一 Python 版, 从第一层的坏像素点构造 near 存在冗余, 如: near 中存在重复索引, near 中依然存在坏像素点
    #    3.3) 至于如何选择, 依据你个人判断吧
    """
    # 3.10 相邻像素
    mask_left, mask_right = torch.roll(current_iter_mask, 1, dims=1), torch.roll(current_iter_mask, -1, dims=1)

    # 3.10.1 左右邻居中有异常的位置为 True, 如果无异常, 则返回 False
    neighbor_outliers = (mask_left == 0) | (mask_right == 0)

    # 3.10.2 检查这些邻居是否满足异常条件
    # 注意: 第一层裁剪剩下的正常点 1), 满足下述条件 2) 与 3), 才为第二层异常点
    # 1) current_iter_mask == 1, 即第一层裁剪获取的正常点 True
    # 2) torch.abs(loss - modelgrd) > 2 * rbst_sig, 即异常点为 True
    # 3) neighbor_outliers, 即第一层的邻居有异常点为 True
    # 4) 如果第一层有些样本都是好像素, 那么第二层裁剪不会剔除那些样本的任何点
    neighbor_condition = (current_iter_mask == 1.) & ((torch.abs(loss) - modelgrd) > 2 * rbst_sig) & neighbor_outliers
    # neighbor_condition = (current_iter_mask == 1.) & (((torch.abs(torch.roll(loss, 1, dims=1)) - torch.roll(modelgrd, 1, dims=1)) > 2 * rbst_sig) | ((torch.abs(torch.roll(loss, -1, dims=1)) - torch.roll(modelgrd, -1, dims=1)) > 2 * rbst_sig)) & neighbor_outliers

    # 3.10.3 更新掩码, 即第一、二层裁剪获取的正常点
    # 注意:
    # 1) current_iter_mask == 1 为第一层裁剪的正常点, 各元素为 1 或 0. 1 表示正常点, 0 表示异常点
    # 2) (~neighbor_condition).float() 为第二层裁剪的正常点, 各元素为 1 或 0. 1 表示正常点, 0 表示异常点
    current_iter_mask = current_iter_mask * (~neighbor_condition).float()

    # 3.10.4 输出统计信息
    if not quiet:
        # 3.10.4.1 计算每个样本中被剔除的点数 (值为 0 的点)
        total_outliers = torch.sum(current_iter_mask == 0, dim=1)
        new_outliers = torch.sum((goodPixels == 1) & (current_iter_mask == 0), dim=1)

        # 3.10.4.2 打印每个样本的剔除点数
        for i in range(batch_size):
            print(f'前两层裁剪: 样本 {i + 1}: 总计剔除 {total_outliers[i].item()} 个点，本次新增 {new_outliers[i].item()} 个点')

    """
    第三层裁剪: 在第二层获取的好像素点上进一步剔除离群值 (完整追踪和剔除整个光谱特征, 如发射线)
    # 注意:
    # 1) 迭代方式追踪光谱特征 (最多 20 次迭代)
    # 2) 每次迭代重新计算 TGM 模型流量残差的标准差
    # 3) 查找满足五个特定条件的点: 
    #    3.1) 前两层裁剪后的好像素点, refining_mask == 1
    #    3.2) 相邻点已被剔除, mask_shifted_back=0, 该条件多余, 但与 IDL 版保持一致
    #    3.3) 当前点与相邻点残差符号相同, loss_forward * loss > 0
    #    3.4) 当前点残差超过标准差, np.abs(loss) - modelgrd > r_sig
    #    3.5) 当前点残差小于相邻已剔除点残差, np.abs(loss) <= np.abs(loss_forward), 不理解该条件, 仅与 IDL 版保持一致
    """
    # 3.11 迭代细化掩码
    # r_sig = torch.std(loss * current_iter_mask, dim=1, keepdim=True, correction=1)
    refining_mask = current_iter_mask.clone()

    # 3.12 迭代细化掩码, 最多 20 次迭代
    for k in range(20):
        
        # 3.12.1 保存上一次迭代后的掩码
        prev_refining_mask = refining_mask.clone()

        # 3.12.2 计算有效像素的标准差, 注意: 不同版本 pytorch 的参数有差异, 比如 correction 参数在其他版本中是 unbiased 参数
        r_sig = torch.std(loss * prev_refining_mask, dim=1, keepdim=True, unbiased=True)

        # 3.12.3 向前和向后移动的掩码和拟合的流量残差
        mask_forward, loss_forward = torch.roll(refining_mask, 1, dims=1), torch.roll(loss, 1, dims=1)
        mask_backward, loss_backward = torch.roll(refining_mask, -1, dims=1), torch.roll(loss, -1, dims=1)

        # 3.12.4 向前检查条件
        cond1_forward = (refining_mask == 1) & (mask_forward == 0)
        cond2_forward = (loss_forward * loss > 0)
        cond3 = (torch.abs(loss) - modelgrd > r_sig)
        cond4_forward = (torch.abs(loss) <= torch.abs(loss_forward))
        # 前两层裁剪后的好像素点、但右侧邻居为坏点、且二者同号、且 "TGM 模型流量残差" 与 "流量残差梯度" 之差大于 "TGM 模型流量残差" 标准差、且 "TGM 模型流量残差" 绝对值小于相邻点残差绝对值
        forward_outliers = cond1_forward & cond2_forward & cond3 & cond4_forward

        # 3.12.5 向后检查条件
        cond1_backward = (refining_mask == 1) & (mask_backward == 0)
        cond2_backward = (loss_backward * loss > 0)
        cond4_backward = (torch.abs(loss) <= torch.abs(loss_backward))
        # 前两层裁剪后的好像素点、但左侧邻居为坏点、且二者同号、且 "TGM 模型流量残差" 与 "流量残差梯度" 之差大于 "TGM 模型流量残差" 标准差、且 "TGM 模型流量残差" 绝对值小于相邻点残差绝对值
        backward_outliers = cond1_backward & cond2_backward & cond3 & cond4_backward

        # 3.12.6 更新掩码, 即第一、二、三层裁剪获取的正常点
        # 注意:
        # 1) refining_mask 为前两层裁剪的正常点, 各元素为 1 或 0. 1 表示正常点, 0 表示异常点
        # 2) (~(forward_outliers | backward_outliers)).float() 为第三层裁剪的正常点, 各元素为 1 或 0. 1 表示正常点, 0 表示异常点
        refining_mask = refining_mask * (~(forward_outliers | backward_outliers)).float()

        # 3.12.7 检查是否收敛
        if torch.all(refining_mask == prev_refining_mask):
            break

    # 3.13 本次迭代的最终掩码
    current_iter_final_mask = refining_mask
    # 3.13.1 如果本次迭代的最终掩码与初始掩码相同, 则退出循环
    # if torch.all(current_iter_final_mask == goodPixels):
        # break

    # 3.14 掩码: 1 表示正常点, 0 表示异常点
    # 注意:
    # 1) 累积: 每次 Clean 的掩码与上一次 Clean 的掩码相乘, 计算 loss 时使用累积掩码
    # 2) 不累积: 本次 Clean 的掩码, 计算 loss 时使用当前掩码
    # 3) 默认不累积
    # 3.14.1 累积
    # final_mask = goodPixels * current_iter_final_mask
    # 3.14.2 不累积
    final_mask = current_iter_final_mask

    # 3.15 创建好像素的布尔掩码
    # goodPixels_final = (final_mask == 1.)

    # 3.16 输出统计信息
    if not quiet:
        # 3.16.1 计算每个样本中被剔除的点数（值为0的点）
        total_outliers = torch.sum(final_mask == 0, dim=1)
        new_outliers = torch.sum((goodPixels == 1) & (current_iter_final_mask == 0), dim=1)

        # 3.16.2 打印每个样本的剔除点数
        for i in range(batch_size):
            print(f'前三层裁剪: 样本 {i + 1}: 总计剔除 {total_outliers[i].item()} 个点, 本次新增 {new_outliers[i].item()} 个点')

    # 3.17 返回最终的掩码
    return final_mask
