from ctypes import POINTER, Array, Structure, c_float, c_int, pointer
from math import fabs, sqrt

from define import MAXMASK

class BpannStruct(Structure):
  pass

BpannStruct._fields_ = [
  ("nlayers", c_int),
  ("nn", POINTER(c_int)),
  ("neuron", POINTER(POINTER(c_float))),
  ("weight", POINTER(POINTER(c_float))),
  ("linearoutflag", c_int)
]

class FilterStruct(Structure):
  pass

FilterStruct._fields_ = [
  ("conv", POINTER(c_float * 1024)),
  ("nconv", c_int),
  ("convw", c_int),
  ("convh", c_int),
  ("varnorm", c_float),
  ("bpann", POINTER(BpannStruct))
]


def filter_init():
  sum = 0
  var = 0
  conv_py = [1.0, 2.0, 1.0, 2.0, 4.0, 2.0, 1.0, 2.0, 1.0]
  conv = (c_float * MAXMASK)()
  filter = FilterStruct(convw = int(sqrt(len(conv_py))), convh = int(sqrt(len(conv_py))))
  for i in range(len(conv_py)):
    sum += fabs(conv_py[i])
    var += conv_py[i]*conv_py[i]
  for i in range(len(conv_py)):
    conv_py[i] /= sum
    conv[i] = c_float(conv_py[i])
    # print(conv_py[i])
  var = sqrt(var)
  filter.varnorm = c_float(var)
  filter.nconv = c_int(len(conv_py))
  filter.conv = pointer(conv)
  return filter