#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar 21 09:56:44 2023

@author: hezhihong
"""
from scipy.optimize import curve_fit
import numpy as np
import pandas as pd
############################
def linear_regression(x, y):#### get initial inclination
    N = len(x)
    sumx = sum(x)
    sumy = sum(y)
    sumx2 = sum(x ** 2)
    sumxy = sum(x * y)
    A = np.mat([[N, sumx], [sumx, sumx2]])
    b = np.array([sumy, sumxy])
    return np.linalg.solve(A, b)
####
def get_inc_z0_pphi(sample,rrr,ss):#### get inclination/theta/theta_i/phi/phi_i/dz
    ####
    ss=ss[(ss.gcr>=rrr-1)&(ss.gcr<rrr+1)]#### OCs in a +-1kpc width
    disp_z=1.4826*np.median(abs(ss.zz-np.median(ss.zz)))#### dispersion:1.4826*MAD(Cantat-Gaudin+2020)
    ss=ss[abs(ss.zz-np.median(ss.zz))<3*disp_z]
    ####
    df=ss
    dz0=np.median(df.zz);angz0 = 180*np.arctan(dz0/rrr)/np.pi
    aa,kk=linear_regression(df.yy, df.zz)
    inc1 = np.arctan(kk)*180/np.pi####initial inc_angle
    ####
    ss1p=ss[ss.e_vphi<2.5]#### select sample error_vphi<2.5 kms
    vphi=np.median(ss1p.vphi)
    ss1z=ss[ss.e_vz<1]#### select sample error_vz<1 kms
    ####
    plons=[]
    ####
    for dvz in np.arange(0.,2,.1):#delta_vz_sun 1+-1 kms
        if len(ss1z)>=30:
            ####-----> Equation 2 get phi_LON:
            kine= lambda theta,vp,vzd,plon:(abs(kk)/kk)*(vphi-rrr*vp)/((1+1./(np.tan(np.radians(inc1)))**2)/(np.cos(np.radians(plon-theta))**2)-1)**.5+\
            vzd*rrr*np.sin(np.radians(plon-theta))*np.cos(np.radians(inc1))
            ####
            try:
                popt,pcov= curve_fit(kine,ss1z.gcphi,(ss1z.vz+dvz),bounds=([-100,-1,100],[100,1,250]))
            except:
                continue
            ####
            plons+=[popt[2]]
            ####         
        ####
    ####
    plon=np.median(plons)####get phi_lon
    ####
    try:
        ####-----> Equation 1 get inclination:
        ince= lambda theta,i,z0: rrr*np.sin(np.radians(plon-theta))*np.sin(np.radians(i))+z0
        popt,pcov= curve_fit(ince,ss1z.gcphi,ss1z.zz,bounds=([-10,-0.05],[10,0.05]))
        inc=popt[0]####get inc
    except:
        inc=-10
    ####
    if abs(inc)<9.9:
        inc2=inc
    else:
        inc2=inc1
        ####
    ####
    vpps=[]
    vzds=[]
    plons=[]
    for dvz in np.arange(0.,2,.1):
        if len(ss1z)>=30:
            ########-----> iterative operation Equation 2:
            kine= lambda theta,vp,vzd,plon:(abs(kk)/kk)*(vphi-rrr*vp)/((1+1./(np.tan(np.radians(inc2)))**2)/(np.cos(np.radians(plon-theta))**2)-1)**.5+\
            vzd*rrr*np.sin(np.radians(plon-theta))*np.cos(np.radians(inc2))
            ####
            try:
                popt,pcov= curve_fit(kine,ss1z.gcphi,(ss1z.vz+dvz),bounds=([-100,-1,100],[100,1,250]))
            except:
                continue
            ####
            vpps += [popt[0]]
            vzds += [popt[1]]
            plons+= [popt[2]]
            ####         
        ####
    ####
    vpp=np.median(vpps);disp_vpp=1.4826*np.median(abs(vpps-np.median(vpps)))
    vzd=np.median(vzds);disp_vzd=1.4826*np.median(abs(vzds-np.median(vzds)))
    plon=np.median(plons);disp_plon=1.4826*np.median(abs(plons-np.median(plons)))
    ####
    try:
        ####-----> iterative operation Equation 1:
        ince= lambda theta,i,z0: rrr*np.sin(np.radians(plon-theta))*np.sin(np.radians(i))+z0
        popt,pcov= curve_fit(ince,ss1z.gcphi,ss1z.zz,bounds=([-10,-0.05],[10,0.05]))
        inc=popt[0]
        if abs(popt[1])<0.049:####remove bad fits
            dz0=popt[1]
        else:
            dz0= ''
    except:
        inc= ''
        dz0= ''
    ####
    val=[sample,len(ss),\
         inc1,angz0,plon,disp_plon,\
         vpp,disp_vpp,\
         vzd,disp_vzd,\
         inc,dz0]
    ####
    return val
####
col=['max_age','n','inc1','theta_z0','phi_lon','disp_phi_lon',\
     'wz_p','disp_wz_p','wz_n','disp_wz_n','inc','dz0']
####
ss0=pd.read_csv('./ocs.csv',encoding='utf-8')
ss0=ss0[(ss0.gcphi<250)&(ss0.gcphi>=100)]####select sample 100<phi<250, delete most distant OCs
#####
for rgc in np.arange(4,14.1,0.5):
    vals=[]
    for age in np.arange(7.3,9.3,0.1):
        ss=ss0[(ss0.age>=0)&(ss0.age<age)]
        val=get_inc_z0_pphi(age,rgc,ss)
        vals+=[val]
        ####
    ####
    df=pd.DataFrame(vals,columns=col)
    df.to_csv('./files/inc_pre_nut_'+str(rgc)+'kpc.csv',index=False)
####--->
####Then run read_files.py to read the files---<
