# Script Name: ps1_correction
# Author: Xiao, K., a PhD student supervised by Yuan, H.
# Date: 2023-07-05
#
# Refs: [1] Xiao, K. & Yuan, H. 2022, AJ, 163, 185
#       [2] Xiao, K., Yuan, H., Huang, B., et al. 2023, ApJS
#
# Description:
#     This script is designed to correct for position-dependent systematic errors of PS1 photometry conveniently.
#     It relies on a 2D correction map (see [2]), utilizing numerical interpolation,
#     to determine the magnitude offset for any given position and PS1 band.
#     By adding the magnitude offset to the original PS1 magnitude,
#     the resulting value represents the corrected PS1 magnitude.


import numpy as np
from astropy.io import fits
from scipy.interpolate import RBFInterpolator

ps1_filter = ['g', 'r', 'i', 'z', 'y']


def correction(ra, dec, band):
    """
    :param ra: right ascension; list or array
    :param dec: declination; list or array
    :param band: char; one of the 'g', 'r', 'i', 'z', 'y'
    :return: magnitude offset
    """

    points = np.array([np.array(ra), np.array(dec)]).T
    n = [i for i, j in enumerate(ps1_filter) if j == band][0]

    # Load the dataset
    with fits.open('PS1_Correction.fits') as hdul:
        data = hdul[n+1].data

    global delmag0
    cols = ['ra', 'dec', 'del_'+ps1_filter[n]]
    ra0, dec0, delmag0 = [data[i] for i in cols]

    # numerical interpolation
    idx = np.where(~np.isnan(dec0))[0]
    x, y = np.array([ra0[idx], dec0[idx]]).T, delmag0[idx]
    delmag = RBFInterpolator(x, y, neighbors=2, kernel='linear')(points)

    return delmag


# print(correction([192.825484], [27.27490232845929], 'i'))


if __name__ == '__main__':
    """
    Used to reproduce Figure 12 in reference [2].
    """

    import healpy as hp
    import matplotlib.pyplot as plt
    from matplotlib import rcParams
    from matplotlib.ticker import MultipleLocator
    rcParams['font.family'] = 'Times New Roman'
    rcParams.update({'mathtext.fontset': 'cm'})

    hp.mollview(delmag0, coord=["G", "C"], min=-0.012, max=0.012, norm="hist", cmap='jet')
    hp.graticule(color='k', dpar=30)
    plt.show()
