import cupy as cp

def get_limits_zscale_g(input_arr, n_samples=1000, contrast=0.165, sig_fract=3.0, percent_fract=0.01, max_iter=5, low_cut=True, high_cut=True):
    """Estimating ranges with the zscale algorithm

	@type input_arr: cupy array
	@param input_arr: image data array as sample pixels to derive z-ranges
	@type contrast: float
	@param contrast: zscale contrast which should be larger than 0.
	@type sig_fract: float
	@param sig_fract: fraction of sigma clipping
	@type percent_fract: float
	@param percent_fract: convergence fraction
	@type max_iter: integer
	@param max_iter: max. of iterations
	@type low_cut: boolean
	@param low_cut: cut out only low values
	@type high_cut: boolean
	@param high_cut: cut out only high values
	@rtype: tuple
	@return: (min. value, max. value, number of iterations)

	"""
    work_arr = cp.asarray(input_arr)
    work_arr = work_arr[cp.isfinite(work_arr)]
    stride = int(max(1.0, work_arr.size / n_samples))
    samples = work_arr[::stride][:n_samples]
    work_arr = cp.sort(samples)  # sorting is done. 
    max_ind = len(work_arr) - 1
    midpoint_ind = int(len(work_arr) * 0.5)
    I_midpoint = work_arr[midpoint_ind]
    x = cp.array(range(0, len(work_arr))) - midpoint_ind
    y = cp.array(work_arr)
    temp = cp.vstack([x, cp.ones(len(x))]).T
    slope, intercept = cp.linalg.lstsq(temp, y, rcond=None)[0]
    old_slope = slope
    sig = y.std()
    upper_limit =  I_midpoint + sig_fract * sig
    lower_limit =  I_midpoint - sig_fract * sig

    if low_cut and high_cut:
        indices = cp.where((work_arr < upper_limit) & (work_arr > lower_limit))
    else:
        if low_cut:
            indices = cp.where((work_arr > lower_limit))
        else:
            indices = cp.where((work_arr < upper_limit))

    x = cp.array(indices[0]) - midpoint_ind
    y = cp.array(work_arr[indices])
    temp = cp.vstack([x, cp.ones(len(x))]).T
    slope, intercept = cp.linalg.lstsq(temp, y, rcond=None)[0]
    new_slope = slope
    iteration = 1
    while (((cp.fabs(old_slope - new_slope) / new_slope) > percent_fract) and (iteration < max_iter)) and (len(y) >= midpoint_ind):
        iteration += 1
        old_slope = new_slope
        sig = y.std()
        upper_limit = I_midpoint + sig_fract * sig
        lower_limit = I_midpoint - sig_fract * sig

        if low_cut and high_cut:
            indices = cp.where((work_arr < upper_limit) & (work_arr > lower_limit))
        else:
            if low_cut:
                indices = cp.where((work_arr > lower_limit))
            else:
                indices = cp.where((work_arr < upper_limit))

        x = cp.array(indices[0]) - midpoint_ind
        y = work_arr[indices]
        temp = cp.vstack([x, cp.ones(len(x))]).T
        slope, intercept = cp.linalg.lstsq(temp, y, rcond=None)[0]
        new_slope = slope
    vmin = I_midpoint + (new_slope / contrast) * (0 - midpoint_ind)
    vmax = I_midpoint + (new_slope / contrast) * (max_ind - midpoint_ind)
    return (vmin, vmax)

def get_limits_percentile_g(values, lower_percentile, upper_percentile, n_samples=None):
    """
    Estimates the value range (vmin, vmax) using the specified lower and upper percentiles.
    The function applies GPU acceleration via CuPy for efficient computation, especially when dealing with large datasets.

    Parameters:
    ----------
    values : array-like
        The input array of values for which the percentile range is being calculated.

    lower_percentile : float
        The lower percentile (e.g., 5 for the 5th percentile) used to estimate the minimum value (vmin).

    upper_percentile : float
        The upper percentile (e.g., 95 for the 95th percentile) used to estimate the maximum value (vmax).

    n_samples : int, optional
        The number of samples to use when calculating percentiles. If `n_samples` is provided and the input
        array is larger than this value, the function will randomly sample from the input values to limit
        the size for efficient computation. If None, the entire array is used.

    Returns:
    --------
    vmin : float
        The estimated minimum value based on the lower percentile.

    vmax : float
        The estimated maximum value based on the upper percentile.
    """
    values = cp.asarray(values).ravel()
    if n_samples is not None and values.size > n_samples:
        values = cp.random.choice(values, n_samples)
    values = values[cp.isfinite(values)]
    vmin, vmax = cp.percentile(values, (lower_percentile, upper_percentile))

    return vmin, vmax