import os
import nvidia.dali
import nvidia.dali.experimental
import nvidia.dali.fn as fn
# import nvidia.dali.libdali
import nvidia.dali.plugin_manager as pm
from nvidia.dali.pipeline import pipeline_def
# from nvidia.dali.plugin.pytorch import DALIGenericIterator
import sys
current_directory = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_directory)

from py_params.brain import BrainStruct
from py_params.filter import Filter
from py_params.flag_obj import FlagObj
from py_params.obj_struct import Obj2Struct
from py_params.plist import PlistVars
from py_params.pic_struct import PicStruct
from py_params.prefs import Prefs

pm.load_library(os.path.join(current_directory,"extract/build/libextract.so"))
# pm.load_library("/home/dell461/cl/Paper_data/fits_processing_toolkit/Pipeline/Function/sub_pipe/extract/build/libextract.so")
# data_root_dir = "./data"
# fits_dir = os.path.join(data_root_dir, 'fits')
# fits_dir = "/home/laityh/LaityHCode/dali_project/test_1024"
fits_dir = "/home/dell461/cl/Paper_data/fits_processing_toolkit/Pipeline/Function/sub_pipe/data"

def source_detection_function(images, prefs_file_path=None):
  prefs = Prefs(prefs_file_path)
  field = PicStruct(prefs)
  filter = Filter()
  flag = FlagObj()
  plist = PlistVars()
  brain = BrainStruct(os.path.join(current_directory,"default.nnw"))
  obj2 = Obj2Struct()

  d_mean, d_sigma, backmean, backsig, numValidObj, numValidPix, images, obj_double, obj_float = fn.scan(
    images, backw=field.backw, backfthresh=prefs.backfthresh, nbackfx=field.nbackfx, nbackfy=field.nbackfy,
    field_back_type=field.back_type, prefs_ndthresh=prefs.ndthresh, prefs_nthresh=prefs.nthresh,
    prefs_thresh=prefs.thresh, prefs_thresh_type=prefs.thresh_type, prefs_dthresh=prefs.dthresh,
    prefs_deblend_nthresh=prefs.deblend_nthresh, flag_poserr_mx2=flag.poserr_mx2, flagobj_peakx=flag.peakx, flagobj_iso_0=flag.iso[0],
    flagobj_fwhm=flag.fwhm, field_gain=field.gain, field_satur_level=field.satur_level, field_ngamma=field.ngamma,
    field_nbackx=field.nbackfx, field_nbacky=field.nbackfy, plist=[plist.plistexist_var, plist.plistoff_var],
    prefs_clean_flag=prefs.clean_flag, prefs_weightgain_flag=prefs.weightgain_flag, prefs_detect_type=prefs.detect_type,
    prefs_clean_param=prefs.clean_param, prefs_clean_stacksize=prefs.clean_stacksize,
    field_ymin=field.ymin, field_ymax=field.ymax, ext_minarea=prefs.ext_minarea, conv=filter.conv, convw=filter.convw, device='gpu',
    **prefs.build_endobject_param(), **brain.build_endobject_param(), **obj2.build_endobject_param()
  )
  nx_mesh = (field.width - 1) / field.backw + 1
  ny_mesh = (field.height - 1) / field.backw + 1
  d_mean = fn.reshape(d_mean, shape=[nx_mesh, ny_mesh])
  d_sigma = fn.reshape(d_sigma, shape=[nx_mesh, ny_mesh])
  obj_double = fn.reshape(obj_double, shape=[-1, 5])
  obj_float = fn.reshape(obj_float, shape=[-1, 24])

  return d_mean, d_sigma, backmean, backsig, numValidObj, numValidPix, obj_double, obj_float



@pipeline_def(num_threads=1, device_id=0)
def extract_pipe():
  images: fn.DataNode | os.Sequence[fn.DataNode] | None = fn.experimental.readers.fits(
    file_root=fits_dir, file_filter='*.fit', name='FITS_READER', device='gpu', hdu_indices=[1]
  )
  images = fn.cast(images, dtype=nvidia.dali.types.FLOAT)
  d_mean, d_sigma, backmean, backsig, numValidObj, numValidPix, obj_double, obj_float = source_detection_function(images)
  # prefs = Prefs()
  # field = PicStruct(prefs)
  # filter = Filter()
  # flag = FlagObj()
  # plist = PlistVars()
  # brain = BrainStruct("default.nnw")
  # obj2 = Obj2Struct()

  # d_mean, d_sigma, backmean, backsig, numValidObj, numValidPix, images, obj_double, obj_float = fn.scan(
  #   images, backw=field.backw, backfthresh=prefs.backfthresh, nbackfx=field.nbackfx, nbackfy=field.nbackfy,
  #   field_back_type=field.back_type, prefs_ndthresh=prefs.ndthresh, prefs_nthresh=prefs.nthresh,
  #   prefs_thresh=prefs.thresh, prefs_thresh_type=prefs.thresh_type, prefs_dthresh=prefs.dthresh,
  #   prefs_deblend_nthresh=prefs.deblend_nthresh, flag_poserr_mx2=flag.poserr_mx2, flagobj_peakx=flag.peakx, flagobj_iso_0=flag.iso[0],
  #   flagobj_fwhm=flag.fwhm, field_gain=field.gain, field_satur_level=field.satur_level, field_ngamma=field.ngamma,
  #   field_nbackx=field.nbackfx, field_nbacky=field.nbackfy, plist=[plist.plistexist_var, plist.plistoff_var],
  #   prefs_clean_flag=prefs.clean_flag, prefs_weightgain_flag=prefs.weightgain_flag, prefs_detect_type=prefs.detect_type,
  #   prefs_clean_param=prefs.clean_param, prefs_clean_stacksize=prefs.clean_stacksize,
  #   field_ymin=field.ymin, field_ymax=field.ymax, ext_minarea=prefs.ext_minarea, conv=filter.conv, convw=filter.convw, device='gpu',
  #   **prefs.build_endobject_param(), **brain.build_endobject_param(), **obj2.build_endobject_param()
  # )

  return d_mean, d_sigma, backmean, backsig, images, numValidObj, numValidPix, obj_double, obj_float

if __name__ == "__main__":
    pipe = extract_pipe(batch_size=2)
    pipe.build()

    for i in range(5):
      # pipe.run()
      (d_mean, d_sigma, backmean, backsig, images, numValidObj, numValidPix, obj_double, obj_float) = pipe.run()

    print(obj_double.shape,"------------------------------------------------")
    # print(obj_float.as_cpu().as_array().reshape(2, 7126, 24)[0][3000][0])
    # print(obj_float.as_cpu().as_array().reshape(2, 7126, 24)[0][3000][1])