# -*- coding: utf-8 -*-
# ---
# jupyter:
#   jupytext:
#     formats: ipynb,py:light
#     text_representation:
#       extension: .py
#       format_name: light
#       format_version: '1.5'
#       jupytext_version: 1.11.4
#   kernelspec:
#     display_name: Python 3
#     language: python
#     name: python3
# ---

# +
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import KFold
from sklearn.model_selection import GroupKFold
import pandas as pd
import seaborn as sns
# from attention import Attention
# import keras_self_attention
from tensorflow.keras import regularizers
from sklearn.linear_model import LinearRegression
import tensorflow.keras.backend as K
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.callbacks import *
from tensorflow.keras.initializers import *
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from tensorflow.keras.layers import Layer
from keras_self_attention import SeqSelfAttention
from keras.models import Model
from scipy.stats import gaussian_kde
from mpl_toolkits.axes_grid1 import make_axes_locatable

# %matplotlib inline

# +
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())

import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" 


# -

# # Model Structure

# ## model 1

def Bi_GRU_Attention1(dropout_rate=0.2, input_shape=(5, 690)):
    ipt   = keras.layers.Input(shape=input_shape)
    x = keras.layers.Bidirectional(keras.layers.GRU(64, return_sequences=True), input_shape=input_shape)(ipt)
    x = keras.layers.Dropout(dropout_rate)(x)
    x = keras.layers.BatchNormalization()(x)
    x = keras.layers.Bidirectional(keras.layers.GRU(32, return_sequences=True))(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    x = SeqSelfAttention()(x)

    x = tf.reduce_mean([x[:,i,:] for i in range(input_shape[0])], axis=0)
    x = keras.layers.Activation("relu")(x)
    out = keras.layers.Dense(1)(x)
    
    model = Model(ipt,out)
    return model


# ## model 2

def Bi_GRU_Attention2(dropout_rate=0.2, input_shape=(10, 345)):
    ipt   = keras.layers.Input(shape=input_shape)
    x = keras.layers.Bidirectional(keras.layers.GRU(128, return_sequences=True), input_shape=input_shape)(ipt)
    x = keras.layers.Dropout(dropout_rate)(x)
    x = keras.layers.BatchNormalization()(x)
    
    x = keras.layers.Bidirectional(keras.layers.GRU(64, return_sequences=True))(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    x = keras.layers.Bidirectional(keras.layers.GRU(32, return_sequences=True))(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    x = SeqSelfAttention()(x)

    x = tf.reduce_mean([x[:,i,:] for i in range(input_shape[0])], axis=0)
    x = keras.layers.Activation("relu")(x)
    out = keras.layers.Dense(1)(x)
    
    model = Model(ipt,out)
    return model


# ## model 3

def Bi_GRU_Attention3(dropout_rate=0.2, input_shape=(15, 230)):
    ipt   = keras.layers.Input(shape=input_shape)
    x = keras.layers.Bidirectional(keras.layers.GRU(128, return_sequences=True), input_shape=input_shape)(ipt)
    x = keras.layers.Dropout(dropout_rate)(x)
    x = keras.layers.BatchNormalization()(x)
    
    x = keras.layers.Bidirectional(keras.layers.GRU(64, return_sequences=True))(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    x = keras.layers.Bidirectional(keras.layers.GRU(32, return_sequences=True))(x)
    x = keras.layers.Dropout(dropout_rate)(x)
    
    x = SeqSelfAttention()(x)

    x = tf.reduce_mean([x[:,i,:] for i in range(input_shape[0])], axis=0)
    x = keras.layers.Activation("relu")(x)
    out = keras.layers.Dense(1)(x)
    
    model = Model(ipt,out)
    return model


# # Loading LAMOST-APOGEE data

# ## Processed spectral data

Flux_3sigma = np.load("../1_FITS_files_download_and_preprocessing/spectra_after_processing/3sigma/BR_Flux_Preprocessing_payne_above_50.npy")

print(Flux_3sigma.shape)

# ## Labels

# 原始标签
# Original labels
label = np.load( '../1_FITS_files_download_and_preprocessing/LABELS/above_50.npy', allow_pickle=True)
print(label.shape)

# +
label = pd.DataFrame(label, columns=['Teff[K]', 'Logg', 'CH', 'NH', 'OH', 'MgH', 'AlH', 'SiH', 'SH', 
                                     'KH', 'CaH', 'TiH', 'CrH','MnH', 'FeH', 'NiH', 'snrg'])

label['Teff[K]'] = np.log10(label['Teff[K]'].astype(np.float64))
label

# +
# 标签标准化
# Label Standardization
label_std = (label.iloc[:, :17] - label.iloc[:, :17].mean()) / np.sqrt(label.iloc[:, :17].var())

label_std['snrg'] = label['snrg']
label_std
# -

# ## Data set division

random_state = 0

# +
# 训练集光谱、测试集光谱、训练集标签、测试集标签
# Training set spectra, test set spectra, training set labels, test set labels
X_train_list, X_test_list, y_train_list, y_test_list = [], [], [], []

Flux_3sigma_train_list, Flux_3sigma_test_list = [], []

X_train, X_test, y_train, y_test = train_test_split(Flux_3sigma, label_std, test_size=0.2, random_state=random_state)

# 训练集标准化
# Training set normalization
Flux_3sigma_sc = StandardScaler()   
X_train_T = Flux_3sigma_sc.fit_transform(X_train.T)   # 对每条光谱数据进行标准化    
X_train = X_train_T.T

# 测试集标准化
# Test set standardization
Flux_3sigma_sc2 = StandardScaler() 
X_test_T = Flux_3sigma_sc2.fit_transform(X_test.T)
X_test = X_test_T.T

X_train_list.append(X_train.astype('float64'))
y_train_list.append(np.array(y_train).astype('float64'))
X_test_list.append(X_test.astype('float64'))
y_test_list.append(np.array(y_test).astype('float64'))
# -

# print(X_train_list[0].shape)
# print(y_train_list[0].shape)
print(X_test_list[0].shape)
print(y_test_list[0].shape)

# 检查空值
# Check Null
# print(pd.isnull(X_train_list[0]).any())
# print(pd.isnull(y_train_list[0]).any())
print(pd.isnull(X_test_list[0]).any())
print(pd.isnull(y_test_list[0]).any())

# 填充空值
# Fill null
# pd.DataFrame(X_train_list[0]).fillna(1,inplace = True)
pd.DataFrame(X_test_list[0]).fillna(1,inplace = True)

# 再次检查空值
# Check Null again
print(pd.isnull(X_test_list[0]).any())

# 保存测试集，用于模型不确定性估计
# Save test sets for model uncertainty estimation
np.save("../data/test_set/X_test_above_50", X_test_list[0])

# +
# 训练集信噪比、测试集信噪比
# Training set S/N ratio, test set S/N ratio
y_train_list_snrg, y_test_list_snrg = [], []

_, _, y_train_snrg, y_test_snrg = train_test_split(Flux_3sigma, np.array(label_std['snrg']), test_size=0.2, 
                                                   random_state=random_state)

y_train_list_snrg.append(y_train_snrg.astype('float64'))
y_test_list_snrg.append(y_test_snrg.astype('float64'))
# -

print(y_train_list_snrg[0].shape)
print(y_test_list_snrg[0].shape)

np.save('../data/Bi-GRU-Attention_labels/above_50/y_test_list_snrg', y_test_list_snrg[0])


# # Key Function Definition

# ## Fitting different BGANet models

# ### model 1

def fit_parameters1(train_list_index=0, 
                   param_index=0, 
                   epochs=50, 
                   batch_size=64, 
                   lr_rate=0.001, 
                   patience=5, 
                   seq_num=5,
                   seq_length=690, 
                   model=Bi_GRU_Attention1(), 
                   model_dir='../logs_above_50_model1/', 
                   label_std=label_std,     
                   label=label,     
                   random_state=0):
    
    # 选取的特征名
    # Selected feature names
    param_name = label.columns[param_index]
    print("Fitting parameter: ", param_name)
    
    global X_train_list
    global X_test_list
    global y_train_list
    global y_test_list
    
    # 定义训练集和测试集
    # Define training set and test set
    X_train = X_train_list[train_list_index]
    X_train = np.reshape(X_train, (X_train.shape[0], seq_num, seq_length))
    
    y_train = y_train_list[train_list_index][:, param_index]

    X_test = X_test_list[train_list_index]
    X_test = np.reshape(X_test, (X_test.shape[0], seq_num, seq_length))
    
    y_test = y_test_list[train_list_index][:, param_index]
    
    # ---------------------模型定义------------------------
    #     Model Definition
    model = model
    model.summary()
    
    # ---------------------模型编译-------------------------
    #     Model Compilation
    model.compile(optimizer=tf.keras.optimizers.Adam(lr_rate), loss='mse')

    log_name = '/model1/'  + str(param_name) + '_epochs_'  + str(epochs)
    log_dir_name = model_dir + log_name
    
    tbCallBack = keras.callbacks.TensorBoard(log_dir=log_dir_name)
    #使用验证损失作为监测数据，早停轮数设置为5
    # Using verification loss as monitoring data, the number of early stop rounds is set to 5
    tbCallBack2 = keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience)  

    # 在训练集中划分出验证集
    #  Divide the validation set from the training set
    X_train_va, X_valid_va, y_train_va, y_valid_va = train_test_split(X_train, y_train, test_size=0.1, random_state=random_state)
    
    # ---------------------模型训练-------------------------
    #     Model Training
    history = model.fit(X_train_va, y_train_va, batch_size=batch_size, epochs=epochs,  
                        validation_data=(X_valid_va, y_valid_va), 
                        callbacks=[tbCallBack, tbCallBack2])
    
    # 获取验证集上的预测结果
    #     Get the prediction results on the validation set
    y_pred_va = model.predict(X_valid_va).squeeze()
    
    # 获取测试集上的测试结果
    #     Get the test results on the test set
    y_pred = model.predict(X_test)
    y_pred = y_pred.squeeze()
    print(y_pred.shape)
    
    y_test = y_test * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    y_pred =  y_pred * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    print(y_pred.shape)
    
    y_valid_va = y_valid_va * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    y_pred_va =  y_pred_va * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    
    if param_index==0:
        y_valid_va = 10**y_valid_va
        y_pred_va = 10**y_pred_va
        y_test = 10**y_test
        y_pred = 10**y_pred

    # 画出训练和验证误差图
    #     Draw training and validation error diagrams
    plt.plot(history.history['loss'], label="training loss")
    plt.plot(history.history['val_loss'],  label="validation_loss")
    plt.legend()
    
    # 保存模型
    #     Save model
    model_file = "../data/Bi-GRU-Attention_models/above_50/model1/" +  str(param_name) + '_model.h5'
    model.save_weights(model_file)
    
    # 保存测试集标签
    #     Save test set labels
    label_file = "../data/Bi-GRU-Attention_labels/above_50/model1/" +  str(param_name) +  '_label'
    np.save(label_file, y_test)
    
    # 保存测试集预测结果
    #     Save test set prediction results
    prediction_file = "../data/Bi-GRU-Attention_predictions/above_50/model1/" +  str(param_name)  + '_prediction'
    np.save(prediction_file, y_pred)
    
    return y_test, y_pred, y_valid_va, y_pred_va


# ### model 2

# +

def fit_parameters2(train_list_index=0, 
                   param_index=0, 
                   epochs=50, 
                   batch_size=64, 
                   lr_rate=0.001, 
                   patience=5, 
                   seq_num=10,
                   seq_length=345, 
                   model=Bi_GRU_Attention2(), 
                   model_dir='../logs_above_50_model2/', 
                   label_std=label_std,     
                   label=label,     
                   random_state=0):
    
    param_name = label.columns[param_index]
    print("Fitting parameter: ", param_name)
    
    global X_train_list
    global X_test_list
    global y_train_list
    global y_test_list
    
    X_train = X_train_list[train_list_index]
    X_train = np.reshape(X_train, (X_train.shape[0], seq_num, seq_length))
    
    y_train = y_train_list[train_list_index][:, param_index]

    X_test = X_test_list[train_list_index]
    X_test = np.reshape(X_test, (X_test.shape[0], seq_num, seq_length))
    
    y_test = y_test_list[train_list_index][:, param_index]
    
    model = model
    model.summary()
    
    model.compile(optimizer=tf.keras.optimizers.Adam(lr_rate), loss='mse')

    log_name = '/model2/'  + str(param_name) + '_epochs_'  + str(epochs)
    log_dir_name = model_dir + log_name
    

    tbCallBack = keras.callbacks.TensorBoard(log_dir=log_dir_name)
    tbCallBack2 = keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience)  

    # 在训练集中划分出验证集
    X_train_va, X_valid_va, y_train_va, y_valid_va = train_test_split(X_train, y_train, test_size=0.1, random_state=random_state)
    
    # ---------------------模型训练-------------------------
#     history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1, callbacks=[tbCallBack])
    history = model.fit(X_train_va, y_train_va, batch_size=batch_size, epochs=epochs,  
                        validation_data=(X_valid_va, y_valid_va), 
                        callbacks=[tbCallBack, tbCallBack2])
    
    y_pred_va = model.predict(X_valid_va).squeeze()
    
    y_pred = model.predict(X_test)
    y_pred = y_pred.squeeze()
    print(y_pred.shape)
    
    y_test = y_test * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    y_pred =  y_pred * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    print(y_pred.shape)
    
    y_valid_va = y_valid_va * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    y_pred_va =  y_pred_va * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    
    if param_index==0:
        y_valid_va = 10**y_valid_va
        y_pred_va = 10**y_pred_va
        y_test = 10**y_test
        y_pred = 10**y_pred

    plt.plot(history.history['loss'], label="training loss")
    plt.plot(history.history['val_loss'],  label="validation_loss")
    plt.legend()
    
    model_file = "../data/Bi-GRU-Attention_models/above_50/model2/" +  str(param_name) + '_model.h5'
    model.save_weights(model_file)
    
    label_file = "../data/Bi-GRU-Attention_labels/above_50/model2/" +  str(param_name) +  '_label'
    np.save(label_file, y_test)
    
    prediction_file = "../data/Bi-GRU-Attention_predictions/above_50/model2/" +  str(param_name)  + '_prediction'
    np.save(prediction_file, y_pred)
    
    return y_test, y_pred, y_valid_va, y_pred_va


# -

# ### model 3

# +

def fit_parameters3(train_list_index=0, 
                   param_index=0, 
                   epochs=50, 
                   batch_size=64, 
                   lr_rate=0.001, 
                   patience=5, 
                   seq_num=15,
                   seq_length=230, 
                   model=Bi_GRU_Attention3(), 
                   model_dir='../logs_above_50_model3/', 
                   label_std=label_std,     # 标准化后的标签
                   label=label,     #  标准化前的标签
                   random_state=0):
    
    # 选取的特征名
    # Selected feature names
    param_name = label.columns[param_index]
    print("Fitting parameter: ", param_name)
    
    global X_train_list
    global X_test_list
    global y_train_list
    global y_test_list
    
    # 定义训练集和测试集
    # Define training set and test set
    X_train = X_train_list[train_list_index]
    X_train = np.reshape(X_train, (X_train.shape[0], seq_num, seq_length))
    
    y_train = y_train_list[train_list_index][:, param_index]

    X_test = X_test_list[train_list_index]
    X_test = np.reshape(X_test, (X_test.shape[0], seq_num, seq_length))
    
    y_test = y_test_list[train_list_index][:, param_index]
    
    # ---------------------模型定义------------------------
    #     Model Definition
    model = model
    model.summary()
    
    # ---------------------模型编译-------------------------
    #     Model Compilation
    model.compile(optimizer=tf.keras.optimizers.Adam(lr_rate), loss='mse')

    """参数说明：
    monitor：需要监视的量，如’val_loss’, ‘val_acc’, ‘acc’, ‘loss’。
    patience：能够容忍多少个epoch内都没有improvement。
    verbose：信息展示模式
    mode：‘auto’，‘min’，‘max’之一，在min模式下，如果检测值停止下降则中止训练。在max模式下，当检测值不再上升则停止训练。
    例如，当监测值为val_acc时，模式应为max，当检测值为val_loss时，模式应为min。在auto模式下，评价准则由被监测值的名字自动推。
    """
    # 日志名称
    log_name = '/model3/'  + str(param_name) + '_epochs_'  + str(epochs)
    log_dir_name = model_dir + log_name
    
    # 日志
    tbCallBack = keras.callbacks.TensorBoard(log_dir=log_dir_name)
    #使用验证损失作为监测数据，早停轮数设置为5
    tbCallBack2 = keras.callbacks.EarlyStopping(monitor='val_loss', patience=patience)  

    # 在训练集中划分出验证集
    X_train_va, X_valid_va, y_train_va, y_valid_va = train_test_split(X_train, y_train, test_size=0.1, random_state=random_state)
    
    # ---------------------模型训练-------------------------
    #     Model Training
#     history = model.fit(X_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1, callbacks=[tbCallBack])
    history = model.fit(X_train_va, y_train_va, batch_size=batch_size, epochs=epochs,  
                        validation_data=(X_valid_va, y_valid_va), 
                        callbacks=[tbCallBack, tbCallBack2])
    
    # 获取验证集上的预测结果
    #     Get the prediction results on the validation set
    y_pred_va = model.predict(X_valid_va).squeeze()
    
    # 获取测试集上的测试结果
    #     Get the test results on the test set
    y_pred = model.predict(X_test)
    y_pred = y_pred.squeeze()
    print(y_pred.shape)
    
    y_test = y_test * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    y_pred =  y_pred * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    print(y_pred.shape)
    
    y_valid_va = y_valid_va * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    y_pred_va =  y_pred_va * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    
    if param_index==0:
        y_valid_va = 10**y_valid_va
        y_pred_va = 10**y_pred_va
        y_test = 10**y_test
        y_pred = 10**y_pred

    # 画出训练和验证误差图
    #     Draw training and validation error diagrams
    plt.plot(history.history['loss'], label="training loss")
    plt.plot(history.history['val_loss'],  label="validation_loss")
    plt.legend()
    
    # 保存模型
    #     Save model
    model_file = "../data/Bi-GRU-Attention_models/above_50/model3/" +  str(param_name) + '_model.h5'
    model.save_weights(model_file)
    
    # 保存测试集标签(多个模型只需要保存一次)
    #     Save test set labels
    label_file = "../data/Bi-GRU-Attention_labels/above_50/model3/" +  str(param_name) +  '_label'
    np.save(label_file, y_test)
    
    # 保存测试集预测结果
    #     Save test set prediction results
    prediction_file = "../data/Bi-GRU-Attention_predictions/above_50/model3/" +  str(param_name)  + '_prediction'
    np.save(prediction_file, y_pred)
    
    return y_test, y_pred, y_valid_va, y_pred_va


# -

# ### Blending

# +
import joblib

# 传入参数：三个模型验证集预测结果、三个模型测试集预测结果、测试集标签、拟合的参数名称、参数所在的索引
"""
parameters: 
validation set predictions of three BGANet models, 
test set predictions  of three BGANet models, 
test set labels, names of fitted parameters, indexes where the parameters are located
"""
def blending(y_pred_va, y_pred_va2, y_pred_va3, 
             y_pred, y_pred_2, y_pred_3, 
             y_valid, y_test, 
             param_name='Teff',
             param_index=0,
             train_list_index=0):

    y_pred_va_Param_1 = pd.DataFrame(y_pred_va)
    y_pred_va_Param_2 = pd.DataFrame(y_pred_va2)
    y_pred_va_Param_3 = pd.DataFrame(y_pred_va3)
    y_pred_va_Param =  pd.concat([y_pred_va_Param_1, y_pred_va_Param_2, y_pred_va_Param_3], axis=1)
    
    y_pred_Param_1 = pd.DataFrame(y_pred, columns=['BGANet1'])
    y_pred_Param_2 = pd.DataFrame(y_pred_2, columns=['BGANet2'])
    y_pred_Param_3 = pd.DataFrame(y_pred_3, columns=['BGANet3'])
    y_pred_Param =  pd.concat([y_pred_Param_1, y_pred_Param_2, y_pred_Param_3], axis=1)
    
    # 多元线性回归
    #     Multiple linear regression
    model = LinearRegression()
    model.fit(y_pred_va_Param, y_valid)
    y_pred_blending_Param = model.predict(y_pred_Param)
    
    # 绘制结果
    #     Plotting results
    label_name = str(param_name) + "_Payne"
    prediction_name = str(param_name) + "_Prediction"
    plot_result(y_test, y_pred_blending_Param, label_name, prediction_name, param_index=param_index)
    
    # 计算各项损失
    #     Calculate each loss
    calculate_loss(y_test, y_pred_blending_Param,  param_index=param_index)
    
    # 画分布
    #     Plot distribution
    delta_name = '$\Delta$' + str(param_name)
    plot_distribution(y_test, y_pred_blending_Param,  delta_name)
    
    # 保存模型
    #     Save model
    blending_model_path = "../data/Bi-GRU-Attention_models/above_50/Blending/" +  str(param_name) + '_model.joblib'
    joblib.dump(model, blending_model_path)       
    #     clf = joblib.load('3.joblib')      
    
    # 保存测试集预测结果
    #     Save test set prediction results
    blending_result_path = "../data/Bi-GRU-Attention_predictions/above_50/Blending/" + str(param_name) + '_prediction'
    np.save(blending_result_path, y_pred_blending_Param)


# -

# ## Scatter density

# +
# 参数列表：真实值、预测值、横坐标名称、纵坐标名称、拟合目标参数所在索引
# Parameter list: 
# true value, predicted value, horizontal coordinate name, vertical coordinate name, index where the fitted target parameters are located
def plot_result(y_test, y_pred, xlabel_name, ylabel_name, param_index=0):
    
    # 选取的特征名
    param_name = label.columns[param_index]
    print("Fitting parameter: ", param_name)
    
#     # 将拟合结果还原到原始数据空间
#     y_test = y_test * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
#     y_pred =  y_pred * np.sqrt(label.var()[param_index]) + label.mean()[param_index]
    
#     # 散点图
#     plt.figure(figsize=(5, 5))
#     plt.plot(y_test[:], y_pred[:], 'b.', markersize=5)
#     plt.plot(y_test[:], y_test[:], 'k-', linewidth=1)
#     plt.xlim([y_test[:].min(), y_test[:].max()])
#     plt.ylim([y_test[:].min(), y_test[:].max()])
    
#     # 横坐标
#     xlabel_name = str(param_name) + '(True)' 
#     # 纵坐标
#     ylabel_name = str(param_name) + '(Prediction)'
#     plt.xlabel(xlabel_name)
#     plt.ylabel(ylabel_name)
    
    
    # 散点密度图
    fig, ax = plt.subplots(figsize=(5, 5), dpi=100)
    x, y = y_test[:], y_pred[:]  # 横轴、纵轴
    xy = np.vstack([x,y])
    z = gaussian_kde(xy)(xy)
    
    z = z / z.max()
    
    idx = z.argsort()
    x, y, z = x[idx], y[idx], z[idx]
#     scaler_x = MinMaxScaler()
#     scaler_y = MinMaxScaler()
#     scaler_z = MinMaxScaler()
#     x = scaler_x.fit_transform(x.reshape(-1, 1))
#     y = scaler_y.fit_transform(y.reshape(-1, 1))
#     z = scaler_z.fit_transform(z.reshape(-1, 1))

#     bounds = np.arange(0, 1.05, 0.05)
#     norm = col.BoundaryNorm(bounds, ncolors=20)
    scatter = ax.scatter(x, y, marker='o', c=z, s=5, label='LST', cmap='Spectral_r')
    plt.plot(y_test[:], y_test[:], 'k-', linewidth=1)
    plt.xlim([y_test[:].min(), y_test[:].max()])
    plt.ylim([y_test[:].min(), y_test[:].max()])

    # 横纵坐标标签名直接自己打
    plt.xlabel(xlabel_name)
    plt.ylabel(ylabel_name)
    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad=0.1)
    cbar = fig.colorbar(scatter, cax=cax, label='frequency')
    plt.show()


# -

# ## Probability density distribution

def plot_distribution(y_test, y_pred, xlabel_name, param_index=0):
    
    diff = y_test - y_pred
    sns.distplot(diff)
    plt.xlabel(xlabel_name)
    plt.ylabel('Density')


# ## Error calculation

# +
from sklearn.metrics import mean_absolute_error, mean_squared_error

def calculate_loss(y_test, y_pred, param_index=0): 

    param_name = label.columns[param_index]
    print("----------------Loss of ", param_name, "----------------")
    print(" ")
    
    mae = mean_absolute_error(y_test[:], y_pred[:])
    mse = mean_squared_error(y_test[:], y_pred[:])
    rmse = np.sqrt(mse)
    
    print('MAE = ', mae)
    print(" ")
    print('MSE = ', mse)
    print(" ")
    print('RMSE = ', rmse)
    print(" ")

    diff = y_pred[:] - y_test[:]    # Residuals
    mu = np.mean(diff)          # Mean value of residuals
    sigma = np.std(diff)          # Standard deviation of residuals
    print("residual error mu = ", mu)
    print(" ")
    print('residual error sigma = ', sigma)
    
    return mae, mse, rmse, mu, sigma


# -
# ## Results

def result(y_test, y_pred, param_index=0):
    
    param_name = label.columns[param_index]
    
    label_name = str(param_name) + "_Payne"  
    prediction_name = str(param_name) + "_Prediction"    
    plot_result(y_test, y_pred, label_name, prediction_name, param_index=param_index)
    
    calculate_loss(y_test, y_pred,  param_index=param_index)
    
    delta_name = '$\Delta$' + str(param_name)
    plot_distribution(y_test, y_pred,  delta_name)


# # Fitting results

# ### Teff

# #### model 1

# +

model = Bi_GRU_Attention1()

y_test, y_pred, y_valid_va, y_pred_va = fit_parameters1(train_list_index=0,
        param_index=0,
        epochs=50,
        batch_size=16,
        patience=5,
        lr_rate=0.001, 
        model=model)

result(y_test, y_pred)
# -

# #### model 2

# +

model = Bi_GRU_Attention2()


y_test_2, y_pred_2, y_valid_va2, y_pred_va2 = fit_parameters2(train_list_index=0,
        param_index=0,
        epochs=50,
        batch_size=16,
        patience=5,
        lr_rate=0.001, 
        model=model)

result(y_test_2, y_pred_2)
# -

# #### model 3

# +

model = Bi_GRU_Attention3()


y_test_3, y_pred_3, y_valid_va3, y_pred_va3 = fit_parameters3(train_list_index=0,
        param_index=0,
        epochs=50,
        batch_size=32,
        patience=5,
        lr_rate=0.001, 
        model=model)

result(y_test_3, y_pred_3)
# -

# #### Blending

# +

blending(y_pred_va, y_pred_va2, y_pred_va3, 
         y_pred, y_pred_2, y_pred_3,
         y_valid_va, y_test, 
         'Teff', 0)
# -

# ### logg

# #### model1

# +
model = Bi_GRU_Attention1()


y_test_logg1, y_pred_logg1, y_valid_va_logg1, y_pred_va_logg1 = fit_parameters1(train_list_index=0,
    param_index=1,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_logg1, y_pred_logg1)
# -
# #### model2

# +
model = Bi_GRU_Attention2()


y_test_logg2, y_pred_logg2, y_valid_va_logg2, y_pred_va_logg2 = fit_parameters2(train_list_index=0,
    param_index=1,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_logg2, y_pred_logg2)
# -

# #### model3

# +
model = Bi_GRU_Attention3()


y_test_logg3, y_pred_logg3, y_valid_va_logg3, y_pred_va_logg3 = fit_parameters3(train_list_index=0,
    param_index=1,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_logg3, y_pred_logg3)
# -

# #### Blending

# +

blending(y_pred_va_logg1, y_pred_va_logg2, y_pred_va_logg3, 
         y_pred_logg1, y_pred_logg2, y_pred_logg3,
         y_valid_va_logg1, y_test_logg1, 
         'Logg', 1)    
# -

# ### [C/H]

# #### MODEL1

# +
model = Bi_GRU_Attention1()


y_test_ch1, y_pred_ch1, y_valid_va_ch1, y_pred_va_ch1 = fit_parameters1(train_list_index=0,
    param_index=2,
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

# 2. 结果绘制与误差计算
result(y_test_ch1, y_pred_ch1)
# -

# #### MODEL2

# +

model = Bi_GRU_Attention2()


y_test_ch2, y_pred_ch2, y_valid_va_ch2, y_pred_va_ch2 = fit_parameters2(train_list_index=0,
    param_index=2,
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

# 2. 结果绘制与误差计算
result(y_test_ch2, y_pred_ch2)
# -

# #### MODEL3

# +

model = Bi_GRU_Attention3()


y_test_ch3, y_pred_ch3, y_valid_va_ch3, y_pred_va_ch3 = fit_parameters3(train_list_index=0,
    param_index=2,
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_ch3, y_pred_ch3)
# -

# #### Blending

# +

blending(y_pred_va_ch1, y_pred_va_ch2, y_pred_va_ch3, 
         y_pred_ch1, y_pred_ch2, y_pred_ch3,
         y_valid_va_ch1, y_test_ch1, 
         'ch', 2)
# -

# ### NH

# #### MODEL1

# +
model = Bi_GRU_Attention1()


y_test_nh1, y_pred_nh1, y_valid_va_nh1, y_pred_va_nh1 = fit_parameters1(train_list_index=0,
    param_index=3,
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_nh1, y_pred_nh1)
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_nh2, y_pred_nh2, y_valid_va_nh2, y_pred_va_nh2 = fit_parameters2(train_list_index=0,
    param_index=3,
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_nh2, y_pred_nh2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_nh3, y_pred_nh3, y_valid_va_nh3, y_pred_va_nh3 = fit_parameters3(train_list_index=0,
    param_index=3,
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_nh3, y_pred_nh3)
# -

# #### BLENDING

# +

blending(y_pred_va_nh1, y_pred_va_nh2, y_pred_va_nh3, 
         y_pred_nh1, y_pred_nh2, y_pred_nh3,
         y_valid_va_nh1, y_test_nh1, 
         'nh', 3)
# -

# ### OH

# #### MODEL1

# +
model = Bi_GRU_Attention1()


y_test_OH1, y_pred_OH1, y_valid_va_OH1, y_pred_va_OH1 = fit_parameters1(train_list_index=0,
    param_index=4,  
    epochs=50,
    batch_size=32,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_OH1, y_pred_OH1)
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_OH2, y_pred_OH2, y_valid_va_OH2, y_pred_va_OH2 = fit_parameters2(train_list_index=0,
    param_index=4,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_OH2, y_pred_OH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_OH3, y_pred_OH3, y_valid_va_OH3, y_pred_va_OH3 = fit_parameters3(train_list_index=0,
    param_index=4,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_OH3, y_pred_OH3)
# -

# #### BLENDING

# +

blending(y_pred_va_OH1, y_pred_va_OH2, y_pred_va_OH3, 
         y_pred_OH1, y_pred_OH2, y_pred_OH3,
         y_valid_va_OH1, y_test_OH1, 
         'OH', 4)    
# -

# ### MgH

# #### MODEL1

# +
model = Bi_GRU_Attention1()


y_test_MgH1, y_pred_MgH1, y_valid_va_MgH1, y_pred_va_MgH1 = fit_parameters1(train_list_index=0,
    param_index=5,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_MgH1, y_pred_MgH1)
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_MgH2, y_pred_MgH2, y_valid_va_MgH2, y_pred_va_MgH2 = fit_parameters2(train_list_index=0,
    param_index=5,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_MgH2, y_pred_MgH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_MgH3, y_pred_MgH3, y_valid_va_MgH3, y_pred_va_MgH3 = fit_parameters3(train_list_index=0,
    param_index=5,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_MgH3, y_pred_MgH3)
# -

# #### BLENDING

# +

blending(y_pred_va_MgH1, y_pred_va_MgH2, y_pred_va_MgH3, 
         y_pred_MgH1, y_pred_MgH2, y_pred_MgH3,
         y_valid_va_MgH1, y_test_MgH1, 
         'MgH', 5)    
# -

# ### AlH

# #### MODEL1

# +
model = Bi_GRU_Attention1()


y_test_AlH1, y_pred_AlH1, y_valid_va_AlH1, y_pred_va_AlH1 = fit_parameters1(train_list_index=0,
    param_index=6,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_AlH1, y_pred_AlH1)
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_AlH2, y_pred_AlH2, y_valid_va_AlH2, y_pred_va_AlH2 = fit_parameters2(train_list_index=0,
    param_index=6,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_AlH2, y_pred_AlH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_AlH3, y_pred_AlH3, y_valid_va_AlH3, y_pred_va_AlH3 = fit_parameters3(train_list_index=0,
    param_index=6,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_AlH3, y_pred_AlH3)
# -

# #### BLENDING

# +

blending(y_pred_va_AlH1, y_pred_va_AlH2, y_pred_va_AlH3, 
         y_pred_AlH1, y_pred_AlH2, y_pred_AlH3,
         y_valid_va_AlH1, y_test_AlH1, 
         'AlH', 6)    
# -

# ### SiH

# #### MODEL1

# +
model = Bi_GRU_Attention1()


y_test_SiH1, y_pred_SiH1, y_valid_va_SiH1, y_pred_va_SiH1 = fit_parameters1(train_list_index=0,
    param_index=7,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_SiH1, y_pred_SiH1, 7)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_SiH2, y_pred_SiH2, y_valid_va_SiH2, y_pred_va_SiH2 = fit_parameters2(train_list_index=0,
    param_index=7,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_SiH2, y_pred_SiH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_SiH3, y_pred_SiH3, y_valid_va_SiH3, y_pred_va_SiH3 = fit_parameters3(train_list_index=0,
    param_index=7,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_SiH3, y_pred_SiH3)
# -

# #### BLENDING

# +

blending(y_pred_va_SiH1, y_pred_va_SiH2, y_pred_va_SiH3, 
         y_pred_SiH1, y_pred_SiH2, y_pred_SiH3,
         y_valid_va_SiH1, y_test_SiH1, 
         'SiH', 7)    
# -

# ### SH

# #### model 1

# +
model = Bi_GRU_Attention1()


y_test_SH1, y_pred_SH1, y_valid_va_SH1, y_pred_va_SH1 = fit_parameters1(train_list_index=0,
    param_index=8,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_SH1, y_pred_SH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_SH2, y_pred_SH2, y_valid_va_SH2, y_pred_va_SH2 = fit_parameters2(train_list_index=0,
    param_index=8,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_SH2, y_pred_SH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_SH3, y_pred_SH3, y_valid_va_SH3, y_pred_va_SH3 = fit_parameters3(train_list_index=0,
    param_index=8,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_SH3, y_pred_SH3)
# -

# #### BLENDING

# +

blending(y_pred_va_SH1, y_pred_va_SH2, y_pred_va_SH3, 
         y_pred_SH1, y_pred_SH2, y_pred_SH3,
         y_valid_va_SH1, y_test_SH1, 
         'SH', 8)    
# -

# ### **KH**

# #### model 1

# +
model = Bi_GRU_Attention1()


y_test_KH1, y_pred_KH1, y_valid_va_KH1, y_pred_va_KH1 = fit_parameters1(train_list_index=0,
    param_index=9,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.0005, 
    patience=3,
    model=model)

result(y_test_KH1, y_pred_KH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_KH2, y_pred_KH2, y_valid_va_KH2, y_pred_va_KH2 = fit_parameters2(train_list_index=0,
    param_index=9,     
    epochs=50,
    batch_size=16,
    lr_rate=0.0005,
    patience=3,
    model=model)

result(y_test_KH2, y_pred_KH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_KH3, y_pred_KH3, y_valid_va_KH3, y_pred_va_KH3 = fit_parameters3(train_list_index=0,
    param_index=9,    
    epochs=50,
    batch_size=32,
    lr_rate=0.0005,
    patience=3, 
    model=model)

result(y_test_KH3, y_pred_KH3)
# -

# #### BLENDING

# +

blending(y_pred_va_KH1, y_pred_va_KH2, y_pred_va_KH3, 
         y_pred_KH1, y_pred_KH2, y_pred_KH3,
         y_valid_va_KH1, y_test_KH1, 
         'KH', 9)    
# -

# ### CaH

# #### model1

# +
model = Bi_GRU_Attention1()


y_test_CaH1, y_pred_CaH1, y_valid_va_CaH1, y_pred_va_CaH1 = fit_parameters1(train_list_index=0,
    param_index=10,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_CaH1, y_pred_CaH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_CaH2, y_pred_CaH2, y_valid_va_CaH2, y_pred_va_CaH2 = fit_parameters2(train_list_index=0,
    param_index=10,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_CaH2, y_pred_CaH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_CaH3, y_pred_CaH3, y_valid_va_CaH3, y_pred_va_CaH3 = fit_parameters3(train_list_index=0,
    param_index=10,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_CaH3, y_pred_CaH3)
# -

# #### BLENDING

# +

blending(y_pred_va_CaH1, y_pred_va_CaH2, y_pred_va_CaH3, 
         y_pred_CaH1, y_pred_CaH2, y_pred_CaH3,
         y_valid_va_CaH1, y_test_CaH1, 
         'CaH', 10)    
# -

# ### TiH

# #### model1

# +
model = Bi_GRU_Attention1()


y_test_TiH1, y_pred_TiH1, y_valid_va_TiH1, y_pred_va_TiH1 = fit_parameters1(train_list_index=0,
    param_index=11,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.0005, 
    patience=3,
    model=model)

result(y_test_TiH1, y_pred_TiH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_TiH2, y_pred_TiH2, y_valid_va_TiH2, y_pred_va_TiH2 = fit_parameters2(train_list_index=0,
    param_index=11,     
    epochs=50,
    batch_size=16,
    lr_rate=0.0005,
    patience=3,
    model=model)

result(y_test_TiH2, y_pred_TiH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_TiH3, y_pred_TiH3, y_valid_va_TiH3, y_pred_va_TiH3 = fit_parameters3(train_list_index=0,
    param_index=11,    
    epochs=50,
    batch_size=32,
    lr_rate=0.0005,
    patience=3, 
    model=model)

result(y_test_TiH3, y_pred_TiH3)
# -

# #### BLENDING

# +

blending(y_pred_va_TiH1, y_pred_va_TiH2, y_pred_va_TiH3, 
         y_pred_TiH1, y_pred_TiH2, y_pred_TiH3,
         y_valid_va_TiH1, y_test_TiH1, 
         'TiH', 11)    
# -

# ### CrH

# #### model 1

# +
model = Bi_GRU_Attention1()


y_test_CrH1, y_pred_CrH1, y_valid_va_CrH1, y_pred_va_CrH1 = fit_parameters1(train_list_index=0,
    param_index=12,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.0005, 
    patience=3,
    model=model)

result(y_test_CrH1, y_pred_CrH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_CrH2, y_pred_CrH2, y_valid_va_CrH2, y_pred_va_CrH2 = fit_parameters2(train_list_index=0,
    param_index=12,     
    epochs=50,
    batch_size=16,
    lr_rate=0.0005,
    patience=3,
    model=model)

result(y_test_CrH2, y_pred_CrH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_CrH3, y_pred_CrH3, y_valid_va_CrH3, y_pred_va_CrH3 = fit_parameters3(train_list_index=0,
    param_index=12,    
    epochs=50,
    batch_size=32,
    lr_rate=0.0005,
    patience=3, 
    model=model)

result(y_test_CrH3, y_pred_CrH3)
# -

# #### BLENDING

# +

blending(y_pred_va_CrH1, y_pred_va_CrH2, y_pred_va_CrH3, 
         y_pred_CrH1, y_pred_CrH2, y_pred_CrH3,
         y_valid_va_CrH1, y_test_CrH1, 
         'CrH', 12)    
# -

# ### MnH

# #### model1



# +
model = Bi_GRU_Attention1()


y_test_MnH1, y_pred_MnH1, y_valid_va_MnH1, y_pred_va_MnH1 = fit_parameters1(train_list_index=0,
    param_index=13,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.0005, 
    patience=3,
    model=model)

result(y_test_MnH1, y_pred_MnH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_MnH2, y_pred_MnH2, y_valid_va_MnH2, y_pred_va_MnH2 = fit_parameters2(train_list_index=0,
    param_index=13,     
    epochs=50,
    batch_size=16,
    lr_rate=0.0005,
    patience=3,
    model=model)

result(y_test_MnH2, y_pred_MnH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_MnH3, y_pred_MnH3, y_valid_va_MnH3, y_pred_va_MnH3 = fit_parameters3(train_list_index=0,
    param_index=13,    
    epochs=50,
    batch_size=32,
    lr_rate=0.0005,
    patience=3, 
    model=model)

result(y_test_MnH3, y_pred_MnH3)
# -

# #### BLENDING

# +

blending(y_pred_va_MnH1, y_pred_va_MnH2, y_pred_va_MnH3, 
         y_pred_MnH1, y_pred_MnH2, y_pred_MnH3,
         y_valid_va_MnH1, y_test_MnH1, 
         'MnH', 13)    
# -

# ### FeH

# #### model1

# +
model = Bi_GRU_Attention1()


y_test_FeH1, y_pred_FeH1, y_valid_va_FeH1, y_pred_va_FeH1 = fit_parameters1(train_list_index=0,
    param_index=14,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_FeH1, y_pred_FeH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_FeH2, y_pred_FeH2, y_valid_va_FeH2, y_pred_va_FeH2 = fit_parameters2(train_list_index=0,
    param_index=14,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_FeH2, y_pred_FeH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_FeH3, y_pred_FeH3, y_valid_va_FeH3, y_pred_va_FeH3 = fit_parameters3(train_list_index=0,
    param_index=14,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_FeH3, y_pred_FeH3)
# -

# #### BLENDING

# +

blending(y_pred_va_FeH1, y_pred_va_FeH2, y_pred_va_FeH3, 
         y_pred_FeH1, y_pred_FeH2, y_pred_FeH3,
         y_valid_va_FeH1, y_test_FeH1, 
         'FeH', 14)    
# -

# ### NiH

# #### model1

# +
model = Bi_GRU_Attention1()


y_test_NiH1, y_pred_NiH1, y_valid_va_NiH1, y_pred_va_NiH1 = fit_parameters1(train_list_index=0,
    param_index=15,    
    epochs=50,
    batch_size=16,   
    lr_rate=0.001, 
    patience=3,
    model=model)

result(y_test_NiH1, y_pred_NiH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_NiH2, y_pred_NiH2, y_valid_va_NiH2, y_pred_va_NiH2 = fit_parameters2(train_list_index=0,
    param_index=15,     
    epochs=50,
    batch_size=16,
    lr_rate=0.001,
    patience=3,
    model=model)

result(y_test_NiH2, y_pred_NiH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_NiH3, y_pred_NiH3, y_valid_va_NiH3, y_pred_va_NiH3 = fit_parameters3(train_list_index=0,
    param_index=15,    
    epochs=50,
    batch_size=32,
    lr_rate=0.001,
    patience=3, 
    model=model)

result(y_test_NiH3, y_pred_NiH3)
# -

# #### BLENDING

# +

blending(y_pred_va_NiH1, y_pred_va_NiH2, y_pred_va_NiH3, 
         y_pred_NiH1, y_pred_NiH2, y_pred_NiH3,
         y_valid_va_NiH1, y_test_NiH1, 
         'NiH', 15)    
# -

# ### CuH

# +
# 原始标签
label_above_50_payne = np.load( '../1_FITS文件下载与预处理/snrg_other/match_good/LABELS/above_50_payne.npy', allow_pickle=True)
print(label_above_50_payne.shape)

# 参数名称
label_above_50_payne = pd.DataFrame(label_above_50_payne, columns=['Teff[K]', 'Logg', 'CH', 'NH', 'OH', 'MgH', 'AlH', 'SiH', 'SH', 
                                     'KH', 'CaH', 'TiH', 'CrH','MnH', 'FeH', 'NiH', 'CuH', 'snrg'])
label_above_50_payne


# +
# 数据的标准化与子序列划分
def Flux_normalization(Flux=X_test_list[0], step_num=5, step_len=690):
#     Flux_3sigma_sc = StandardScaler()   
#     Flux_T = Flux_3sigma_sc.fit_transform(Flux.T)   # 对每条光谱数据进行标准化    
#     Flux = Flux_T.T
    Flux = np.reshape(Flux, (Flux.shape[0], step_num, step_len))
    return Flux

def predict_results_Abundance(Flux=X_test_list[0], 
                    step_num=5, 
                    step_len=690, 
                    param_index=16,
                    model=Bi_GRU_Attention1(), 
                    model_path="../data/Bi-GRU-Attention_models/between_5_50/model1/CuH_model.h5"):
    Flux = Flux_normalization(Flux, step_num, step_len)
    model = model
    model.compile(optimizer=tf.keras.optimizers.Adam(0.001), loss='mse')
    model.load_weights(model_path)
    scores = model(Flux, training=True)
    
    return scores


# -

# #### MODEL1

# +
y_pred_CuH1 = predict_results_Abundance(Flux=X_test_list[0],
                    step_num=5, 
                    step_len=690, 
                    param_index=16,
                    model=Bi_GRU_Attention1(), 
                    model_path="../data/Bi-GRU-Attention_models/between_5_50/model1/CuH_model.h5")

y_test_CuH1 = y_test_list[0][:, 16]
result(y_test_CuH1, y_pred_CuH1)  
# -

# #### MODEL2

# +
model = Bi_GRU_Attention2()


y_test_CuH2, y_pred_CuH2, y_valid_va_CuH2, y_pred_va_CuH2 = fit_parameters2(train_list_index=0,
    param_index=16,     
    epochs=50,
    batch_size=16,
    lr_rate=0.0005,
    patience=3,
    model=model)

result(y_test_CuH2, y_pred_CuH2)
# -

# #### MODEL3

# +
model = Bi_GRU_Attention3()


y_test_CuH3, y_pred_CuH3, y_valid_va_CuH3, y_pred_va_CuH3 = fit_parameters3(train_list_index=0,
    param_index=16,    
    epochs=50,
    batch_size=16,
    lr_rate=0.0005,
    patience=3, 
    model=model)

result(y_test_CuH3, y_pred_CuH3)
# -

# #### BLENDING

# +

blending(y_pred_va_CuH1, y_pred_va_CuH2, y_pred_va_CuH3, 
         y_pred_CuH1, y_pred_CuH2, y_pred_CuH3,
         y_valid_va_CuH1, y_test_CuH1, 
         'CuH', 16)    
# -


