
import cupy as cp
from cupyx.scipy.ndimage import  zoom, median_filter

def mad_std(a, factor = 0.6745):
    return cp.nanmedian(cp.abs(a - cp.nanmedian(a))) /factor

def mad_std_along_axis(arr, axis=1):
    return cp.apply_along_axis(mad_std, axis, arr)

def sigma_clip_along_axis(data, axis=1, cenfunc="median", stdfunc="std", sigma_lower=3.0, sigma_upper=3.0, maxiters=-1):
    # 拷贝数据以避免原始数据修改
	data_reshaped = cp.asarray(data)

	# 定义 sigma-clipping 的中心函数
	if cenfunc == "median":
		cen_func = cp.nanmedian
	elif cenfunc == "mean":
		cen_func = cp.nanmean
	else:
		raise ValueError("Invalid cenfunc")

	# 定义 sigma-clipping 的标准差函数
	if stdfunc == "std":
		std_func = cp.nanstd
	elif stdfunc == "mad_std":
		std_func = mad_std_along_axis
	else:
		raise ValueError("Invalid stdfunc")

	# 定义 sigma 上下限
	sigma_lower = cp.abs(sigma_lower)
	sigma_upper = cp.abs(sigma_upper)

	# 获取轴的长度
	axis_length = data_reshaped.shape[axis]
	iterations = 0
	# 开始迭代
	while maxiters == -1 or iterations < maxiters:
		center = cen_func(data_reshaped, axis=axis)
		std = std_func(data_reshaped, axis=axis)
		bound_lo = center - sigma_lower * std
		bound_hi = center + sigma_upper * std
		# 将 bound_lo 和 bound_hi 扩展到数据的形状
		broadcast_lo = cp.expand_dims(bound_lo, axis=axis)
		broadcast_hi = cp.expand_dims(bound_hi, axis=axis)
		mask = (data_reshaped < broadcast_lo) | (data_reshaped > broadcast_hi)
		data_reshaped[mask] = cp.nan
		iterations += 1

	return data_reshaped

def Source_Extractor_background(data, axis=1):
	
	median = cp.atleast_1d(cp.nanmedian(data, axis=axis))
	mean = cp.atleast_1d(cp.nanmean(data, axis=axis))
	std = cp.atleast_1d(cp.nanstd(data, axis=axis))
	bkg = cp.atleast_1d((2.5 * median) - (1.5 * mean))
	bkg = cp.where(std == 0, mean, bkg)
	idx = cp.where(std != 0)
	condition = (cp.abs(mean[idx] - median[idx]) / std[idx]) < 0.3
	bkg[idx] = cp.where(condition, bkg[idx], median[idx])
	if bkg.size == 1:
		bkg = bkg[0]
	result = bkg
	return result

def ModeEstimator_background(data, axis=1, median_factor=3.0, mean_factor=2.0):

	result = ((median_factor * cp.nanmedian(data, axis=axis))
						- (mean_factor * cp.nanmean(data, axis=axis)))
	return result

def reshape_data(data, box_size, mode = 'median', axis=1):
	"""
	First, pad or crop the 2D data array so that there are an
	integer number of boxes in both dimensions.

	Then reshape it into a different 2D array where each row
	represents the data in a single box.
	"""
	if mode == "median":
		handle = cp.nanmedian
	elif mode == "mean":
		handle = cp.nanmean
	elif mode == "Source_Extrator":
		handle = Source_Extractor_background
	elif mode == "ModeEstimator":
		handle = ModeEstimator_background
	else:
		raise ValueError("Invalid mode")
	data = cp.asarray(data, dtype=cp.float32)
	box_size_array = cp.asarray(box_size)
	data_shape = cp.asarray(data.shape)
	nboxes = data_shape  // box_size_array	
	extra_size = data_shape % box_size_array
	if cp.sum(extra_size) != 0:
		pad_size = (cp.ceil(data_shape
							/ box_size_array).astype(int)
					* box_size_array) - data_shape
		
		pad_width = ((0, pad_size[0].get()), (0, pad_size[1].get()))
		data = cp.pad(data, pad_width, mode='constant',
						constant_values=cp.nan)   
		data_shape = cp.asarray(data.shape)
		nboxes = data_shape // box_size_array
	box_npixels = int(cp.prod(box_size_array))
	nboxes_tot = int(cp.prod(nboxes))
	nboxes = (int(nboxes[0]), int(nboxes[1]))
	box_size = (int(box_size[0]), int(box_size[1]))	
	box_data = cp.swapaxes(data.reshape(
		nboxes[0], box_size[0],
		nboxes[1], box_size[1]),
		1, 2).reshape(nboxes_tot, box_npixels)
	box_data = sigma_clip_along_axis(box_data, axis=axis, maxiters=5)
	bkg_data = handle(box_data, axis=axis)
	box_idx = cp.arange(nboxes_tot)
	mesh_idx = cp.unravel_index(box_idx, nboxes)
	box_data2d = cp.full(nboxes, cp.nan)
	box_data2d[mesh_idx] = bkg_data

	return box_data2d

def get_background_g(data, box_size=(64,64), filter_size=(3,3), axis=1, mode='Source_Extrator'):
	box_data = reshape_data(data, box_size, axis=axis, mode=mode)

	filtdata = median_filter(box_data, size=filter_size, mode='nearest')
	zoom_factor = box_size
	result = zoom(filtdata, zoom_factor, order=3, mode='reflect', cval=0.0, grid_mode=True)
	result = result[0:data.shape[0], 0:data.shape[1]]
	minval = cp.min(filtdata)
	maxval = cp.max(filtdata)
	result = cp.clip(result, minval, maxval)

	return result
