import numpy as np
import cupy as cp
import cv2
from cupyx.scipy.ndimage import map_coordinates


def parse_header(header):
    crpix = [float(header['CRPIX1']), float(header['CRPIX2'])]
    if 'CD1_1' in header and 'CD1_2' in header and 'CD2_1' in header and 'CD2_2' in header:
        matrix = [[float(header['CD1_1']), float(header['CD1_2'])],
                  [float(header['CD2_1']), float(header['CD2_2'])]]
    elif 'PC1_1' in header and 'PC1_2' in header and 'PC2_1' in header and 'PC2_2' in header:
        matrix = [[float(header['PC1_1']), float(header['PC1_2'])],
                  [float(header['PC2_1']), float(header['PC2_2'])]]
    else:
        raise ValueError("Neither CD matrix nor PC matrix found in header.")
    # cd_matrix = [[float(header['CD1_1']), float(header['CD1_2'])],
    #              [float(header['CD2_1']), float(header['CD2_2'])]]
    crval = [float(header['CRVAL1']), float(header['CRVAL2'])]

        # 获取 A_ORDER 和 B_ORDER 的阶数
    a_order = header.get('A_ORDER', 0)
    b_order = header.get('B_ORDER', 0)

    # 初始化 a_coeffs 和 b_coeffs 矩阵，大小为 (a_order + 1, a_order + 1)
    a_coeffs = np.zeros((a_order + 1, a_order + 1))
    b_coeffs = np.zeros((b_order + 1, b_order + 1))

    # 填充 a_coeffs 矩阵
    for i in range(a_order + 1):
        for j in range(a_order + 1 - i):
            a_key = f'A_{i}_{j}'
            a_coeffs[i, j] = header.get(a_key, 0.0)

    # 填充 b_coeffs 矩阵
    for i in range(b_order + 1):
        for j in range(b_order + 1 - i):
            b_key = f'B_{i}_{j}'
            b_coeffs[i, j] = header.get(b_key, 0.0)

    # 获取 AP_ORDER 和 BP_ORDER 的阶数
    ap_order = header.get('AP_ORDER', 0)
    bp_order = header.get('BP_ORDER', 0)

    # 初始化 ap_coeffs 和 bp_coeffs 矩阵，大小为 (ap_order + 1, ap_order + 1)
    ap_coeffs = np.zeros((ap_order + 1, ap_order + 1))
    bp_coeffs = np.zeros((bp_order + 1, bp_order + 1))

    # 填充 ap_coeffs 矩阵
    for i in range(ap_order + 1):
        for j in range(ap_order + 1 - i):
            ap_key = f'AP_{i}_{j}'
            ap_coeffs[i, j] = header.get(ap_key, 0.0)

    # 填充 bp_coeffs 矩阵
    for i in range(bp_order + 1):
        for j in range(bp_order + 1 - i):
            bp_key = f'BP_{i}_{j}'
            bp_coeffs[i, j] = header.get(bp_key, 0.0)

    return crpix, matrix, crval, a_coeffs, b_coeffs, ap_coeffs, bp_coeffs

def apply_sip_forward(pixcrd, crpix, a_coeffs_matrix, b_coeffs_matrix):
    """
    使用 SIP 前向畸变校正模型将像素坐标转换为校正后的坐标。

    参数：
    pixcrd: 2D 像素坐标数组，形状为 (N, 2)
    crpix: 参考像素的坐标
    a_coeffs_matrix: A 系数矩阵 (Numpy 或 Cupy 数组)
    b_coeffs_matrix: B 系数矩阵 (Numpy 或 Cupy 数组)
    """
    u = pixcrd[:, 0] - crpix[0]
    v = pixcrd[:, 1] - crpix[1]

    f_u = cp.zeros_like(u)
    f_v = cp.zeros_like(v)
    matrix_shape = cp.shape(a_coeffs_matrix)
    
    for i in range(matrix_shape[0]):
        for j in range(matrix_shape[1]):
            if a_coeffs_matrix[i, j] != 0:  # 忽略零系数
                f_u += a_coeffs_matrix[i, j] * (u ** i) * (v ** j)
            if b_coeffs_matrix[i, j] != 0:  # 忽略零系数
                f_v += b_coeffs_matrix[i, j] * (u ** i) * (v ** j)
    u += f_u
    v += f_v
    

    result = cp.stack([u + crpix[0], v + crpix[1]], axis=-1)

    return result


def apply_sip_inverse(pixcrd, crpix, ap_coeffs_matrix, bp_coeffs_matrix):
    """
    使用 SIP 逆向畸变校正模型将像素坐标转换为校正后的坐标。

    参数：
    pixcrd: 2D 像素坐标数组，形状为 (N, 2)
    crpix: 参考像素的坐标
    ap_coeffs_matrix: AP 系数矩阵 (Numpy 或 Cupy 数组)
    bp_coeffs_matrix: BP 系数矩阵 (Numpy 或 Cupy 数组)
    """
    u = pixcrd[:, 0] - crpix[0]
    v = pixcrd[:, 1] - crpix[1]
    
    # 初始化畸变校正量
    f_u = cp.zeros_like(u)
    f_v = cp.zeros_like(v)
    # 动态遍历系数矩阵，处理所有非零系数
    matrix_shape = cp.shape(ap_coeffs_matrix)
    for i in range(matrix_shape[0]):
        for j in range(matrix_shape[1]):
            if ap_coeffs_matrix[i, j] != 0:  # 忽略零系数
                f_u += ap_coeffs_matrix[i, j] * (u ** i) * (v ** j)
            if bp_coeffs_matrix[i, j] != 0:  # 忽略零系数
                f_v += bp_coeffs_matrix[i, j] * (u ** i) * (v ** j)


    # 更新 u 和 v 的值
    u += f_u
    v += f_v
    
    # 将校正后的 u 和 v 堆叠为结果
    result = cp.stack([u + crpix[0], v + crpix[1]], axis=-1)

    return result


def pix2world_pipe(pixcrd, crpix, cd_matrix, crval, a_coeffs, b_coeffs):
    """
    Convert pixel coordinates to world coordinates using the SIP forward distortion model, 
    linear transformation, and tangent projection.

    Parameters:
    ----------
    pixcrd : array-like, shape (N, 2)
        A 2D array of pixel coordinates (x, y) to be transformed.

    crpix : array-like, shape (2,)
        The reference pixel coordinates (CRPIX1, CRPIX2). This is used to calculate the pixel offset.

    cd_matrix : array-like, shape (2, 2)
        The CD matrix, which defines the linear transformation from pixel coordinates to intermediate world coordinates.

    crval : array-like, shape (2,)
        The world coordinates (RA, Dec) at the reference pixel (CRVAL1, CRVAL2), in degrees.

    a_coeffs : 2D array (numpy or cupy array)
        The A coefficients matrix used in the SIP forward distortion model.

    b_coeffs : 2D array (numpy or cupy array)
        The B coefficients matrix used in the SIP forward distortion model.

    Returns:
    --------
    ra_dec_coords : array-like, shape (N, 2)
        The transformed world coordinates (RA, Dec) corresponding to the input pixel coordinates.
    """
    pixcrd = cp.array(pixcrd)
    crpix = cp.array(crpix)
    cd_matrix = cp.array(cd_matrix)
    crval = cp.array(crval)
    a_coeffs = cp.array(a_coeffs)
    b_coeffs = cp.array(b_coeffs)
    # Apply SIP forward transformation
    if a_coeffs.any() and b_coeffs.any():

        pixcrd = apply_sip_forward(pixcrd, crpix, a_coeffs, b_coeffs)

    # 线性变换
    delta_pix = pixcrd - crpix
    world_coords = cp.dot(delta_pix, cd_matrix.T)

    # 正切投影
    xi = cp.deg2rad(world_coords[:, 0])
    eta = cp.deg2rad(world_coords[:, 1])

    ra0 = cp.deg2rad(crval[0])
    dec0 = cp.deg2rad(crval[1])

    denominator = cp.cos(dec0) - eta * cp.sin(dec0)
    ra = cp.arctan2(xi, denominator) + ra0
    dec = cp.arctan2(eta * cp.cos(dec0) + cp.sin(dec0), cp.sqrt(xi**2 + denominator**2))

    ra = cp.rad2deg(ra)
    dec = cp.rad2deg(dec)

    ra_dec_coords = cp.vstack((ra, dec)).T

    return ra_dec_coords

def world2pix_pipe(ra_dec, crpix, cd_matrix, crval, ap_coeffs, bp_coeffs):
    ra_dec = cp.array(ra_dec)
    crpix = cp.array(crpix)
    cd_matrix = cp.array(cd_matrix)
    crval = cp.array(crval)
    ap_coeffs = cp.array(ap_coeffs)
    bp_coeffs = cp.array(bp_coeffs)

    ra = cp.deg2rad(ra_dec[:, 0])
    dec = cp.deg2rad(ra_dec[:, 1])

    ra0 = cp.deg2rad(crval[0])
    dec0 = cp.deg2rad(crval[1])

    xi = cp.cos(dec) * cp.sin(ra - ra0) / (cp.sin(dec) * cp.sin(dec0) + cp.cos(dec) * cp.cos(dec0) * cp.cos(ra - ra0))
    eta = (cp.sin(dec) * cp.cos(dec0) - cp.cos(dec) * cp.sin(dec0) * cp.cos(ra - ra0)) / (cp.sin(dec) * cp.sin(dec0) + cp.cos(dec) * cp.cos(dec0) * cp.cos(ra - ra0))

    xi = cp.rad2deg(xi)
    eta = cp.rad2deg(eta)

    world_coords = cp.vstack((xi, eta)).T

    inv_cd_matrix = cp.linalg.inv(cd_matrix)
    delta_pix = cp.dot(world_coords, inv_cd_matrix.T)

    pixcrd = delta_pix + crpix
    
    # Apply SIP inverse transformation
    if ap_coeffs.any() and bp_coeffs.any():
        pixcrd = apply_sip_inverse(pixcrd, crpix, ap_coeffs, bp_coeffs)

    return pixcrd

def pix2pix_pipe(pixcrd, i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs):

  
    ra_dec = pix2world_pipe(pixcrd, i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs)
    new_pixcrd = world2pix_pipe(ra_dec, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs)

    return new_pixcrd


def find_points_pipe(i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs, data_source, step=10):
    """
    Find corresponding points between source and target images using WCS headers.
    
    Parameters:
    header_target (object): Target WCS header information.
    header_source (object): Source WCS header information.
    data_source (np.ndarray): Source data to be aligned.
    step (float): Step size for sampling points. Default is 10.
    
    Returns:
    np.ndarray: points_source, points_target
    """
    try:
        pixel_out = cp.meshgrid(
            *[cp.arange(size, dtype=float) for size in data_source.shape],
            indexing="ij",
            sparse=False,
            copy=False,
        )
        pixel_out = cp.asarray(pixel_out)
        pixel_out = pixel_out[:, ::int(step), ::int(step)]
        pixel_out = cp.concatenate((pixel_out[1].ravel().reshape(-1, 1), pixel_out[0].ravel().reshape(-1, 1)), axis=1)

        pixel_in = pix2pix_pipe(pixel_out, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs,i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs)
        points_source = cp.asarray(pixel_in)
        points_target = cp.asarray(pixel_out)

    except Exception as e:
        print(f"Error: {e}")
        raise
    # finally:
    #     cp.cuda.Device().synchronize()
        
    return points_source, points_target



def warpAffine(image, M):
    rows, cols = image.shape
    dst = cp.zeros((rows, cols), dtype=image.dtype)
    
    # 计算仿射变换矩阵的逆矩阵
    M_inv = cp.linalg.inv(cp.vstack([M, [0, 0, 1]]))[:2, :]
    
    # 创建目标图像中的坐标网格
    x, y = cp.meshgrid(cp.arange(cols), cp.arange(rows))
    coords = cp.stack([x.ravel(), y.ravel(), cp.ones_like(x).ravel()])
    
    # 将目标坐标变换到源图像中的坐标
    src_coords = M_inv @ coords
    src_x = src_coords[0, :].reshape(rows, cols)
    src_y = src_coords[1, :].reshape(rows, cols)
    
    # 确定四个邻近点的坐标
    x0 = cp.floor(src_x).astype(int)
    x1 = x0 + 1
    y0 = cp.floor(src_y).astype(int)
    y1 = y0 + 1
    
    # 计算插值系数
    a = src_x - x0
    b = src_y - y0
    
    # 限制坐标范围在源图像边界内
    x0_clipped = cp.clip(x0, 0, image.shape[1] - 1)
    x1_clipped = cp.clip(x1, 0, image.shape[1] - 1)
    y0_clipped = cp.clip(y0, 0, image.shape[0] - 1)
    y1_clipped = cp.clip(y1, 0, image.shape[0] - 1)
    
    # 获取四个邻近点的像素值
    Ia = image[y0_clipped, x0_clipped]
    Ib = image[y0_clipped, x1_clipped]
    Ic = image[y1_clipped, x0_clipped]
    Id = image[y1_clipped, x1_clipped]
    
    # 进行双线性插值
    dst = (1 - a) * (1 - b) * Ia + a * (1 - b) * Ib + (1 - a) * b * Ic + a * b * Id
    
    # 将超出边界的像素值设置为0
    outside = (src_x < 0) | (src_x >= image.shape[1]) | (src_y < 0) | (src_y >= image.shape[0])
    dst[outside] = 0
    
    # 确保 dtype 与输入图像一致
    dst = cp.array(dst, dtype=image.dtype)
    
    # 显式释放资源
    cp.cuda.Device().synchronize()
    del M_inv, coords, src_coords, src_x, src_y
    del x0, x1, y0, y1, a, b
    del x0_clipped, x1_clipped, y0_clipped, y1_clipped
    del Ia, Ib, Ic, Id, outside

    return dst
def align_points_pipe(i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs, data_source, step=10):
    """
    Align data_source to target using affine transformation based on points derived from headers.
    
    Parameters:
    header_target (object): Target header information.
    header_source (object): Source header information.
    data_source (np.ndarray): Source data to be aligned.
    step (int): Step size for sampling points. Default is 10.
    
    Returns:
    np.ndarray: Aligned data.
    """
    data_source = cp.asarray(data_source)

    try:
        points_source, points_target = find_points_pipe( i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs, data_source, step=step)

        points_source = cp.asnumpy(points_source)
        points_target = cp.asnumpy(points_target)
        M_affine, _ = cv2.estimateAffinePartial2D(points_source, points_target)
        M_affine = cp.asarray(M_affine)
        align_data = warpAffine(data_source, M_affine)
    except Exception as e:
        print(f"Error in align_points_pipe: {e}")
        raise
    # finally:
    #     # 显式释放资源
    
    #     cp.cuda.Device().synchronize()
    #     del  points_source, points_target, M_affine
    return align_data




# from astropy.io import fits
# import os
# import sys
# sys.path.append("/home/image_processing_tools")
# import AIPT

# fits_folder = "i_image"
# tem_file_path = "/home/image_processing_tools/t_image/G021_mon_objt_180209T-G021_mon_objt_180210-template.fit"
# file_list1 = os.listdir(fits_folder)
# fits_file_list = [os.path.join(fits_folder,fits_file) for fits_file in file_list1 if fits_file.endswith('.fit')]
# t_hdul = fits.open(tem_file_path)
# t_header = t_hdul[0].header
# t_crpix, t_cd_matrix, t_crval, t_a_coeffs, t_b_coeffs, t_ap_coeffs, t_bp_coeffs = parse_header(t_header)
# for fits_file in fits_file_list:
#     print(fits_file)
#     i_hdu = fits.open(fits_file)
#     i_header = i_hdu[0].header
#     i_image = i_hdu[0].data
#     i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, i_ap_coeffs, i_bp_coeffs = parse_header(i_header)
#     # t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs, t_ap_coeffs, t_bp_coeffs = parse_header(t_header)
#     # aligned_data1 = AIPT.align_points_gpu(t_header, i_header, i_image, step=10)
#     aligned_data = align_points_pipe(i_crpix, i_cd_matrix, i_crval, i_ap_coeffs, i_bp_coeffs, t_crpix, t_cd_matrix, t_crval, t_a_coeffs, t_b_coeffs, i_image, step=10)
#     fits.writeto(f"aligned_{os.path.basename(fits_file)}", aligned_data.get(), i_header, overwrite=True)
    