#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import argparse
import pickle
import numpy as np
from astropy.io import fits
import re
import pandas as pd
from scipy.ndimage import binary_closing
from skimage.filters import threshold_otsu
import cupy as cp
from cupyx.scipy.ndimage import label
from sklearn.decomposition import IncrementalPCA
import pickle
from skimage.measure import label as sk_label
from skimage.measure import regionprops
from astropy.wcs import WCS 



def make_ellipsoid_struct(radii):
    """
    生成一个 3D 椭球形的结构元。
    radii: 三元组 (rz, ry, rx)，分别是 z, y, x 方向的半轴长度
    返回: bool 数组，中心为 (rz,ry,rx)，shape = (2*rz+1,2*ry+1,2*rx+1)
    """
    rz, ry, rx = radii
    z, y, x = np.ogrid[-rz:rz+1, -ry:ry+1, -rx:rx+1]
    ellip = (z/rz)**2 + (y/ry)**2 + (x/rx)**2 <= 1
    return ellip
def load_fits(path):
    """读取 FITS 并返回数据与 header。"""
    with fits.open(path) as hdul:
        data = hdul[0].data
        header = hdul[0].header
    return data, header
def mask_to_bboxes(mask,min_voxels=5):
    """
    输入：mask (D, H, W), bool array
    输出：list of bounding boxes: [ [z1, y1, x1, z2, y2, x2], ... ]
    """
    structure = np.ones((3, 3, 3), dtype=np.int32)  
    labeled_mask, num_features = label(mask, structure=structure)
    bboxes = []
    for i in range(1, num_features + 1):
        positions = np.argwhere(labeled_mask == i)
        if positions.shape[0] < min_voxels:
            continue  # 体素数小于 min_vox
        z_min, y_min, x_min = positions.min(axis=0)
        z_max, y_max, x_max = positions.max(axis=0)
        bboxes.append([z_min, y_min, x_min, z_max, y_max, x_max])
    return bboxes

def compute_iou_3d(box1, box2):
    z1 = max(box1[0], box2[0])
    y1 = max(box1[1], box2[1])
    x1 = max(box1[2], box2[2])
    z2 = min(box1[3], box2[3])
    y2 = min(box1[4], box2[4])
    x2 = min(box1[5], box2[5])

    dz = max(0, z2 - z1 + 1)
    dy = max(0, y2 - y1 + 1)
    dx = max(0, x2 - x1 + 1)

    intersection = dz * dy * dx
    volume1 = (box1[3] - box1[0] + 1) * (box1[4] - box1[1] + 1) * (box1[5] - box1[2] + 1)
    volume2 = (box2[3] - box2[0] + 1) * (box2[4] - box2[1] + 1) * (box2[5] - box2[2] + 1)
    union = volume1 + volume2 - intersection
    
    if union == 0:
        return 0.0

    return intersection / union

def get_instances_coords(mask: np.ndarray, min_voxels: int = 10):
    """
    标记 mask 中的连通体，并返回每个满足体素数最小值的实例的坐标列表。
    
    Returns
    -------
    coords_list : List[np.ndarray]
        每个元素是一个 (N_i, D) 的数组，D=mask.ndim（3D 时 D=3），
        记录第 i 个实例所有体素的索引。
    """
    structure = np.ones((3,)*mask.ndim, dtype=int)
    labeled_mask, num_features = label(mask, structure=structure)

    coords_list = []
    for inst_id in range(1, num_features+1):
        coords = np.argwhere(labeled_mask == inst_id)
        if coords.shape[0] >= min_voxels:
            coords_list.append(coords)

    return coords_list

def get_instances_coords_with_variance_ratio_gpu(
    mask: cp.ndarray,
    variance: cp.ndarray,
    mean: cp.ndarray,
    min_voxels: int = 5,
    variance_threshold: float = 0.01,
    variance_ratio_threshold: float = 0.8,
    mean_threshold: float = 0.5
):
    """
    在 GPU 上基于连通域标记筛选实例：
      - mask     : bool 或 0/1 CuPy 数组
      - variance : 与 mask 同形状的方差图 (CuPy 数组)
      - mean     : 与 mask 同形状的均值图 (CuPy 数组)
    返回：
      coords_list : list of CuPy (N_i, 3) 数组，每个元素为一个实例的 (z,y,x) 坐标
      var_list    : list of CuPy 标量，为对应实例的平均方差
    """
    mask = cp.asarray(mask)
    variance = cp.asarray(variance)
    mean = cp.asarray(mean)
    structure = cp.ones((3,) * mask.ndim, dtype=cp.int32)
    labeled_mask, num_features = label(mask, structure=structure)

    coords_list = []
    var_list    = []

    for inst_id in range(1, int(num_features) + 1):
        inst_vox = (labeled_mask == inst_id)
        coords   = cp.argwhere(inst_vox)

        n_vox = coords.shape[0]
        if n_vox < min_voxels:
            continue
        inst_var  = variance[inst_vox]
        inst_mean = mean[inst_vox]

        low_var_count = cp.sum(inst_var < variance_threshold)
        var_ratio     = low_var_count / n_vox
        avg_variance  = cp.mean(inst_var)
        avg_mean      = cp.mean(inst_mean)
        if (var_ratio >= variance_ratio_threshold) and (avg_mean >= mean_threshold):
            coords_list.append(coords)
            var_list.append(avg_variance)
    coords_res = [coords.get() for coords in coords_list]
    vars_res   = [float(v.get()) for v in var_list]
    return coords_res, vars_res
def get_instances_coords_with_variance_ratio(mask: np.ndarray, variance: np.ndarray, mean,min_voxels: int = 5, 
                                             variance_threshold: float = 0.01, variance_ratio_threshold: float = 0.8, mean_threshold: float = 0.5): 


    structure = np.ones((3,)*mask.ndim, dtype=int)
    labeled_mask, num_features = label(mask, structure=structure)

    coords_list = []
    var_list = []
    for inst_id in range(1, num_features + 1):
        coords = np.argwhere(labeled_mask == inst_id)
        
        if coords.shape[0] >= min_voxels:
            inst_variance = variance[labeled_mask == inst_id]
            # inst_mean = mean[labeled_mask == inst_id]
            low_variance_voxels = np.sum(inst_variance < variance_threshold)
            variance_ratio = low_variance_voxels / coords.shape[0]
            avg_variance = np.mean(inst_variance)
            # avg_mean = np.mean(inst_mean)
            if variance_ratio >= variance_ratio_threshold :
                coords_list.append(coords)
                var_list.append(avg_variance)
    return coords_list, var_list

def compute_iou(mask1, mask2):
    intersection = np.logical_and(mask1, mask2).sum()
    union = np.logical_or(mask1, mask2).sum()
    if union == 0:
        return 0.0
    return intersection / union

def filter_mean_var_conf(mean):

    mask = np.ones_like(mean, dtype=bool)
    # if conf_threshold is not None:
    otsu_thresh = threshold_otsu(mean)
    print(f"[INFO] Otsu 阈值: {otsu_thresh:.4f}")
    mask &= (mean >= otsu_thresh)

    return mask


def restore_multiple_coords(coords_lists,var_mean):
            
    global_bbox = {

        'instance' : coords_lists,
        "var_mean": var_mean,
    }

    return global_bbox


def estimate_angle(mask: np.ndarray):
    positions = np.argwhere(mask != 0)
    if len(positions) < 4:
        return None
 
    pca = IncrementalPCA(n_components=3, batch_size=30).fit(positions)

    angle = np.rad2deg(np.arctan2(pca.components_[0, 1], pca.components_[0, 2]))

    if pca.components_[0, 0] > 0:
        angle += 180

    angle -= 90
    angle = angle % 360

    return angle



def instance_df_with_metrics(
    coords_lists,
    id_list = None
) -> pd.DataFrame:
    """
    将每个实例的坐标列表转换为一行记录，并计算质心和 Z 方向长度。

    Parameters
    ----------
    coords_list : List[np.ndarray]
        每个元素为 (N_i,3) ndarray，表示第 i 个实例的 (z,y,x) 坐标。
    id_list : List[int], optional
        每个实例的唯一 id。若 None，则使用 1,2,... 默认。

    Returns
    -------
    pd.DataFrame
        包含列：
          - 'id'
          - 'coords'         （坐标列表）
          - 'z_centroid', 'y_centroid', 'x_centroid'
          - 'z_length'
    """

    obj_id = 0
    records = []
    for  coords in  coords_lists:
        
        coords_py = coords["instance"]
        var_mean = coords["var_mean"]
        coords_py = np.array(coords_py)
        if coords_py.size == 0:
            zc = yc = xc = 0.0
            z_len = 0
        else:
            zc, yc, xc = coords_py.mean(axis=0)
            # if yc >=608 or xc >=1216 or xc<=608:
            #     continue
            z_min , y_min , x_min = coords_py.min(axis=0)
            z_max , y_max , x_max = coords_py.max(axis=0)
            # Z 方向长度 = max(z)-min(z)+1
            z_len = int(np.ptp(coords_py[:, 0]) + 1)
            # if z_len < 5:
            #     continue
        
        obj_id += 1
        records.append({
            'id': obj_id,
            'coords': coords_py,
            'var_mean': var_mean,
            'z_centroid': float(zc),
            'y_centroid': float(yc),
            'x_centroid': float(xc),
            'z_length': z_len,
            "x_min": x_min,
            'x_max': x_max,
            "y_min": y_min,
            'y_max': y_max,
            "z_min": z_min,
            'z_max': z_max, 
            
        })

    return pd.DataFrame.from_records(records)


def estimate_object_properties(cube: np.ndarray, 
                               df: pd.DataFrame, 
                               dilation_iterations: int = 2):
    """
    对给定的 cube 和 label mask，先对各对象掩膜做膨胀，再基于 PCA 估计 major/minor 轴长，
    并计算 mask_size 和 est_flux。
    
    Parameters
    ----------
    cube : np.ndarray
        原始数据立方体 (3D array)。
    mask : np.ndarray
        与 cube 同形状的整数掩膜，每个对象的 voxels 标记为唯一的 id。
    df : pd.DataFrame
        包含列 'id'（对象 id）。
    dilation_iterations : int
        膨胀迭代次数。

    Returns
    -------
    df_out : pd.DataFrame
        原始 df 增加了 ['mask_size','est_flux','ell_maj','ell_min'] 列。
    dilated_mask : np.ndarray
        对原始 mask 膨胀后的新掩膜 (同形状)。
    """


    # 准备输出数组
    mask_sizes = []
    est_fluxes = []
    ell_majs = []
    ell_mins = []

    nz, ny, nx = cube.shape

    for i in range(len(df['id'])):
        obj_mask = np.zeros(cube.shape, dtype=bool)
        coords = df['coords'][i]
        obj_mask[coords[:,0], coords[:,1], coords[:,2]] = True
 
 
        structure = np.ones((3,3,3), dtype=bool)
        # dilated_mask = binary_dilation(obj_mask, structure=structure, iterations=dilation_iterations)
        x0, x1 = df['x_min'][i], df['x_max'][i]
        y0, y1 = df['y_min'][i], df['y_max'][i]
        z0, z1 = df['z_min'][i], df['z_max'][i]
        subcube = cube[z0:z1+1, y0:y1+1, x0:x1+1]
        submask = obj_mask[z0:z1+1, y0:y1+1, x0:x1+1]
        # subdil  = dilated_mask[z0:z1+1, y0:y1+1, x0:x1+1]
        angle = estimate_angle(submask)
        if angle is not None:
            df.loc[i, 'ell_pa'] = angle
        mask_size = submask.sum()
        est_flux  = (subcube * submask).sum()
        mask_sizes.append(mask_size)
        est_fluxes.append(max(est_flux, 0))


        mask2d = (submask.max(axis=0) > 0).astype(np.uint8)
        lbl    = sk_label(mask2d)
        props  = regionprops(lbl)

        if not props:
            ell_majs.append(0)
            ell_mins.append(0)
        else:
            p = props[0]
            # major_axis_length / minor_axis_length 本身就是直径
            ell_majs.append(p.major_axis_length)
            ell_mins.append(p.minor_axis_length)
    df_out = df.copy()
    df_out['mask_size'] = mask_sizes
    df_out['est_flux']  = est_fluxes
    df_out['ell_maj']   = ell_majs
    df_out['ell_min']   = ell_mins

    return df_out


def compute_challenge_metrics(df, header):
    # if padding is None:
    #     padding = np.zeros(len(position[0]))
    wcs = WCS(header)

    if len(df) > 0:

        df.loc[:, ['ra', 'dec', 'central_freq']] = wcs.all_pix2world(
            np.array(df[['x_centroid', 'y_centroid', 'z_centroid']], dtype=np.float32), 0)

        if 'z_length' in df.columns:
            df.loc[:, 'w20'] = df['z_length'] * 299792.458 * header['CDELT3'] / header[
                'RESTFREQ']

        if 'est_flux' in df.columns:
            df.loc[:, 'line_flux_integral'] = df['est_flux'] * header['CDELT3'] / (
                    np.pi * (7 / 2.8) ** 2 / (4 * np.log(2)))

        if 'ell_maj' in df.columns:
            df.loc[:, 'hi_size'] = df['ell_maj'] * 2.8
            
        if 'ell_maj' and 'ell_min' in df.columns:
            df.loc[:, 'i'] = np.rad2deg(
                np.arccos(np.sqrt(((df['ell_min'] / df['ell_maj']) ** 2 - .2 ** 2) / (1 - .2 ** 2))))
            df.loc[:, 'i'] = df['i'].fillna(45)

        if 'ell_pa' in df.columns:
            df.loc[:, 'pa'] = df['ell_pa'].fillna(0)

    return df

import os
import numpy as np
from tqdm import tqdm

if __name__ == "__main__":
    
    predict_path = ''

    original_cube_path = ''
    
    cube = fits.getdata(original_cube_path)
    header = fits.getheader(original_cube_path)

    instance_list = []
    pred_bboxes_num = 0
    file_list = sorted(os.listdir(predict_path))
   
    variance_threshold = 0.1
    variance_ratio_threshold = 0.01
    instance_list = []
    print(f"处理 variance_ratio_threshold: {variance_ratio_threshold}")
    for fn in tqdm(file_list):

        if not fn.endswith("mean.fits"):
            continue
        var_file = fn.replace("mean", "var")  
        mean_path = os.path.join(predict_path, fn)
        var_path  = os.path.join(predict_path, var_file)
        if not os.path.exists(var_path):
            print(f"[WARN] 找不到对应 var 文件: {var_path}, 跳过")
            continue

        mean, hdr_mean = load_fits(mean_path)
        var,  hdr_var  = load_fits(var_path)
        new_mask = filter_mean_var_conf(mean)
        pred_instances, var_list = get_instances_coords_with_variance_ratio_gpu(new_mask, var,mean, min_voxels=3,
                                                                variance_threshold=variance_threshold, 
                                                                variance_ratio_threshold=variance_ratio_threshold, mean_threshold=0.5)
        print("预测框数量:", len(pred_instances))
        for i, pre_in in enumerate(pred_instances):
                var_mean = var_list[i]
                pre_in = restore_multiple_coords(pre_in,var_mean)
                instance_list.append(pre_in)
                
    df = instance_df_with_metrics(instance_list)
    cols = [
        'id', 'ra', 'dec', 'hi_size',
        'line_flux_integral', 'central_freq',
        'pa', 'i', 'w20','z_centroid','y_centroid','x_centroid'
    ]

    df_out = estimate_object_properties(cube, df, dilation_iterations=2)
    df_out= compute_challenge_metrics(df_out, header)
    df_model_re2 = df_out[cols].copy()
    df_model_re = pd.concat([
        # df_model_re, 
                            df_model_re2
                            ], axis=0)

    df_model_re.to_csv(f'', index=False)

