import nvidia.dali.fn as fn
import fits_processing_toolkit as fpt
from Function.align_pipe.align import align_points_pipe   

def align_points_function(i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs, data_source, step=10, device='gpu'):
        return fn.python_function(i_crpix, i_cd_matrix, i_crval, i_a_coeffs, i_b_coeffs, t_crpix, t_cd_matrix, t_crval, t_ap_coeffs, t_bp_coeffs, data_source, step, function=align_points_pipe, num_outputs=1, device= device)

def get_background_function(data, box_size=(64,64), filter_size= (3,3), device='gpu'):
        return fn.python_function(data, box_size,filter_size,function=fpt.get_background_g, num_outputs=1, device=device)

def Power_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.Power_g, num_outputs=1, device= device)

def Linear_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.Linear_g, num_outputs=1, device= device)

def Log_function(data, vmin, vmax, device='gpu'):
        return fn.python_function(data, vmin, vmax,function=fpt.Log_g, num_outputs=1, device= device)

def Sinh_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.Sinh_g, num_outputs=1, device= device)

def Asinh_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.Asinh_g, num_outputs=1, device= device)

def HistEq_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.HistEq_g, num_outputs=1, device= device)

def Bytescale_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.Bytescale_g, num_outputs=1, device= device)

def Bytescale_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.Bytescale_g, num_outputs=1, device= device)

def get_limits_percentile_function(data,vmin, vmax, device='gpu'):
        return fn.python_function(data,vmin, vmax , function=fpt.get_limits_percentile_g, num_outputs=2, device= device)

def get_limits_zscale_function(data, device='gpu'):
        return fn.python_function(data, function=fpt.get_limits_zscale_g, num_outputs=2, device= device)

def subtract_images(i_data,t_data, device='gpu'):
        return fn.python_function(i_data,t_data, function=fpt.call_subtract_images, num_outputs=1, device= device)