import sys

import numpy as np
import pywt

sys.path.insert(0, "/home/cl/sdc2/HISourceFinder-master-l/src/")

import pandas as pd


def wavelet_7_9_stronger_denoise_3d(data, threshold_factor=3.0, levels=2, mode="soft"):
    """
    Apply a multi-level 7/9 wavelet transform along the Z-axis for stronger denoising,
    then reconstruct the signal.

    Parameters
    ----------
    data : np.ndarray
        Input 3D array of shape (Z, Y, X).
    threshold_factor : float, optional
        Multiplicative factor for noise thresholding. A higher value removes more noise.
        Default is 3.0.
    levels : int, optional
        Number of decomposition levels for wavelet transform.
        Default is 2 (stronger denoising).
    mode : str, optional
        Thresholding mode ('soft' or 'hard'). Default is 'soft'.

    Returns
    -------
    filtered_data : np.ndarray
        Reconstructed 3D array after stronger wavelet-based denoising.
    """
    z_dim, y_dim, x_dim = data.shape

    # Apply multi-level wavelet transform along Z-axis
    coeffs = pywt.wavedec(data, "bior2.2", axis=0, level=levels)

    # Apply thresholding to high-frequency components
    for i in range(1, len(coeffs)):  # Skip approximation coefficients (cA)
        threshold = threshold_factor * np.std(coeffs[i])
        coeffs[i] = pywt.threshold(coeffs[i], threshold, mode=mode)

    # Perform inverse wavelet transform to reconstruct filtered data
    filtered_data = pywt.waverec(coeffs, "bior2.2", axis=0)

    return filtered_data


fits_file = ""
csv_file = ""

# 读取标签并筛选
labels = pd.read_csv(csv_file)
labels = labels[
    (labels["z"] >= 0)
    & (labels["z"] < 6668)
    & (labels["y"] >= 0)
    & (labels["y"] < 608)
    & (labels["x"] >= 576)
    & (labels["x"] < 1280)
]
cols = [
    "id",
    "ra",
    "dec",
    "hi_size",
    "line_flux_integral",
    "central_freq",
    "pa",
    "i",
    "w20",
]
labels = labels[cols]
labels.to_csv("", index=False)
print(len(labels))
# # print(len(s_labels))
# # labels = labels[labels['n_channels'] > 12]  # 筛选出较长的源\
# labels = labels[labels['line_flux_integral'] < 20]  # 筛选出较大的源
# # labels = labels[labels['line_flux_integral'] >= (labels['n_channels'] -12)*2]  # 筛选出较强的源
# # labels = labels[labels['n_channels'] < 12]  # 筛选出较长的源\
# # labels = labels[labels['line_flux_integral'] < 35]  # 筛选出较大的源
# # labels = labels[labels['line_flux_integral'] <= (labels['n_channels'] -12)*2]  # 筛选出较强的源
# # new_size = (192, 192)  # 目标 xy 维度
# data = fits.getdata(fits_file)
# header = fits.getheader(fits_file)
# # denoiser = SimpleStarletDenoiser()
# # data = (data - np.min(data))/ (np.max(data) - np.min(data))
# # 设置分块大小
# block_size = 128
# eps = 1e-6
# C, H, W = data.shape
# overlap = 32
# block_size_z = 128  # z 轴切块大小，你自己设
# block_size_y = 128  # y 轴切块大小
# block_size_x = 128  # x 轴切块大小
# stride_z = block_size_z - overlap
# stride_y = block_size_y - overlap
# stride_x = block_size_x - overlap

# num_blocks_z = (C - overlap + stride_z - 1) // stride_z
# num_blocks_y = (H - overlap + stride_y - 1) // stride_y
# num_blocks_x = (W - overlap + stride_x - 1) // stride_x
# means = data.mean(axis=(1,2), keepdims=True)   # 每个频率通道的均值 (bz,1,1)
# stds  = data.std(axis=(1,2), keepdims=True)    # 每个频率通道的标准差 (bz,1,1)
# block_std = (data - means) / (stds + eps)

# # —— 第二步：Min–Max 归一化 ——
# mins = block_std.min(axis=(1,2), keepdims=True)     # 标准化后每通道最小值 (bz,1,1)
# maxs = block_std.max(axis=(1,2), keepdims=True)     # 标准化后每通道最大值 (bz,1,1)
# data = (block_std - mins) / (maxs - mins + eps) # 每个频率通道的标准差 (bz,1,1)
# output_dir = '/home/cl/sdc2/dataset/dataset_128/'
# os.makedirs(output_dir, exist_ok=True)
# print(C, H, W)

# for z in range(num_blocks_z-1):
#     for y in range(num_blocks_y-1):
#         for x in range(num_blocks_x-1):
#             z_start = z * stride_z
#             y_start = y * stride_y
#             x_start = x * stride_x
#             z_end   = min(z_start + block_size_z, C)
#             y_end   = min(y_start + block_size_y, H)
#             x_end   = min(x_start + block_size_x, W)

#             # 获取原始子块
#             sub_block = data[z_start:z_end, y_start:y_end, x_start:x_end]  # shape = (bz, by, bx)

#             # # —— 第一步：Z-score 标准化 ——
#             # means = sub_block.mean(axis=(1,2), keepdims=True)   # 每个频率通道的均值 (bz,1,1)
#             # stds  = sub_block.std(axis=(1,2), keepdims=True)    # 每个频率通道的标准差 (bz,1,1)
#             # block_std = (sub_block - means) / (stds + eps)

#             # # —— 第二步：Min–Max 归一化 ——
#             # mins = block_std.min(axis=(1,2), keepdims=True)     # 标准化后每通道最小值 (bz,1,1)
#             # maxs = block_std.max(axis=(1,2), keepdims=True)     # 标准化后每通道最大值 (bz,1,1)
#             # block_norm = (block_std - mins) / (maxs - mins + eps)

#             # # 输出为 float32 写 FITS
#             # output_block = block_norm.astype(np.float32)

#             # 保存 FITS 子块
#             fits_dir = os.path.join(output_dir, 'fre_standardized_normalized_mask')
#             os.makedirs(fits_dir, exist_ok=True)
#             block_filename = f"block_{z}_{y}_{x}.fits"
#             block_filepath = os.path.join(fits_dir, block_filename)

#             fits.writeto(block_filepath, sub_block, overwrite=True, header=header)
#             print(f"Saved standardized+normalized block: {block_filepath}")
# for z in range(num_blocks_z-1):
#     for y in range(num_blocks_y-1):
#         for x in range(num_blocks_x-1):
#             z_start = z * stride_z
#             y_start = y * stride_y
#             x_start = x * stride_x
#             z_end = min(z_start + block_size_z, C)
#             y_end = min(y_start + block_size_y, H)
#             x_end = min(x_start + block_size_x, W)
#             # 获取子块
#             sub_block = data[z_start:z_end, y_start:y_end, x_start:x_end]
#             targets_in_cube = labels[
#                 (labels['z'] >= z_start) & (labels['z'] < z_end) &
#                 (labels['y'] >= y_start) & (labels['y'] < y_end) &
#                 (labels['x'] >= x_start) & (labels['x'] < x_end)
#             ]
#             block_filename = f"block_{z}_{y}_{x}.fits"
#             fits_filepath = os.path.join(output_dir, 'mask_low_light')
#             if not os.path.exists(fits_filepath):
#                 os.makedirs(fits_filepath)
#             block_filepath = os.path.join(fits_filepath, block_filename)

#             # 保存 FITS 文件
#             fits.writeto(block_filepath, sub_block, overwrite=True,header=header)

#             print(f"Saved: {block_filepath}")
