
from nvidia.dali.pipeline import Pipeline
from  nvidia.dali.fn.experimental import readers
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator
from nvidia.dali.types import DALIDataType
import cupy as cp
import numpy as np  
from astropy.io import fits
import os
import sys

from Function import align_function, sub_function, source_detection_function, get_background_function, get_limits_zscale_function, Log_function
import time


def parse_header(header):
    crpix = [float(header['CRPIX1']), float(header['CRPIX2'])]
    if 'CD1_1' in header and 'CD1_2' in header and 'CD2_1' in header and 'CD2_2' in header:
        matrix = [[float(header['CD1_1']), float(header['CD1_2'])],
                  [float(header['CD2_1']), float(header['CD2_2'])]]
    elif 'PC1_1' in header and 'PC1_2' in header and 'PC2_1' in header and 'PC2_2' in header:
        matrix = [[float(header['PC1_1']), float(header['PC1_2'])],
                  [float(header['PC2_1']), float(header['PC2_2'])]]
    else:
        raise ValueError("Neither CD matrix nor PC matrix found in header.")

    crval = [float(header['CRVAL1']), float(header['CRVAL2'])]

        # 获取 A_ORDER 和 B_ORDER 的阶数
    a_order = header.get('A_ORDER', 0)
    b_order = header.get('B_ORDER', 0)

    # 初始化 a_coeffs 和 b_coeffs 矩阵，大小为 (a_order + 1, a_order + 1)
    a_coeffs = np.zeros((a_order + 1, a_order + 1))
    b_coeffs = np.zeros((b_order + 1, b_order + 1))

    # 填充 a_coeffs 矩阵
    for i in range(a_order + 1):
        for j in range(a_order + 1 - i):
            a_key = f'A_{i}_{j}'
            a_coeffs[i, j] = header.get(a_key, 0.0)

    # 填充 b_coeffs 矩阵
    for i in range(b_order + 1):
        for j in range(b_order + 1 - i):
            b_key = f'B_{i}_{j}'
            b_coeffs[i, j] = header.get(b_key, 0.0)

    # 获取 AP_ORDER 和 BP_ORDER 的阶数
    ap_order = header.get('AP_ORDER', 0)
    bp_order = header.get('BP_ORDER', 0)

    # 初始化 ap_coeffs 和 bp_coeffs 矩阵，大小为 (ap_order + 1, ap_order + 1)
    ap_coeffs = np.zeros((ap_order + 1, ap_order + 1))
    bp_coeffs = np.zeros((bp_order + 1, bp_order + 1))

    # 填充 ap_coeffs 矩阵
    for i in range(ap_order + 1):
        for j in range(ap_order + 1 - i):
            ap_key = f'AP_{i}_{j}'
            ap_coeffs[i, j] = header.get(ap_key, 0.0)

    # 填充 bp_coeffs 矩阵
    for i in range(bp_order + 1):
        for j in range(bp_order + 1 - i):
            bp_key = f'BP_{i}_{j}'
            bp_coeffs[i, j] = header.get(bp_key, 0.0)
    crpix = cp.array(crpix)
    matrix = cp.array(matrix)
    crval = cp.array(crval)
    a_coeffs = cp.array(a_coeffs)
    b_coeffs = cp.array(b_coeffs)
    ap_coeffs = cp.array(ap_coeffs)
    bp_coeffs = cp.array(bp_coeffs)
    
    return crpix, matrix, crval, a_coeffs, b_coeffs, ap_coeffs, bp_coeffs

class Header_InputIterator(object):
    
    def __init__(self,file_list, batch_size=1, hdu_index=[1]):
        self.batch_size = batch_size
        self.index = int(hdu_index[0]-1)
        self.fits_files = file_list
        self.name_list = []

    def __iter__(self):
        self.i = 0
        self.n = len(self.fits_files)
        return self
    
    def __next__(self):

        crpix_list = []
        matrix_list = []
        crval_list = []
        a_coeffs_list = []
        b_coeffs_list = []
        ap_coeffs_list = []
        bp_coeffs_list = []
        for _ in range(self.batch_size):
            if self.i >= self.n:
               self.i = 0
            self.fits_file = self.fits_files[self.i]
            self.name_list.append(self.fits_file.split('/')[-1])
            header = fits.getheader(self.fits_file, ext=self.index)
            crpix, matrix, crval, a_coeffs, b_coeffs,ap_coeffs, bp_coeffs= parse_header(header)
            crpix_list.append(crpix)
            matrix_list.append(matrix)
            crval_list.append(crval)
            a_coeffs_list.append(a_coeffs)
            b_coeffs_list.append(b_coeffs)
            ap_coeffs_list.append(ap_coeffs)
            bp_coeffs_list.append(bp_coeffs)
          
            self.i = self.i + 1
 
        return crpix_list,matrix_list,crval_list,a_coeffs_list,b_coeffs_list, ap_coeffs_list, bp_coeffs_list
    

    def __len__(self):
        return len(self.fits_files)
    
    
class AstroPipeline(Pipeline):

    def __init__(self, file_list,template_file_list = None,hdu_index=[1],  batch_size=1,num_threads=1, device_id=0, exec_pipelined=False, exec_async=False):
        super(AstroPipeline, self).__init__(batch_size=batch_size,
                                                     num_threads=num_threads,
                                                     device_id=device_id,
                                                     seed=12,
                                                     exec_async=exec_async,
                                                     exec_pipelined=exec_pipelined,
                                                     prefetch_queue_depth = {"cpu_size": 1, "gpu_size": 1})
        self.hdu_index = hdu_index
        self.source_iterator = Header_InputIterator(file_list, self.batch_size, hdu_index)
        if template_file_list is not None:
                self.template_iterator = Header_InputIterator(template_file_list, self.batch_size, hdu_index)
                self.template_folder = os.path.dirname(template_file_list[0])
                self.template_list_path = os.path.join(self.template_folder, 'template_list.txt')
                with open(self.template_list_path, 'w') as file:
                    for path in template_file_list:
                        file.write(f"{os.path.basename(path)}\n")
        self.file_folder = os.path.dirname(file_list[0])
        self.file_list_path = os.path.join(self.file_folder, 'file_list.txt')
       
        with open(self.file_list_path, 'w') as file:
                for path in file_list:
                        file.write(f"{os.path.basename(path)}\n")
    
    def read_header(self, num_outputs=7, device='gpu'):
        return fn.external_source(source=self.source_iterator, num_outputs=num_outputs, device=device)

    def read_template_header(self, num_outputs=7, device='gpu'):
        if self.template_iterator is None:
            raise ValueError("Template file list is not provided.")
        return fn.external_source(source=self.template_iterator, num_outputs=num_outputs, device=device)

    def define_graph(self):
        crpix_list,matrix_list,crval_list,a_coeffs_list,b_coeffs_list,_,_ = self.read_header()
        # image = self.read_header()
        t_crpix_list, t_matrix_list, t_crval_list,_,_, t_ap_coeffs_list, t_bp_coeffs_list = self.read_template_header()
        img_data = readers.fits(
            file_root=self.file_folder,  # 指定 FITS 文件的根目录
        #     file_filter=self.file_filter,  # 只读取 .fits 文件
            dtypes=DALIDataType.FLOAT,  # 读取的数据类型
            file_list=self.file_list_path,
            hdu_indices=self.hdu_index,  # 读取的 HDU 索引，默认为第二个 HDU
            random_shuffle=False, 
            device="gpu"
        )
        template_data = readers.fits(
                      file_root=self.template_folder,  # 指定 FITS 文件的根目录
        #     file_filter=self.file_filter,  # 只读取 .fits 文件
            dtypes=DALIDataType.FLOAT,  # 读取的数据类型
            file_list=self.template_list_path,
            hdu_indices=self.hdu_index,  # 读取的 HDU 索引，默认为第二个 HDU
            random_shuffle=False, 
            device="gpu"
        )
        img_data = fn.cast(img_data, dtype=DALIDataType.FLOAT)
        template_data = fn.cast(template_data, dtype=DALIDataType.FLOAT)
        align_image = align_function( crpix_list, matrix_list, crval_list, a_coeffs_list,b_coeffs_list,t_crpix_list, t_matrix_list, t_crval_list, t_ap_coeffs_list, t_bp_coeffs_list,img_data)
        sub_image = sub_function(align_image, template_data)
        
        return (sub_image,align_image,template_data)



if __name__ == '__main__':
        
        fits_folder = "i_image"
        tem_file_path = "template.fit"
        file_list1 = os.listdir(fits_folder)
        fits_file_list = [os.path.join(fits_folder,fits_file) for fits_file in file_list1 if fits_file.endswith('.fit')]
        total_num = len(fits_file_list)
        template_file_list = [tem_file_path for _ in range(total_num)]
        fits_file_list = sorted(fits_file_list)
        pipe = AstroPipeline(file_list=fits_file_list, hdu_index=[1],batch_size=4, device_id=1,exec_async=True,exec_pipelined=True)
        dali_loader = DALIGenericIterator(pipe, ['images', 'label'], reader_name=None, size=total_num)
        
        num_epochs = 5
        for epoch in range(num_epochs):
                for i, data in enumerate(dali_loader):
                 
                        # 从 DALI 数据加载器中提取 PyTorch 张量
                        inputs = data[0]['images'][0]
                        labels = data[0]['label'][0]

                        # 转换为浮点数并移动到 GPU 上
                        inputs = inputs.float().cuda()
                        labels = labels.float().cuda()
                        print("image__________",inputs)
                        print("label__________",labels)
                     
       
