import os
import pprint
from typing import Dict
import numpy as np
import nvidia.dali.fn as fn
import nvidia.dali.plugin_manager as plugin_manager
from nvidia.dali.pipeline import pipeline_def
import cupy as cp
from astropy.io import fits


current_dir = os.path.dirname(os.path.abspath(__file__))
plugin_manager.load_library(os.path.join(current_dir, 'build/libsub.so'))


def get_w_h(image):
    
    return np.asarray(image.shape[0]), np.asarray(image.shape[1])

def sub_function(image_inputs, image_templates, device='gpu'):

  # height_i, width_i = fn.python_function(image_inputs,function=get_w_h, num_outputs=2, device= device)
  # height_t, width_t = fn.python_function(image_templates,function=get_w_h, num_outputs=2, device= device)
  # height_i = height_i.cpu()
  # width_i = width_i.cpu()
  # height_t = height_t.cpu()
  # width_t = width_t.cpu()
  # print(height_i, width_i, height_t, width_t)
  res =  fn.sub(image_inputs, image_templates, height_i=4096, width_i=4096, height_t=4096, width_t=4096, device=device)

  return fn.reshape(res, shape=[4096, 4096])