from ctypes import CDLL, POINTER, alignment, byref, c_double, c_float, c_int, c_uint, cast, create_string_buffer, sizeof
from define import MAXNAPER, NAPER
from flag_obj import FlagObj
from plist import PlistVars
from prefs import Prefs
from structs.brain import BrainStruct
from structs.obj_struct import Obj2Struct
from structs.outobj_struct import OutobjStruct
from structs.pic_struct import PicStruct
from structs.prefs2 import Prefs2
from structs.wcs_struct import WcsStruct
import os

def analyse_py(detected_num, field: PicStruct, lib: CDLL, prefs: Prefs, flag_obj: FlagObj, plist_vars: PlistVars):
  # lib.analyse_gpu.argtypes = [
  #   c_double, c_int, c_int, c_float,
  #   c_float, c_double, c_double, c_float, c_float, c_double, # field global vars
  #   c_int, c_int, c_int, c_int, c_int, c_int, c_int, c_int, c_int, c_int, c_int, c_int, # constant vars
  # ]
  lib.analyse_gpu(
    flag_obj.poserr_mx2, flag_obj.peakx, flag_obj.iso[0], flag_obj.fwhm,
    c_float(field.backsig), c_double(field.gain), c_double(field.ngamma), c_float(field.thresh), c_float(field.dthresh), c_double(field.satur_level),
    plist_vars.plistexist_var, plist_vars.plistoff_var,
    c_int(prefs.weightgain_flag), c_int(prefs.ext_minarea), c_int(prefs.clean_flag), c_int(prefs.detect_type),
    c_int(field.nbackx), c_int(field.nbacky), c_int(field.backw), c_int(field.backh), c_int(field.width),
    c_int(detected_num),
  )
  need_mem = detected_num * sizeof(c_uint)
  buffer = create_string_buffer(need_mem)
  masterIndex = cast(buffer, POINTER(c_uint))
  
  lib.run_clean.argtypes = [POINTER(c_uint), c_double, c_int, c_int]
  lib.run_clean(masterIndex, prefs.clean_param, prefs.clean_stacksize, detected_num)

  prefs2 = Prefs2()
  prefs2.naper = prefs.naper
  prefs2.detect_type = prefs.detect_type
  prefs2.mask_type = prefs.mask_type
  prefs2.apert = (c_double * MAXNAPER)()
  for i in range(MAXNAPER):
    prefs2.apert[i] = prefs.apert[i]
  prefs2.flux_apersize = prefs.flux_apersize
  prefs2.flux_err_apersize = prefs.flux_err_apersize
  prefs2.mag_apersize = prefs.mag_apersize
  prefs2.mag_zeropoint = prefs.mag_zeropoint
  prefs2.autoaper = (c_double * 2)()
  prefs2.autoparam = (c_double * 2)()
  for i in range(2):
    prefs2.autoaper[i] = prefs.autoaper[i]
    prefs2.autoparam[i] = prefs2.autoparam[i]
  prefs2.growth_flag = prefs.growth_flag
  prefs2.seeing_fwhm = prefs.seeing_fwhm
  prefs2.world_flag = prefs.world_flag

  script_dir = os.path.dirname(os.path.abspath(__file__))
  brain = BrainStruct(nnw_name=os.path.join(script_dir,"default.nnw"))
  
  flagobj2 = Obj2Struct()
  detected_num * sizeof(c_uint)
  buffer = create_string_buffer(detected_num * sizeof(OutobjStruct))
  objlist = cast(buffer, POINTER(OutobjStruct))
  # objlist = OutobjStruct()
  wcs = WcsStruct()

  lib.endobject_gpu.argtypes = [
    POINTER(Prefs2), POINTER(BrainStruct), POINTER(Obj2Struct), POINTER(OutobjStruct), POINTER(WcsStruct),
    c_int, c_int, c_int, c_int, c_float, c_float, c_double, c_double, c_double, c_double, c_int
  ]

  lib.endobject_gpu(
    byref(prefs2),
    byref(brain),
    byref(flagobj2),
    objlist,
    byref(wcs),
    field.ymin,
    field.ymax,
    field.width,
    field.height,
    c_float(field.backsig),
    c_float(field.thresh),
    c_double(field.satur_level),
    c_double(field.pixscale),
    c_double(field.ngamma),
    c_double(field.gain),
    detected_num
  )

  lib.clear_detection()
  lib.clear_device(field.strip)

  num = 0
  for i in range(detected_num):
    byte_n = sizeof(c_int)
    if (masterIndex[i] == (pow(2, byte_n * 8) - 1)):
      objlist[i].number = num
      num += 1
  
  print("Objects: detected %d \t sextracted %d\n", detected_num, num)
  return objlist, num