from skimage.transform import resize
import numpy as np
import joblib
from photutils.aperture import CircularAperture, aperture_photometry

def possible_target_points(kernel_5, img_mat, x_y_w_h, x_y_w_h_f, n1):
    
    def getAperturePhotometry(img, h_size, w_size):
        positions = [(float(w_size), float(h_size))]
        radii = 3  # 设置孔径:2.2
        aperture = CircularAperture(positions, r=radii)
        phot_table = aperture_photometry(img, aperture, method='exact')
        for col in phot_table.colnames:
            phot_table[col].info.format = '%.8g'  # for consistent table output
        result = phot_table[0][3]
        # result = str(phot_table[0]).split(' ')[-13]
        return result

    # all_point_x = []  # 用于保存所有检测到的目标点
    # all_point_y = []
    i_x_y_w_h = []  # 保存检测框内该点所在的十字路径上点的x值
    i_x_y_w_h_f = []  # 保存检测框内该点所在的十字路径上点的x值
    # 针对暗源的阈值
    if n1 == 512:
        min_scale = 0.20
        min_corrcoef = 0.0    # 0.0
        photometry_value_init = 8   # 8以上
    else:
        # 针对亮源的阈值——前两个阈值设置小一点，减轻作用！
        min_scale = 0.38    # 比例值达到0.4
        min_corrcoef = 0.60     # 0.8 有误检(对那种条状结果进行过滤)
        photometry_value_init = 10   # 12 ——保证精确度！
    psf_col = kernel_5.reshape(kernel_5.size, order='C')
    width = kernel_5.shape[0]

    for i in range(len(x_y_w_h)):
        x1 = max(round(x_y_w_h[i][0] - x_y_w_h[i][2] // 2), 0)
        y1 = max(round(x_y_w_h[i][1] - x_y_w_h[i][3] // 2), 0)
        x2 = min(round(x_y_w_h[i][0] + x_y_w_h[i][2] // 2 + 1), n1 - 1)
        y2 = min(round(x_y_w_h[i][1] + x_y_w_h[i][3] // 2 + 1), n1 - 1)
        img = img_mat[y1:y2, x1:x2]
        h_size = (y2 - y1) / 2
        w_size = (x2 - x1) / 2
        # 条件1
        photometry_value = getAperturePhotometry(img, h_size, w_size)
        # 条件2
        real_i_num = 0  # 计数
        i_num = 0
        corrcoef = 0
        try:
            small_m = resize(img, (width, width), preserve_range=True)  # 将检测框resize为5*5大小
            small_m_col = small_m.reshape(small_m.size, order='C')
            corrcoef = abs(np.corrcoef(small_m_col, psf_col)[0, 1])
            # 条件3
            for j in range(len(small_m)):
                for t in range(len(small_m[0])):
                    # 5*5点的总数
                    i_num += small_m[j][t]
                    # 十字路径的点计数
                    if abs(j - width//2) <= 0 or abs(t - width//2) <= 0:
                        real_i_num += small_m[j][t]  # or 1
        except ValueError:
            print(x_y_w_h[i])

        # 符合点的初步判断条件：最小比例达到min_scale
        if i_num > 0 and corrcoef >= min_corrcoef and (
                real_i_num / i_num) >= min_scale and photometry_value >= photometry_value_init:
            i_x_y_w_h.append(x_y_w_h[i])
            i_x_y_w_h_f.append(x_y_w_h_f[i])

    return i_x_y_w_h, i_x_y_w_h_f

def use_classify_4096(det, img_fits, Model_PATH):

    target_result = []
    object_size = 13
    scale = 16
    size = 256
#     print(img_fits.shape)
    if len(det):  # 若是已经检测到目标
        x1 = det[:, 0] - det[:, 2] // 2
        x2 = det[:, 0] + det[:, 2] // 2
        y1 = det[:, 1] - det[:, 3] // 2
        y2 = det[:, 1] + det[:, 3] // 2
        n = len(det)
        classifier_model = joblib.load(Model_PATH)
        test_data = np.zeros((n, object_size * object_size))

        while n:
            n -= 1
            # 准备数据——4维
            data = resize(img_fits[max(int(round(y1[n] / scale)), 0): min(int(round(y2[n] / scale) + 1), size), max(int(round(x1[n] / scale)), 0):min(int(round(x2[n] / scale) + 1), size)], (object_size, object_size), preserve_range=True)
            test_data[n, :] = data.reshape(data.size, order='C')

        # 模型测试
        predicted = classifier_model.predict(test_data)

        for i in range(len(predicted)):
            if predicted[i] == 0:
                target_result.append(det[i])

    return target_result