import glob

import numpy as np
from astropy.io import fits
from scipy.ndimage.filters import generic_filter
import os
import SEx1
import re


# Outlier removal
def remove_abnormal_point(image, mask):
    '''
    :param image: The FITS image.                Type:Numpy.
    :param mask: The mask matrix of the image.   Type:Numpy.
    :return: image: Remove the anomaly picture.  Type:Numpy.
    '''
    # mask becomes the mask matrix of True and False
    mask = np.isnan(mask)

    # Create a temporary image matrix
    masked_image = np.copy(image).astype(float)

    # Set the outlier on the temporary image matrix to NAN
    masked_image[mask] = np.nan

    # Create a #*# average filter
    kernel = np.ones((9, 9)) / 81

    # The NaN value in the mask is averaged
    filtered_image = generic_filter(masked_image, mean_filter, size=kernel.shape, mode='constant', cval=np.NaN)

    # Copy the filtered value back to the original image
    image[mask] = filtered_image[mask]

    return image


# Defines a function for applying an average filter at a given position
def mean_filter(x):
    # Replace NaN with 0 and calculate the mean
    x = np.nan_to_num(x, nan=0.0)
    return np.mean(x)


def save_new_fits(img, dir, name):
    """
    Purpose: Saves a new FITS image file to the specified directory with the given name.

    :param img: The FITS image to be saved. Type: FITS Image.
    :param dir:  The directory where the FITS image will be saved. Type: String.
    :param name:    The name of the FITS image . Type: String.
    """
    path = os.path.join(dir, name)
    if os.path.exists(path):
        os.remove(path)
    grev = fits.PrimaryHDU(img)
    grevHDU = fits.HDUList([grev])
    grevHDU.writeto(path)
    print("save fits:{}".format(path))




# Obtain star coordinates via the sextractor
def Generate_Celestial_Coordinates(fits):
    # User-defined code
    result = []
    return result


# Generate the mask matrix
def Generate_mask(mask, Galactic_coordinates):
    background_mask = np.copy(mask).astype(float)

    background_mask[background_mask >= 0] = 1
    background_mask[np.isnan(background_mask)] = 0

    for coord in Galactic_coordinates:
        background_mask[coord[0] - 5:coord[0] + 5, coord[1] + 5:coord[1] - 5] = 0

    return background_mask


# Select a dark field image with a similar background
def Select_similar_backgrounds(background_mask, image, path, interval):
    dark_background_arr = []

    # Get the number of dark field images
    img_num = len(os.listdir(path))

    # Initializes a dark field three-dimensional matrix
    # User-defined code
    cube = []

    # Defining regular expressions
    pattern = r"(\d{4})-(\d{2})-(\d{2})-(\d{2})-(\d{2})-(\d{2})-(\d{3})"

     # Gets the file name and time information for all files in the folder
    files = [(f, re.findall(pattern, f)[0]) for f in os.listdir(path) if re.findall(pattern, f)]

    # Sort files by time information
    sorted_files = sorted(files, key=lambda x: x[1])
    sorted_files = np.array(sorted_files)


    # Calculate the background mean of the image to be processed
    image_background = np.multiply(background_mask, image).mean()

    # Traverse the background mean of the dark field image
    for i, filename in enumerate(sorted_files[:, 0]):
        dark_image = fits.open(path + '/' + filename)[0].data

        # The dark field matrix is obtained by traversing
        cube[:, :, i] = dark_image

        # Get the background mean of the dark field image
        dark_background = np.multiply(background_mask, dark_image).mean()
        dark_background_arr.append(dark_background)

    # Calculate which dark field image is closest to the image background mean
    dark_background_arr = np.array(dark_background_arr)
    dark_background_arr = abs(image_background - dark_background_arr)
    min_idx = np.argmin(dark_background_arr)

    # Calculate the exposure time of images with similar dark fields
    Exposure_time = min_idx * interval

    # The two-dimensional matrix with the closest dark field is obtained
    dark_image = cube[:, :,min_idx]

    return Exposure_time,dark_image


# Calculated dark current
def Calculate_Dark_Current(Polynomial, Clucster, Exposure_time):
    # Obtain the polynomial coefficients of the corresponding clusters
    coefficient = Polynomial[Clucster]
    # User-defined code
    Dark = []
    return Dark


# Generates a dark current matrix
def Generate_Dark_Matrix(mask, Polynomial, Exposure_time,dark_image):
    # Create a temporary mask matrix
    temporary_mask = np.copy(mask)

    # Go through each cluster, calculate the dark current, and assign the value to the mask matrix
    for i in range(Polynomial.shape[0]):
        Dark = Calculate_Dark_Current(Polynomial, i, Exposure_time)
        temporary_mask[temporary_mask == i] = Dark

    temporary_mask[np.isnan(mask)] = 0

    temporary_mask = temporary_mask - dark_image

    return temporary_mask


# Subtracting dark current
def subtract_dark(dark_matrix, image):
    image = abs(image - dark_matrix)

    return image

