{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Construction of Light curve Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Read ztf variable star data\n",
    "\n",
    "import pandas as pd\n",
    "file_path = \"data\\\\ZTF\\\\ztf_Table2.txt\"\n",
    "catalog = pd.read_table(file_path, header=None, delim_whitespace=True, skiprows=34)\n",
    "# Create a DataFrame and select the required columns\n",
    "columns = {\n",
    "    'ID': 0, 'SourceID': 1, 'RAdeg': 2, 'DEdeg': 3, 'Per_g': 10, 'Per_r': 11, \n",
    "    'Num_g': 12, 'Num_r': 13, 'R21_g': 14, 'R21_r': 15, 'phi21_g': 16, \n",
    "    'phi21_r': 17, 'R^2_g': 18, 'R^2_r': 19, 'Amp_g': 20, 'Amp_r': 21, \n",
    "    'FAP_g': 22, 'FAP_r': 23, 'Type': 24\n",
    "}\n",
    "df = catalog[list(columns.values())].copy()\n",
    "df.columns = columns.keys()\n",
    "\n",
    "# Filter the data with fap < 0.001\n",
    "df_filtered = df[(df.Num_g >= 20) & (df.Num_r >= 20) & (df.FAP_g < -3) & (df.FAP_r < -3)]\n",
    "print(df_filtered['Type'].value_counts())\n",
    "# df_filtered.to_csv(\"ZTF_classification_selected.csv\", index=False)\n",
    "\n",
    "# Unify variable star names\n",
    "label_mapping = {\n",
    "    \"Mira\": 'M',\n",
    "    \"RR\": 'RRAB',\n",
    "    \"RRc\": 'RRC',\n",
    "    \"CEP\": 'CEP',\n",
    "    \"CEPII\": 'CEP',\n",
    "    \"RSCVN\": 'ROT',\n",
    "    \"BYDra\": 'ROT'\n",
    "}\n",
    "# Create labels DataFrame and apply mapping\n",
    "df_labels = df_filtered[['ID', 'Type', 'FAP_g']].rename(columns={'Type': 'Class'})\n",
    "df_labels['Class'] = df_labels['Class'].replace(label_mapping)\n",
    "df_labels.to_csv(\"data\\\\labels\\\\ZTF_lables_g.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "# Unify variable star names\n",
    "label_mapping = {\n",
    "    \"RRD\": 'RRC',\n",
    "    \"HADS\": 'DSCT',\n",
    "    \"DCEP\": 'CEP',\n",
    "    \"DCEPS\": 'CEP',\n",
    "    \"CWA\": 'CEP',\n",
    "    \"CWB\": 'CEP',\n",
    "    \"RVA\": 'CEP',\n",
    "}\n",
    "\n",
    "# Load and filter ASAS-SN g-band data for variables with classification probability >= 0.99\n",
    "g = pd.read_csv(\"data\\\\ASAS-SN\\\\g\\\\asassn_variables_x.csv\")\n",
    "g_filtered = g[g['ML_probability'] >= 0.99][['ID', 'ML_classification', 'ML_probability']]\n",
    "g_filtered = g_filtered.rename(columns={'ML_classification': 'Class', 'ML_probability': 'class_probability'})\n",
    "g_filtered = g_filtered.sort_values(by=['Class', 'ID']).drop_duplicates('ID', keep=False)\n",
    "g_filtered['Class'] = g_filtered['Class'].replace(label_mapping)\n",
    "g_filtered.to_csv(\"data\\\\labels\\\\ASAS-SN_labels_g.csv\", index=False)\n",
    "\n",
    "# Load and filter ASAS-SN V-band data for variables with classification probability >= 0.99\n",
    "v = pd.read_csv(\"data\\\\ASAS-SN\\\\V\\\\asassn_catalog_full.csv\")\n",
    "v_filtered = v[v['class_probability'] >= 0.99][['asassn_name', 'variable_type', 'class_probability']]\n",
    "v_filtered = v_filtered.rename(columns={'asassn_name': 'ID', 'variable_type': 'Class'})\n",
    "v_filtered['Class'] = v_filtered['Class'].str.replace(':', '')\n",
    "v_filtered = v_filtered.sort_values(by=['Class', 'ID']).drop_duplicates('ID', keep=False)\n",
    "v_filtered['Class'] = v_filtered['Class'].replace(label_mapping)\n",
    "v_filtered.to_csv(\"data\\\\labels\\\\ASAS-SN_labels_v.csv\", index=False)\n",
    "\n",
    "# Load and filter Gaia g-band data for variables with classification probability >= 0.99\n",
    "path = 'data\\\\Gaia\\\\classification\\\\'\n",
    "files = os.listdir(path)\n",
    "all = pd.DataFrame()\n",
    "for file in files:\n",
    "    df = pd.read_csv(path + file, comment='#') # 跳过注释行\n",
    "    df = df[['source_id', 'best_class_name', 'best_class_score']]\n",
    "    df = df[(df.best_class_score >= 0.99)]\n",
    "    all = pd.concat([all, df], ignore_index=True)\n",
    "\n",
    "all = all.rename(columns={'source_id':'ID', 'best_class_name':'Class'})\n",
    "all = all.replace(\"ECL\", 'EB')\n",
    "all.to_csv(\"data\\\\labels\\\\Gaia_labels_g.csv\", index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Calculate the magnitude error of variables from Gaia\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "# Function to calculate magnitude error\n",
    "def get_mag_err(flux, flux_error):\n",
    "    return 1.086 * flux_error / flux\n",
    "\n",
    "# Directory containing Gaia light curve files\n",
    "dst = 'data\\Lightcurves\\Gaia/'\n",
    "files = os.listdir(dst)\n",
    "for f in files:\n",
    "    df = pd.read_csv(dst + f)\n",
    "    df['mag_err'] = df.apply(lambda x: get_mag_err(x['flux'], x['flux_err']), axis=1)\n",
    "    df.to_csv(dst + f, index = False)\n",
    "    print(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We collect light curves of the non-variable sources in ASAS-SN. \n",
    "# Select a sky position 2' away from the variable stars, and retrieve all light curves in a radius of 1.2'.\n",
    "\n",
    "alldata = pd.read_csv(\"data\\\\ASAS-SN\\\\G\\\\asassn_variables_x.csv\")\n",
    "# query list\n",
    "arcmin_to_deg = 1.0/60.0\n",
    "q_ra = alldata['RAJ2000']\n",
    "q_dec = alldata['DEJ2000'] + 2.0 * arcmin_to_deg\n",
    "q_radius = np.ones(len(q_ra)) * 1.2 * 60.0 # arcsec\n",
    "q_radius2 = np.ones(len(q_ra)) * 1.2 * arcmin_to_deg # degree\n",
    "np.savetxt('data\\\\query\\\\list_to_query.txt', np.array([q_ra, q_dec, q_radius]).T, fmt=\"%f %f %f\")\n",
    "non_trans_id = np.int_(np.ones(len(q_ra))*1E6 + np.arange(len(q_ra))+1.0)\n",
    "np.savetxt('data\\\\query\\\\list_to_query2.txt', np.array([non_trans_id, q_ra, q_dec, q_radius2]).T, fmt=\"%d %f %f %f\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Extract features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import feets\n",
    "from feets.feets import *\n",
    "from feets.feets.datasets.base import Data\n",
    "import feets.feets.preprocess\n",
    "\n",
    "import warnings\n",
    "warnings.filterwarnings(\"ignore\")\n",
    "\n",
    "# Define a function to extract features from light curves\n",
    "def extract_features(files):\n",
    "    df = pd.DataFrame()\n",
    "    path = \"data\\\\Lightcurves\\\\ZTF\"\n",
    "    for file in files:\n",
    "            oid = file.replace(\".csv\", \"\")\n",
    "            lc = pd.read_csv(os.path.join(path, file))\n",
    "            if lc.shape[0] <= 20:\n",
    "                continue\n",
    "            bands = (\"g\", \"V\")\n",
    "            data = {\n",
    "                \"g\": {\"time\": lc['time'], \"magnitude\": lc['mag'], \"error\": lc['mag_err']},\n",
    "                \"V\": {\"time\": lc['time'], \"magnitude\": lc['mag'], \"error\": lc['mag_err']},\n",
    "            }\n",
    "            descr = (\n",
    "                # \"The files are gathered from the ASAS-SN, Gaia, ZTF \"\n",
    "            )\n",
    "\n",
    "            lc = Data(\n",
    "                id=oid,\n",
    "                metadata=None,\n",
    "                ds_name=\"ZTF\",\n",
    "                description=descr,\n",
    "                bands=bands,\n",
    "                data=data,\n",
    "            )\n",
    "\n",
    "            # Preprocess: remove noise\n",
    "            time, mag, error = feets.feets.preprocess.remove_noise(**lc.data.g)\n",
    "            lc = [time, mag, error]\n",
    "    \n",
    "            # Extract features using feets\n",
    "            fs = feets.feets.FeatureSpace(data = ['time', 'magnitude', 'error'], \n",
    "                                          exclude = ['DMDT', 'SignaturePhMag'])\n",
    "            features, values = fs.extract(*lc)\n",
    "\n",
    "            fdict = {'ID': oid}\n",
    "            fdict.update(dict(zip(features, values)))\n",
    "    \n",
    "            # Append the features to the dataframe\n",
    "            df = df.append(fdict, ignore_index=True)\n",
    "\n",
    "            # Save individual feature files\n",
    "            df2 = pd.DataFrame([fdict])\n",
    "            df2.to_csv(\"data\\\\Features\\\\ZTF\" + file + \".csv\", index=False)\n",
    "            \n",
    "            print(file)\n",
    "\n",
    "    # Save all features to a single file\n",
    "    df.to_csv('data\\\\Features\\\\all_features_ztf.csv', index=False)\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    # Load the list of files to extract features from\n",
    "    path = \"data\\\\Lightcurves\\\\ZTF\"\n",
    "    files = os.listdir(path)\n",
    "    extract_features(files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Match the features and labels of each variable star in ZTF, as well as ASAS-SN and Gaia\n",
    "import pandas as pd\n",
    "import os\n",
    "\n",
    "allf = pd.read_csv(\"data\\\\Features\\\\all_features_ztf.csv\")\n",
    "\n",
    "labels = pd.read_csv(\"data\\\\labels\\\\ZTF_lables_g.csv\")\n",
    "labels = labels[['ID', 'Type']]\n",
    "labels = labels.drop_duplicates('ID',keep=False) # Remove duplicate IDs, keeping only unique ones\n",
    "\n",
    "full = pd.merge(labels, allf, on='ID', how='inner')\n",
    "full.to_csv(\"data\\\\ZTF_g_full.csv\", index=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Merge the four files for subsequent experiments.\n",
    "import pandas as pd\n",
    "import os\n",
    "df1 = pd.read_csv(\"data\\\\ZTF_g_full.csv\")\n",
    "df2 = pd.read_csv(\"data\\\\Gaia_g_full.csv\")\n",
    "df3 = pd.read_csv(\"data\\\\ASAS-SN_g_full.csv\")\n",
    "df4 = pd.read_csv(\"data\\\\ASAS-SN_v_full.csv\")\n",
    "full = pd.concat([df1, df2, df3, df4], ignore_index=True)\n",
    "\n",
    "full = full.dropna(axis=0, how='any')\n",
    "full = full[~full.isin([np.nan, np.inf, -np.inf]).any(1)].dropna()\n",
    "full.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
    "full.to_csv(\"data\\\\LEAVES_full.csv\", index=False)\n",
    "print(full['Type'].value_counts())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiments of Variable Classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Make datasets for HBRF\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "df_feat = pd.read_csv('data\\\\LEAVES_full.csv')\n",
    "\n",
    "df_feat.replace([np.inf, -np.inf], np.nan, inplace=True)\n",
    "\n",
    "print(df_feat)\n",
    "\n",
    "df_labels = df_feat[['Class']]\n",
    "\n",
    "df_labels['class_original'] = df_labels['Class']\n",
    "\n",
    "# print(df_labels)\n",
    "\n",
    "label_order = ['EA','EW','ROT', # extrinsic\n",
    "               'RRAB','RRC','CEP','M','SR','DSCT', # intrinsic\n",
    "               'Non-var']\n",
    "\n",
    "labels = df_labels.loc[df_labels.class_original.isin(label_order)][[\"class_original\"]]\n",
    "\n",
    "\n",
    "#defining init classes:\n",
    "\n",
    "labels['class_init'] = 'Variable'\n",
    "\n",
    "labels.loc[(labels['class_original'] == 'Non-var') , 'class_init'] = 'Non-var'\n",
    "\n",
    "cm_classes_init = ['Variable','Non-var']\n",
    "\n",
    "#defining variable classes:\n",
    "\n",
    "labels['class_variable'] = labels['class_original']\n",
    "\n",
    "labels.loc[((labels['class_variable'] == 'ROT') | \n",
    "           (labels['class_variable'] == 'EA') | (labels['class_variable'] == 'EW')), 'class_variable'] = 'Extrinsic'\n",
    "labels.loc[(labels['class_variable'] == 'RRAB') | (labels['class_variable'] == 'RRC') | \n",
    "           (labels['class_variable'] == 'DSCT') | (labels['class_variable'] == 'CEP') | \n",
    "           (labels['class_variable'] == 'M') | (labels['class_variable'] == 'SR') , 'class_variable'] = 'Intrinsic'\n",
    "\n",
    "cm_classes_variable = ['Extrinsic','Intrinsic']\n",
    "cm_classes_original = label_order\n",
    "\n",
    "cm_classes_extrinsic = ['EB', 'ROT']\n",
    "cm_classes_intrinsic = ['RR', 'CEP','LPV', 'DSCT']\n",
    "cm_classes_variable2 = ['EB', 'ROT', 'RR', 'CEP','LPV', 'DSCT', 'Non-var']\n",
    "#defining ecl\\rrl\\lpv classes:\n",
    "labels['class_variable2'] = labels['class_original']\n",
    "\n",
    "labels.loc[((labels['class_variable2'] == 'EA') | (labels['class_variable2'] == 'EW') ), 'class_variable2'] = 'EB'\n",
    "\n",
    "labels.loc[((labels['class_variable2'] == 'RRAB') | (labels['class_variable2'] == 'RRC')) , 'class_variable2'] = 'RR'\n",
    "\n",
    "labels.loc[((labels['class_variable2'] == 'M') | (labels['class_variable2'] == 'SR')) , 'class_variable2'] = 'LPV'\n",
    "\n",
    "cm_classes_ecl = ['EA','EW']\n",
    "cm_classes_rrl = ['RRAB','RRC']\n",
    "cm_classes_lpv = ['M','SR']\n",
    "\n",
    "print(labels['class_variable'].values.shape)\n",
    "labels.head()\n",
    "\n",
    "rm_nd_cols = ['Class']\n",
    "\n",
    "df = labels.join(df_feat.drop(rm_nd_cols, axis=1),how='inner')\n",
    "df.to_csv('data\\\\LEAVES_full_HBRF.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "astronomical",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
