from ctypes import POINTER, Array, Structure, c_float, c_int, pointer
from typing_extensions import Self
import numpy as np
from math import fabs, sqrt

from .define import MAXMASK

class FilterStruct(Structure):
  pass

FilterStruct._fields_ = [
  ("conv", POINTER(c_float * 1024)),
  ("nconv", c_int),
  ("convw", c_int),
  ("convh", c_int),
  ("varnorm", c_float),
  # ("bpann", POINTER(BpannStruct))
]

class Filter:
  conv = [0.0] * 1024
  nconv = 0
  convw = 0
  convh = 0
  varnorm = 0.0
  conv_py = [1.0, 2.0, 1.0, 2.0, 4.0, 2.0, 1.0, 2.0, 1.0]
  def __init__(self) -> Self:
    sum = 0
    var = 0
    self.convw = int(sqrt(len(self.conv_py)))
    self.convh = int(sqrt(len(self.conv_py)))
    for i in range(len(self.conv_py)):
      sum += fabs(self.conv_py[i])
      var += self.conv_py[i] * self.conv_py[i]
    for i in range(len(self.conv_py)):
      self.conv_py[i] /= sum
      self.conv[i] = self.conv_py[i]
    var = sqrt(var)
    self.varnorm = var
    self.nconv = len(self.conv_py)


def filter_init():
  sum = 0
  var = 0
  conv_py = [1.0, 2.0, 1.0, 2.0, 4.0, 2.0, 1.0, 2.0, 1.0]
  conv = (c_float * MAXMASK)()
  filter = FilterStruct(convw = int(sqrt(len(conv_py))), convh = int(sqrt(len(conv_py))))
  for i in range(len(conv_py)):
    sum += fabs(conv_py[i])
    var += conv_py[i]*conv_py[i]
  for i in range(len(conv_py)):
    conv_py[i] /= sum
    conv[i] = c_float(conv_py[i])
    # print(conv_py[i])
  var = sqrt(var)
  filter.varnorm = c_float(var)
  filter.nconv = c_int(len(conv_py))
  filter.conv = pointer(conv)
  return filter