'''
    File name: run_bursts_in_folder.py
    Author: Peijin Zhang 张沛锦
    Date : 2022-4-26
    
    a script to run through the dataset
'''

import sys
import matplotlib.dates as mdates
import datetime
import matplotlib.pyplot as plt
import numpy as np
import astropy.io.fits as fits
import scipy
import glob
import matplotlib as mpl
import matplotlib
matplotlib.use('agg')
import radioTools as rt
import glob
import json
import os
from tqdm import tqdm

# try to use the precise epoch
mpl.rcParams['date.epoch']='1970-01-01T00:00:00'
try:
    mdates.set_epoch('1970-01-01T00:00:00')
except:
    pass
import detectRadioburst as drb

fits_root = './out2'
out_folder_base = './detect2'
plot_fig = True
dump_info_to_json = True
write_csv = True
incremental=True

csv_fname = 'event.csv'
#os.system('rm '+csv_fname)
id_event  = 0
if not os.path.exists(csv_fname):
    with open(csv_fname,'w') as fp:
        fp.write('''ID, t, t0_num, t1_num,f_0,f_1, dfdt(MHz/s), v_b(c)
             ''')
    fp.close()

for path, subdirs, files in os.walk(fits_root):
    for name in tqdm(files):
        if name.endswith('.fits'):
            fname = (os.path.join(path, name))
            fname_json = fname.replace('.fits','.json')
            fname_detect = fname.replace(fits_root,out_folder_base).replace('.fits','.png')
            
            os.makedirs(path.replace(fits_root,out_folder_base), exist_ok=True)

            (dyspec,t_fits,f_fits,hdu) = drb.read_fits(fname)
            (dyspec,f_fits) =  drb.cut_low(dyspec,f_fits,f_low_cut_val=25)
            (data_fits_new_tmp,data_fits_new) = drb.preproc(
                dyspec,gauss_sigma=2)
            bmap = drb.binarization(data_fits_new,N_order=6,peak_r=1+1e-4)
            lines = drb.hough_detect(bmap,dyspec,threshold=40,line_gap=10,line_length=30,
                theta=np.linspace(np.pi/2-np.pi/18,np.pi/2-1e-3,1000))
            if len(lines)<=1:
                continue
            line_sets = drb.line_grouping(lines)
            (v_beam, f_range_burst, t_range_burst, model_curve_set,
                t_set_arr_set,f_set_arr_set,t_model_arr,f_model_arr
                )= drb.get_info_from_linegroup(line_sets,t_fits,f_fits)
            with open(fname_json, 'r') as fp:
                old_json = json.load(fp)
            fp.close()
            drb.append_into_json(old_json, v_beam, f_range_burst, t_range_burst)
            with open(fname_json, 'w') as fp:
                json.dump(old_json,fp)
                fp.close()

            if plot_fig==True:
                fig,ax = plt.subplots(1,1,figsize=[6,3],dpi=200)
                lines = sorted(lines, key=lambda i: i[0][1])
                dt = t_fits[1]-t_fits[0]

                    
                # scale vmax and vmin
                freq_safe0, freq_safe1 = int(0.1 * f_fits.shape[0]), int(0.99 * f_fits.shape[0])
                data_safe_arr = data_fits_new[:, freq_safe0:freq_safe1].ravel()
                data_safe = np.sort(data_safe_arr)[int(data_safe_arr.shape[0] * 0.02):int(data_safe_arr.shape[0] * 0.98)]
                
                vmin,vmax = [(np.nanmean(data_safe) - 2 * np.nanstd(data_safe)),
                        (np.nanmean(data_safe) + 2 * np.nanstd(data_safe)+0.9*np.nanmax(data_safe))]
                        
                ax.imshow(data_fits_new.T,aspect='auto',origin='lower', vmin=vmin,vmax=vmax,cmap='gray',
                                extent=[t_fits[0]-dt,t_fits[-1]+dt,f_fits[0],f_fits[-1]])
                for idx,model in enumerate(model_curve_set):
                    if (v_beam[idx]>0.03) & (v_beam[idx]<0.9):
                        plt.plot(model[0],model[1],ls='--')
                        plt.plot(t_range_burst[idx],f_range_burst[idx],'k+')

                ax.xaxis_date()
                ax.xaxis.set_major_formatter(mdates.DateFormatter('%H:%M'))
                ax.set_xlabel('Time (UT)')
                ax.set_ylabel('Frequency (MHz)')
                ax.set_title(hdu[0].header['CONTENT'])
                plt.tight_layout()
                fig.savefig(fname_detect)

            if write_csv==True:
                with open(csv_fname,'a') as fp:

                    for idx,v_cur in enumerate(v_beam):
                        if (v_cur>0) & (v_cur<0.9):
                            fp.write(str(id_event)+','+mdates.num2date(t_range_burst[idx][0]).strftime("%y-%m-%d %H:%M:%S")+','
                                +str(t_range_burst[idx][0])+','+str(t_range_burst[idx][1])+','
                                +str(f_range_burst[idx][0])+','+str(f_range_burst[idx][1])+','
                                +str((np.max(f_set_arr_set[idx])-np.min(f_set_arr_set[idx]))/
                                (np.max(t_set_arr_set[idx])-np.min(t_set_arr_set[idx])))+','
                                +str(v_beam[idx])
                                +'''
                                ''')
                            id_event+=1
                fp.close()
            
        

