import numpy as np
import cupy as cp
import cv2
from cupyx.scipy.ndimage import map_coordinates

def parse_header(header):
    """
    Parses the header information from a FITS file and extracts relevant WCS and distortion 
    parameters such as CRPIX, CD or PC matrix, CRVAL, and polynomial distortion coefficients.

    Parameters:
    ----------
    header : dict-like
        The FITS header containing the World Coordinate System (WCS) and distortion information.

    Returns:
    --------
    crpix : list of floats
        Reference pixel coordinates (CRPIX1, CRPIX2).
        
    matrix : list of list of floats
        The CD or PC transformation matrix for pixel-to-world coordinate conversion.
        
    crval : list of floats
        World coordinates at the reference pixel (CRVAL1, CRVAL2).
        
    a_coeffs : 2D numpy array
        Distortion correction coefficients for the 'A' direction.

    b_coeffs : 2D numpy array
        Distortion correction coefficients for the 'B' direction.
        
    ap_coeffs : 2D numpy array
        Inverse distortion correction coefficients for the 'A' direction.

    bp_coeffs : 2D numpy array
        Inverse distortion correction coefficients for the 'B' direction.

    Raises:
    -------
    ValueError:
        If neither the CD matrix nor the PC matrix is found in the header.
    """
    crpix = [float(header['CRPIX1']), float(header['CRPIX2'])]
    if 'CD1_1' in header and 'CD1_2' in header and 'CD2_1' in header and 'CD2_2' in header:
        matrix = [[float(header['CD1_1']), float(header['CD1_2'])],
                  [float(header['CD2_1']), float(header['CD2_2'])]]
    elif 'PC1_1' in header and 'PC1_2' in header and 'PC2_1' in header and 'PC2_2' in header:
        matrix = [[float(header['PC1_1']), float(header['PC1_2'])],
                  [float(header['PC2_1']), float(header['PC2_2'])]]
    else:
        raise ValueError("Neither CD matrix nor PC matrix found in header.")

    crval = [float(header['CRVAL1']), float(header['CRVAL2'])]
    
    a_order = header.get('A_ORDER', 0)
    b_order = header.get('B_ORDER', 0)
    
    a_coeffs = np.zeros((a_order + 1, a_order + 1))
    b_coeffs = np.zeros((b_order + 1, b_order + 1))
    
    for i in range(a_order + 1):
        for j in range(a_order + 1 - i):
            a_key = f'A_{i}_{j}'
            a_coeffs[i, j] = header.get(a_key, 0.0)

    for i in range(b_order + 1):
        for j in range(b_order + 1 - i):
            b_key = f'B_{i}_{j}'
            b_coeffs[i, j] = header.get(b_key, 0.0)

    ap_order = header.get('AP_ORDER', 0)
    bp_order = header.get('BP_ORDER', 0)

    ap_coeffs = np.zeros((ap_order + 1, ap_order + 1))
    bp_coeffs = np.zeros((bp_order + 1, bp_order + 1))

    for i in range(ap_order + 1):
        for j in range(ap_order + 1 - i):
            ap_key = f'AP_{i}_{j}'
            ap_coeffs[i, j] = header.get(ap_key, 0.0)

    for i in range(bp_order + 1):
        for j in range(bp_order + 1 - i):
            bp_key = f'BP_{i}_{j}'
            bp_coeffs[i, j] = header.get(bp_key, 0.0)

    return crpix, matrix, crval, a_coeffs, b_coeffs, ap_coeffs, bp_coeffs

def apply_sip_forward(pixcrd, crpix, a_coeffs_matrix, b_coeffs_matrix):
    """
    Apply the SIP (Simple Imaging Polynomial) forward distortion correction model to pixel coordinates.

    The SIP model adds distortion to the original pixel coordinates based on polynomial coefficients.
    This function takes input pixel coordinates and applies the forward SIP distortion, returning the
    corrected pixel coordinates.

    Parameters:
    ----------
    pixcrd : array-like, shape (N, 2)
        A 2D array of pixel coordinates, where each row represents a coordinate (x, y).
        
    crpix : list or array-like
        The reference pixel coordinates (CRPIX1, CRPIX2). These are subtracted from the input pixel
        coordinates to center the distortion around the reference pixel.

    a_coeffs_matrix : 2D array (numpy or cupy array)
        The A coefficient matrix for distortion in the u-direction (x-axis).

    b_coeffs_matrix : 2D array (numpy or cupy array)
        The B coefficient matrix for distortion in the v-direction (y-axis).

    Returns:
    --------
    result : array-like, shape (N, 2)
        The SIP-corrected pixel coordinates.
    """
    u = pixcrd[:, 0] - crpix[0]
    v = pixcrd[:, 1] - crpix[1]
    f_u = cp.zeros_like(u)
    f_v = cp.zeros_like(v)
    rows, cols = a_coeffs_matrix.shape
    for i in range(rows):
        for j in range(cols):
            if a_coeffs_matrix[i, j] != 0:  # 忽略零系数
                f_u += a_coeffs_matrix[i, j] * (u ** i) * (v ** j)
            if b_coeffs_matrix[i, j] != 0:  # 忽略零系数
                f_v += b_coeffs_matrix[i, j] * (u ** i) * (v ** j)
    u += f_u
    v += f_v
    result = cp.stack([u + crpix[0], v + crpix[1]], axis=-1)

    return result


def apply_sip_inverse(pixcrd, crpix, ap_coeffs_matrix, bp_coeffs_matrix):
    """
    Apply the SIP (Simple Imaging Polynomial) inverse distortion correction model to pixel coordinates.

    The SIP inverse model removes the distortion applied during the forward transformation. This function 
    takes distorted pixel coordinates and applies the inverse SIP distortion using the provided polynomial 
    coefficients, returning the corrected pixel coordinates.

    Parameters:
    ----------
    pixcrd : array-like, shape (N, 2)
        A 2D array of pixel coordinates, where each row represents a coordinate (x, y).
        
    crpix : list or array-like
        The reference pixel coordinates (CRPIX1, CRPIX2). These are subtracted from the input pixel
        coordinates to center the distortion around the reference pixel.

    ap_coeffs_matrix : 2D array (numpy or cupy array)
        The AP coefficient matrix for inverse distortion in the u-direction (x-axis).

    bp_coeffs_matrix : 2D array (numpy or cupy array)
        The BP coefficient matrix for inverse distortion in the v-direction (y-axis).

    Returns:
    --------
    result : array-like, shape (N, 2)
        The SIP-inverse-corrected pixel coordinates.
    """
    u = pixcrd[:, 0] - crpix[0]
    v = pixcrd[:, 1] - crpix[1]
    
    f_u = cp.zeros_like(u)
    f_v = cp.zeros_like(v)
    
    rows, cols = ap_coeffs_matrix.shape
    for i in range(rows):
        for j in range(cols):
            if ap_coeffs_matrix[i, j] != 0:  
                f_u += ap_coeffs_matrix[i, j] * (u ** i) * (v ** j)
            if bp_coeffs_matrix[i, j] != 0: 
                f_v += bp_coeffs_matrix[i, j] * (u ** i) * (v ** j)

    u += f_u
    v += f_v
    
    result = cp.stack([u + crpix[0], v + crpix[1]], axis=-1)

    return result


def pix2world_gpu(pixcrd, header):
    """
    Converts pixel coordinates to world coordinates (RA, Dec) using header information and
    applying SIP forward distortion correction, linear transformation, and tangent projection.

    Parameters:
    ----------
    pixcrd : array-like, shape (N, 2)
        A 2D array of pixel coordinates (x, y) to be transformed.

    header : dict-like
        The FITS header containing the World Coordinate System (WCS) and distortion information.

    Returns:
    --------
    ra_dec_coords : array-like, shape (N, 2)
        The transformed world coordinates (RA, Dec) corresponding to the input pixel coordinates.
    """
    crpix, cd_matrix, crval, a_coeffs, b_coeffs, _, _ = parse_header(header)
    pixcrd = cp.array(pixcrd)
    crpix = cp.array(crpix)
    cd_matrix = cp.array(cd_matrix)
    crval = cp.array(crval)

    # Apply SIP forward transformation
    if a_coeffs.any() and b_coeffs.any():
        pixcrd = apply_sip_forward(pixcrd, crpix, a_coeffs, b_coeffs)

    # 线性变换
    delta_pix = pixcrd - crpix
    world_coords = cp.dot(delta_pix, cd_matrix.T)

    # 正切投影
    xi = cp.deg2rad(world_coords[:, 0])
    eta = cp.deg2rad(world_coords[:, 1])

    ra0 = cp.deg2rad(crval[0])
    dec0 = cp.deg2rad(crval[1])

    denominator = cp.cos(dec0) - eta * cp.sin(dec0)
    ra = cp.arctan2(xi, denominator) + ra0
    dec = cp.arctan2(eta * cp.cos(dec0) + cp.sin(dec0), cp.sqrt(xi**2 + denominator**2))

    ra = cp.rad2deg(ra)
    dec = cp.rad2deg(dec)

    ra_dec_coords = cp.vstack((ra, dec)).T
    return ra_dec_coords

def world2pix_gpu(ra_dec, header):
    """
    Converts world coordinates (RA, Dec) to pixel coordinates using the WCS information in the header.
    The function applies tangent projection, linear transformation, and optionally the SIP inverse distortion correction.
    GPU acceleration is used via CuPy for faster computation.

    Parameters:
    ----------
    ra_dec : ndarray, shape (N, 2)
        The world coordinates (RA, Dec) in degrees, to be transformed to pixel coordinates.
        
    header : dict
        The FITS header containing the World Coordinate System (WCS) and SIP distortion information.

    Returns:
    --------
    pixcrd : ndarray
        The corresponding pixel coordinates after the transformation.
    """
    
    crpix, cd_matrix, crval, _, _, ap_coeffs, bp_coeffs = parse_header(header)
    ra_dec = cp.array(ra_dec)
    crpix = cp.array(crpix)
    cd_matrix = cp.array(cd_matrix)
    crval = cp.array(crval)
    ra = cp.deg2rad(ra_dec[:, 0])
    dec = cp.deg2rad(ra_dec[:, 1])

    ra0 = cp.deg2rad(crval[0])
    dec0 = cp.deg2rad(crval[1])

    xi = cp.cos(dec) * cp.sin(ra - ra0) / (cp.sin(dec) * cp.sin(dec0) + cp.cos(dec) * cp.cos(dec0) * cp.cos(ra - ra0))
    eta = (cp.sin(dec) * cp.cos(dec0) - cp.cos(dec) * cp.sin(dec0) * cp.cos(ra - ra0)) / (cp.sin(dec) * cp.sin(dec0) + cp.cos(dec) * cp.cos(dec0) * cp.cos(ra - ra0))

    xi = cp.rad2deg(xi)
    eta = cp.rad2deg(eta)

    world_coords = cp.vstack((xi, eta)).T

    inv_cd_matrix = cp.linalg.inv(cd_matrix)
    
    delta_pix = cp.dot(world_coords, inv_cd_matrix.T)

    pixcrd = delta_pix + crpix
    
    # Apply SIP inverse transformation
    if ap_coeffs.any() and bp_coeffs.any():
            pixcrd = apply_sip_inverse(pixcrd, crpix, ap_coeffs, bp_coeffs)
    return pixcrd

    
def pix2pix_gpu(pixcrd, header1, header2):

    ra_dec = pix2world_gpu(pixcrd, header1)
    new_pixcrd = world2pix_gpu(ra_dec, header2)

    return new_pixcrd


def find_points_gpu(header_target, header_source, data_source, step=10):
    """
    Find corresponding points between source and target images using WCS headers.
    
    Parameters:
    header_target (object): Target WCS header information.
    header_source (object): Source WCS header information.
    data_source (np.ndarray): Source data to be aligned.
    step (float): Step size for sampling points. Default is 10.
    
    Returns:
    np.ndarray: points_source, points_target
    """

    pixel_out = cp.meshgrid(
        *[cp.arange(size, dtype=float) for size in data_source.shape],
        indexing="ij",
        sparse=False,
        copy=False,
    )
    pixel_out = cp.asarray(pixel_out)
    pixel_out = pixel_out[:, ::int(step), ::int(step)]
    pixel_out = cp.concatenate((pixel_out[1].ravel().reshape(-1, 1), pixel_out[0].ravel().reshape(-1, 1)), axis=1)
    pixel_in = pix2pix_gpu(pixel_out, header_target, header_source)
    points_source = cp.asarray(pixel_in)
    points_target = cp.asarray(pixel_out)

        
    return points_source, points_target

def find_points_gpu_all(header_target, header_source,  data_source):
    """
    Find corresponding points between source and target images using WCS headers.
    
    Parameters:
    header_target (object): Target WCS header information.
    header_source (object): Source WCS header information.
    data_source (np.ndarray): Source data to be aligned.
    step (float): Step size for sampling points. Default is 1e4.
    
    Returns:
    np.ndarray: points_source, points_target
    """
    pixel_out = cp.meshgrid(
        *[cp.arange(size, dtype=float) for size in data_source.shape],
        indexing="ij",
        sparse=False,
        copy=False,
    )
    pixel_out = cp.asarray(pixel_out)
    pixel_out = cp.concatenate((pixel_out[1].ravel().reshape(-1, 1), pixel_out[0].ravel().reshape(-1, 1)), axis=1)
    pixel_in = pix2pix_gpu(pixel_out, header_target, header_source)
    points_source = cp.asarray(pixel_in)
    points_target = cp.asarray(pixel_out)
    return points_source, points_target

def warpAffine(image, M):
    """
    Applies an affine transformation to the input image using the transformation matrix M.
    
    Parameters:
    ----------
    image : ndarray
        The input 2D image to be transformed, represented as a CuPy array.
        
    M : ndarray
        The 2x3 affine transformation matrix.

    Returns:
    --------
    dst : ndarray
        The transformed image after applying the affine transformation, with the same shape as the input image.
    """
    rows, cols = image.shape
    dst = cp.zeros((rows, cols), dtype=image.dtype)
    
    M_inv = cp.linalg.inv(cp.vstack([M, [0, 0, 1]]))[:2, :]

    x, y = cp.meshgrid(cp.arange(cols), cp.arange(rows))
    coords = cp.stack([x.ravel(), y.ravel(), cp.ones_like(x).ravel()])
    
    src_coords = M_inv @ coords
    src_x = src_coords[0, :].reshape(rows, cols)
    src_y = src_coords[1, :].reshape(rows, cols)

    x0 = cp.floor(src_x).astype(int)
    x1 = x0 + 1
    y0 = cp.floor(src_y).astype(int)
    y1 = y0 + 1

    a = src_x - x0
    b = src_y - y0
    
    x0_clipped = cp.clip(x0, 0, image.shape[1] - 1)
    x1_clipped = cp.clip(x1, 0, image.shape[1] - 1)
    y0_clipped = cp.clip(y0, 0, image.shape[0] - 1)
    y1_clipped = cp.clip(y1, 0, image.shape[0] - 1)
    
    Ia = image[y0_clipped, x0_clipped]
    Ib = image[y0_clipped, x1_clipped]
    Ic = image[y1_clipped, x0_clipped]
    Id = image[y1_clipped, x1_clipped]
    
    dst = (1 - a) * (1 - b) * Ia + a * (1 - b) * Ib + (1 - a) * b * Ic + a * b * Id

    outside = (src_x < 0) | (src_x >= image.shape[1]) | (src_y < 0) | (src_y >= image.shape[0])
    dst[outside] = 0
    
    dst = cp.array(dst, dtype=image.dtype)

    return dst


def align_points_gpu(header_target, header_source, data_source, step=10):
    """
    Aligns the source data (data_source) to the target coordinate system using an affine transformation,
    based on corresponding points derived from the FITS headers. GPU acceleration is applied using CuPy.

    Parameters:
    ----------
    header_target : dict
        The FITS header of the target image, which defines the target world coordinate system (WCS).

    header_source : dict
        The FITS header of the source image, which defines the source world coordinate system (WCS).

    data_source : np.ndarray
        The source image data (2D array) that needs to be aligned to the target WCS.

    step : int, optional
        The step size for sampling points from the source and target headers. A smaller step size results
        in more points for the affine transformation, but increases computation time. Default is 10.

    Returns:
    --------
    align_data : cp.ndarray
        The aligned source image data after applying the affine transformation.
    """
    
    data_source = cp.asarray(data_source)
    
    try:
        points_source, points_target = find_points_gpu(header_target, header_source, data_source, step=step)
        points_source = cp.asnumpy(points_source)
        points_target = cp.asnumpy(points_target)
        print(type(points_source), points_source.shape)
        print(type(points_target), points_target.shape)
        M_affine, _ = cv2.estimateAffinePartial2D(points_source, points_target)
        M_affine = cp.asarray(M_affine)
        align_data = warpAffine(data_source, M_affine)
    except Exception as e:
        print(f"Error in align_points_gpu: {e}")
        raise

    return align_data

# 直接重投影
def align_image_gpu(header_target, header_source, data_source, order=1, cval=0):
    """
    Aligns the source image (data_source) to the target coordinate system defined by header_target 
    using GPU acceleration with CuPy.

    Parameters:
    ----------
    header_target : dict
        The FITS header of the target image, which defines the target world coordinate system (WCS).

    header_source : dict
        The FITS header of the source image, which defines the world coordinate system (WCS) of the source image.

    data_source : ndarray
        The source image data (2D array) that needs to be aligned to the target WCS.

    order : int, optional
        The order of the spline interpolation (default is 1, meaning linear interpolation).

    cval : float, optional
        The value to use for points outside the boundaries of the input. Default is 0.

    Returns:
    --------
    align_data : ndarray
        The aligned image data, with the same shape as data_source, transformed to the target WCS.
    """
    data_source  = cp.asarray(data_source)
    points_source, points_target = find_points_gpu_all(header_target, header_source, data_source)
    align_data = cp.empty(data_source.shape)
    points = points_source[:,[1,0]].T
    map_coordinates(
        input=data_source,
        coordinates=points,   
        output=align_data.ravel(),
        order=order,
        cval=cval,
        mode="constant"
    )
    align_data = cp.array(align_data, dtype=data_source.dtype)

    return align_data


