import numpy as np
import pandas as pd
from scipy.ndimage import binary_dilation
from astropy.wcs import WCS 
from sklearn.decomposition import IncrementalPCA
import pickle
from skimage.measure import label, regionprops

def load_pickle_to_list(pkl_path):
    """
    读取一个 pickle 文件并将其内容转换为列表。

    参数
    ----------
    pkl_path : str
        pickle 文件路径。

    返回
    -------
    list
        从 pickle 中读取并转换得到的列表。如果 pickle 内容是：
          - list：直接返回
          - dict：返回其 items() 列表
          - 其他可迭代对象：使用 list() 转换
          - 不可迭代对象：将其作为单元素列表返回
    """
    with open(pkl_path, 'rb') as f:
        data = pickle.load(f)

    if isinstance(data, list):
        return data
    if isinstance(data, dict):
        return list(data.items())
    try:
        return list(data)
    except TypeError:
        return [data]

def estimate_angle(mask: np.ndarray):
    positions = np.argwhere(mask != 0)
    if len(positions) < 4:
        return None
 
    pca = IncrementalPCA(n_components=3, batch_size=30).fit(positions)

    angle = np.rad2deg(np.arctan2(pca.components_[0, 1], pca.components_[0, 2]))

    if pca.components_[0, 0] > 0:
        angle += 180

    angle -= 90
    angle = angle % 360

    return angle



def instance_df_with_metrics(
    coords_lists,
    id_list = None
) -> pd.DataFrame:
    """
    将每个实例的坐标列表转换为一行记录，并计算质心和 Z 方向长度。

    Parameters
    ----------
    coords_list : List[np.ndarray]
        每个元素为 (N_i,3) ndarray，表示第 i 个实例的 (z,y,x) 坐标。
    id_list : List[int], optional
        每个实例的唯一 id。若 None，则使用 1,2,... 默认。

    Returns
    -------
    pd.DataFrame
        包含列：
          - 'id'
          - 'coords'         （坐标列表）
          - 'z_centroid', 'y_centroid', 'x_centroid'
          - 'z_length'
    """
    obj_id = 0
    records = []
    for  coords in  coords_lists:
        
     
        coords_py = coords["instance"]
        var_mean = coords["var_mean"]
        coords_py = np.array(coords_py)
      
        if coords_py.size == 0:
            zc = yc = xc = 0.0
            z_len = 0
        else:
        
            zc, yc, xc = coords_py.mean(axis=0)
            # if yc >=608 or xc >=1216 or xc<=608:
            #     continue
            z_min , y_min , x_min = coords_py.min(axis=0)
            z_max , y_max , x_max = coords_py.max(axis=0)
            # Z 方向长度 = max(z)-min(z)+1
            z_len = int(np.ptp(coords_py[:, 0]) + 1)
            # if z_len < 5:
            #     continue
        
        obj_id += 1
        records.append({
            'id': obj_id,
            'coords': coords_py,
            'var_mean': var_mean,
            'z_centroid': float(zc),
            'y_centroid': float(yc),
            'x_centroid': float(xc),
            'z_length': z_len,
            "x_min": x_min,
            'x_max': x_max,
            "y_min": y_min,
            'y_max': y_max,
            "z_min": z_min,
            'z_max': z_max, 
            
        })

    return pd.DataFrame.from_records(records)


def estimate_object_properties(cube: np.ndarray, 
                               df: pd.DataFrame, 
                               dilation_iterations: int = 2):
    """
    对给定的 cube 和 label mask，先对各对象掩膜做膨胀，再基于 PCA 估计 major/minor 轴长，
    并计算 mask_size 和 est_flux。
    
    Parameters
    ----------
    cube : np.ndarray
        原始数据立方体 (3D array)。
    mask : np.ndarray
        与 cube 同形状的整数掩膜，每个对象的 voxels 标记为唯一的 id。
    df : pd.DataFrame
        包含列 'id'（对象 id）。
    dilation_iterations : int
        膨胀迭代次数。

    Returns
    -------
    df_out : pd.DataFrame
        原始 df 增加了 ['mask_size','est_flux','ell_maj','ell_min'] 列。
    dilated_mask : np.ndarray
        对原始 mask 膨胀后的新掩膜 (同形状)。
    """

    mask_sizes = []
    est_fluxes = []
    ell_majs = []
    ell_mins = []

    nz, ny, nx = cube.shape

    for i in range(len(df['id'])):
        obj_mask = np.zeros(cube.shape, dtype=bool)
        coords = df['coords'][i]
        obj_mask[coords[:,0], coords[:,1], coords[:,2]] = True
 
 
        structure = np.ones((3,3,3), dtype=bool)
        # dilated_mask = binary_dilation(obj_mask, structure=structure, iterations=dilation_iterations)
        x0, x1 = df['x_min'][i], df['x_max'][i]
        y0, y1 = df['y_min'][i], df['y_max'][i]
        z0, z1 = df['z_min'][i], df['z_max'][i]
        subcube = cube[z0:z1+1, y0:y1+1, x0:x1+1]
        submask = obj_mask[z0:z1+1, y0:y1+1, x0:x1+1]
        # subdil  = dilated_mask[z0:z1+1, y0:y1+1, x0:x1+1]
        angle = estimate_angle(submask)
        if angle is not None:
            df.loc[i, 'ell_pa'] = angle

        # 4. 计算属性
        mask_size = submask.sum()
        est_flux  = (subcube * submask).sum()
        mask_sizes.append(mask_size)
        est_fluxes.append(max(est_flux, 0))


        mask2d = (submask.max(axis=0) > 0).astype(np.uint8)
        lbl    = label(mask2d)
        props  = regionprops(lbl)

        if not props:
            ell_majs.append(0)
            ell_mins.append(0)
        else:
            p = props[0]

            ell_majs.append(p.major_axis_length)
            ell_mins.append(p.minor_axis_length)

    df_out = df.copy()
    df_out['mask_size'] = mask_sizes
    df_out['est_flux']  = est_fluxes
    df_out['ell_maj']   = ell_majs
    df_out['ell_min']   = ell_mins

    return df_out


def compute_challenge_metrics(df, header):

    wcs = WCS(header)

    if len(df) > 0:

        df.loc[:, ['ra', 'dec', 'central_freq']] = wcs.all_pix2world(
            np.array(df[['x_centroid', 'y_centroid', 'z_centroid']], dtype=np.float32), 0)

        if 'z_length' in df.columns:
            df.loc[:, 'w20'] = df['z_length'] * 299792.458 * header['CDELT3'] / header[
                'RESTFREQ']

        if 'est_flux' in df.columns:
            df.loc[:, 'line_flux_integral'] = df['est_flux'] * header['CDELT3'] / (
                    np.pi * (7 / 2.8) ** 2 / (4 * np.log(2)))

        if 'ell_maj' in df.columns:
            df.loc[:, 'hi_size'] = df['ell_maj'] * 2.8

        if 'ell_maj' and 'ell_min' in df.columns:
            df.loc[:, 'i'] = np.rad2deg(
                np.arccos(np.sqrt(((df['ell_min'] / df['ell_maj']) ** 2 - .2 ** 2) / (1 - .2 ** 2))))
            df.loc[:, 'i'] = df['i'].fillna(45)

        if 'ell_pa' in df.columns:
            df.loc[:, 'pa'] = df['ell_pa'].fillna(0)

    return df
from astropy.io import fits
cube = fits.getdata('')
header = fits.getheader('')
file = ''


lst = load_pickle_to_list(file)

print(len(lst))
df = instance_df_with_metrics(lst)
cols = [
    'id', 'ra', 'dec', 'hi_size',
    'line_flux_integral', 'central_freq',
    'pa', 'i', 'w20','z_centroid','y_centroid','x_centroid'
]

df_out = estimate_object_properties(cube, df, dilation_iterations=2)
df_out= compute_challenge_metrics(df_out, header)
df_model_re2 = df_out[cols].copy()
df_model_re = pd.concat([
    # df_model_re, 
                        df_model_re2
                        ], axis=0)

df_model_re.to_csv(f'', index=False)