import sys

import numpy as np
import pywt

sys.path.insert(0, "/home/cl/sdc2/HISourceFinder-master-l/src/")

import pandas as pd


def wavelet_7_9_stronger_denoise_3d(data, threshold_factor=3.0, levels=2, mode="soft"):
    """
    Apply a multi-level 7/9 wavelet transform along the Z-axis for stronger denoising,
    then reconstruct the signal.

    Parameters
    ----------
    data : np.ndarray
        Input 3D array of shape (Z, Y, X).
    threshold_factor : float, optional
        Multiplicative factor for noise thresholding. A higher value removes more noise.
        Default is 3.0.
    levels : int, optional
        Number of decomposition levels for wavelet transform.
        Default is 2 (stronger denoising).
    mode : str, optional
        Thresholding mode ('soft' or 'hard'). Default is 'soft'.

    Returns
    -------
    filtered_data : np.ndarray
        Reconstructed 3D array after stronger wavelet-based denoising.
    """
    z_dim, y_dim, x_dim = data.shape

    # Apply multi-level wavelet transform along Z-axis
    coeffs = pywt.wavedec(data, "bior2.2", axis=0, level=levels)

    # Apply thresholding to high-frequency components
    for i in range(1, len(coeffs)):  # Skip approximation coefficients (cA)
        threshold = threshold_factor * np.std(coeffs[i])
        coeffs[i] = pywt.threshold(coeffs[i], threshold, mode=mode)

    # Perform inverse wavelet transform to reconstruct filtered data
    filtered_data = pywt.waverec(coeffs, "bior2.2", axis=0)

    return filtered_data

import os
import numpy as np
from astropy.io import fits
import cupy as cp
fits_file = '/home/dell461/cl/sdc2/dataset/sky_ldev.fits'
# fits_file = "/home/ska/sdc2/sky_full_v2.fits"
# csv_file = "/home/cl/sdc2/dataset/sdc2_l.csv"


def compute_starts(L, block, stride):
        if block >= L:
            return [0]
        starts = list(range(0, max(1, L - block + 1), stride))
        last = L - block
        if starts[-1] != last:
            starts.append(last)
        # 去重并保证 >=0
        starts = sorted(set(max(0, s) for s in starts))
        return starts

with fits.open(fits_file, mode='readonly', memmap=True, do_not_scale_image_data=True) as hdul:
    hdu = hdul[0]
    header = hdu.header
    data = hdu.data           # numpy.memmap，懒加载
    block_size = 128
    eps = 1e-6
    C, H, W = data.shape
    overlap = 32
    block_size_z = 128  # z 轴切块大小，你自己设
    block_size_y = 128  # y 轴切块大小
    block_size_x = 128  # x 轴切块大小
    stride_z = block_size_z - overlap
    stride_y = block_size_y - overlap
    stride_x = block_size_x - overlap

    num_blocks_z = (C - overlap + stride_z - 1) // stride_z
    num_blocks_y = (H - overlap + stride_y - 1) // stride_y
    num_blocks_x = (W - overlap + stride_x - 1) // stride_x
    eps = 1e-8  # 数值稳定

    # 假设已有：data (C,H,W), block_size_z/y/x, stride_z/y/x, output_dir, header
    C, H, W = data.shape


    # 先计算 XY 平面的起点（先 XY 后 Z）
    y_starts = compute_starts(H, block_size_y, stride_y)
    x_starts = compute_starts(W, block_size_x, stride_x)

    # 输出目录
    fits_dir = '/home/dell461/cl/sdc2/last_ska_mid/data'
    # os.makedirs(fits_dir, exist_ok=True)

    # header 追加处理说明
    hdr = header.copy() if header is not None else fits.Header()
    hdr.add_history("Pipeline: tile XY -> per-tile Z-score (over HxW within tile) -> per-tile Min-Max -> slice along Z.")
    hdr.add_history("Each output block shape = (bz, by, bx), dtype=float32.")

    # 遍历 XY 2D 矩形块
    for y_start in y_starts:
        y_end = min(y_start + block_size_y, H)
        by = y_end - y_start

        for x_start in x_starts:
            x_end = min(x_start + block_size_x, W)
            bx = x_end - x_start

            # 取出该 XY 矩形块（覆盖全部 Z）：形状 (C, by, bx)
            tile = data[:, y_start:y_end, x_start:x_end]  # 不复制视图

            # Min–Max：仍然逐 z 切片在 (by,bx) 上做
            mins = tile.min( keepdims=True)      
            maxs = tile.max( keepdims=True)      
            tile_norm = (tile - mins) / (maxs - mins)   # (C,by,bx)
            # tile_norm = tile_norm.astype(np.float32, copy=False)

            # —— 再在 z 轴上对该 tile 切块 —— 
            z_starts = compute_starts(C, block_size_z, stride_z)

            for z_start in z_starts:
                z_end = min(z_start + block_size_z, C)
                bz = z_end - z_start

                sub_block = tile_norm[z_start:z_end, :, :]      # (bz, by, bx)
                z_idx = z_start // stride_z
                y_idx = y_start // stride_y
                x_idx = x_start // stride_x
                # 保存：命名携带 tile 的 xy 起点与 z 起点，便于追踪
                block_filename = f"block_{z_idx}_{y_idx}_{x_idx}.fits"
                # block_filepath = os.path.join(fits_dir, block_filename)
                out_y_dir = os.path.join(fits_dir, str(y_idx))
                if not os.path.exists(out_y_dir):
                    os.makedirs(out_y_dir, exist_ok=True)
                block_filepath = os.path.join(out_y_dir, block_filename)
                fits.writeto(block_filepath, sub_block, overwrite=True, header=hdr)
                print(f"Saved block: {block_filepath}  shape={sub_block.shape}")
    # data_cp = cp.asarray(data, dtype=cp.float32)    # ★ 改：numpy -> cupy

    # for y_start in y_starts:
    #     y_end = min(y_start + block_size_y, H)
    #     by = y_end - y_start

    #     for x_start in x_starts:
    #         x_end = min(x_start + block_size_x, W)
    #         bx = x_end - x_start

    #         # ★ 直接在 GPU 上切 tile：形状 (C, by, bx)
    #         tile = data[:, y_start:y_end, x_start:x_end]   # ★ 改：保持切片语义
    #         tile = cp.asarray(tile, dtype=cp.float32) 
    #         # ★ 每个 z 在 (by,bx) 上做 min-max（GPU 上）
    #         mins = tile.min( keepdims=True)       # (C,1,1)
    #         maxs = tile.max( keepdims=True)
    #         denom = cp.maximum(maxs - mins, eps)              # 避免除零
    #         tile_norm = (tile - mins) / denom                 # (C,by,bx)
    #         z_starts = compute_starts(C, block_size_z, stride_z)
    #         # —— z 轴切块并写出 —— #
    #         for z_start in z_starts:
    #             z_end = min(z_start + block_size_z, C)
    #             bz = z_end - z_start

    #             sub_block = tile_norm[z_start:z_end, :, :]    # (bz,by,bx)

    #             z_idx = z_start // stride_z
    #             y_idx = y_start // stride_y
    #             x_idx = x_start // stride_x

    #             block_filename = f"block_{z_idx}_{y_idx}_{x_idx}.fits"
    #             out_y_dir = os.path.join(fits_dir, str(y_idx))
    #             os.makedirs(out_y_dir, exist_ok=True)
    #             block_filepath = os.path.join(out_y_dir, block_filename)

    #             # ★ 回到 CPU（numpy）再写 FITS
    #             sub_block_np = cp.asnumpy(sub_block)          # ★ 改：cupy -> numpy
    #             fits.writeto(block_filepath, sub_block_np, overwrite=True, header=hdr)
    #             print(f"Saved block: {block_filepath}  shape={sub_block_np.shape}")
# for z in range(num_blocks_z-1):
#     for y in range(num_blocks_y-1):
#         for x in range(num_blocks_x-1):
#             z_start = z * stride_z
#             y_start = y * stride_y
#             x_start = x * stride_x
#             z_end = min(z_start + block_size_z, C)
#             y_end = min(y_start + block_size_y, H)
#             x_end = min(x_start + block_size_x, W)
#             # 获取子块
#             sub_block = data[z_start:z_end, y_start:y_end, x_start:x_end]
#             targets_in_cube = labels[
#                 (labels['z'] >= z_start) & (labels['z'] < z_end) &
#                 (labels['y'] >= y_start) & (labels['y'] < y_end) &
#                 (labels['x'] >= x_start) & (labels['x'] < x_end)
#             ]
#             block_filename = f"block_{z}_{y}_{x}.fits"
#             fits_filepath = os.path.join(output_dir, 'mask_low_light')
#             if not os.path.exists(fits_filepath):
#                 os.makedirs(fits_filepath)
#             block_filepath = os.path.join(fits_filepath, block_filename)

#             # 保存 FITS 文件
#             fits.writeto(block_filepath, sub_block, overwrite=True,header=header)

#             print(f"Saved: {block_filepath}")

# if not targets_in_cube.empty:

#     # npy_file = os.path.join(output_dir, f"cube_{x}_{y}_{z}.npy")
#     # small_cube_resized = resize(small_cube, (192, 192, 128), mode='reflect', anti_aliasing=True)
#     # np.save(npy_file, point_cloud_cpu.astype(np.float32))
#     # small_cube = cp.asnumpy(small_cube_resized)
#     # small_cube = np.transpose(small_cube, (2, 1, 0))
#     # fits.writeto(npy_file.replace('.npy', '.fits'), small_cube, overwrite=True)
#     txt_output_dir = os.path.join(output_dir, "label_txt")
#     if not os.path.exists(txt_output_dir):
#         os.makedirs(txt_output_dir, exist_ok=True)
#     label_file = os.path.join(txt_output_dir, f"cube_{z}_{x}.txt")
# with open(label_file, "w") as f:
#     for _, row in targets_in_cube.iterrows():
#         # local_x, local_y, local_z = row['x'] - x_min, row['y'] - y_min, row['z'] - z_min
#         # length = (row['x_upper'] - row['x_lower']) / 2
#         # width = (row['y_upper'] - row['y_lower']) / 2
#         # height = (row['z_upper'] - row['z_lower']) / 2
#         local_x = (row['x'] - x_start)   # x 方向放大 3 倍
#         local_y = (row['y'] - y_start)   # y 方向放大 3 倍
#         local_z = row['z']  # z 方向不变
#         # length = max(int(row['major_radius_pixels'] * 3), 8)  # 长度放大 3 倍
#         # width = max(int(row['major_radius_pixels'] * 3), 8)  # 宽度放大 3 倍
#         # height = max(int(row['n_channels'] - 10), 3)  # 高度不变
#         length = int(row['major_radius_pixels'])  # 长度放大 3 倍
#         width = int(row['major_radius_pixels'])  # 宽度放大 3 倍
#         height = int(row['n_channels']) -10 # 高度不变
#         # length = int(row['major_radius_pixels'])
#         # width = int(row['major_radius_pixels'])
#         # height = int(row['n_channels']-10)
#         # if length<3:
#         #     length=3
#         # if width<3:
#         #     width=3
#         if local_z + height/2 > block_size:
#             height = (block_size - local_z)*2
#         elif local_z - height/2 < 0:
#             height = local_z*2
#         f.write(f"{local_x:.2f} {local_y:.2f} {local_z:.2f} {length:.2f} {width:.2f} {height:.2f} 0 radio_source\n")

# 归一化处理，避免除零错误
# sub_block = wavelet_7_9_stronger_denoise_3d(sub_block, threshold_factor=1, levels=1, mode='soft')
# sub_block_denoised = denoiser.denoise(sub_block, method='simple', threshold_level=1,
#     threshold_increment_high_freq=1, num_scales=None,)
# denoised_cube = (sub_block_denoised- np.min(sub_block_denoised)) / (np.max(sub_block_denoised)- np.min(sub_block_denoised))
# sub_block_mean = np.mean(sub_block)
# sub_block_std = np.std(sub_block)
# sub_block = (sub_block - sub_block_mean) / sub_block_std
# if np.max(sub_block) > np.min(sub_block):
#     sub_block = (sub_block - np.min(sub_block)) / (np.max(sub_block) - np.min(sub_block))
# else:
#     sub_block = np.zeros_like(sub_block)

# 生成文件名，记录块的位置
# block_filename = f"block_{z}_{y}_{x}.fits"
# fits_filepath = os.path.join(output_dir, 'mask_low_light')
# if not os.path.exists(fits_filepath):
#     os.makedirs(fits_filepath)
# block_filepath = os.path.join(fits_filepath, block_filename)

# # 保存 FITS 文件
# fits.writeto(block_filepath, sub_block, overwrite=True,header=header)

# print(f"Saved: {block_filepath}")

# from spectral_cube import SpectralCube

# # 读入 FITS 立方体（会自动提取 WCS，包括空间+频率轴）
# cube = SpectralCube.read('/home/cl/sdc2/dataset/sky_ldev.fits')

# # 按像素索引裁剪（axis order: spectral, y, x）
# subcube = cube[  :,    # 频谱通道 50–150
#                0:608,    # y 方向像素 200–400
#                0:608 ]   # x 方向像素 300–500

# # 或者，按世界坐标裁剪：
# # subcube = cube.subcube(xlo, xhi, ylo, yhi)

# # 写出新的子立方体
# subcube.write('/home/cl/sdc2/dataset/dataset_128/test/test_608.fits', overwrite=True)
