From b44cf644533f9018c94be5377610ff27c2c43f24 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Mon, 7 Feb 2022 14:05:53 +0100 Subject: [PATCH 1/2] added noise upgrade study --- studies/upgrade_noise/make_plots.py | 203 ++++++++++++++++++ studies/upgrade_noise/make_selection.py | 86 ++++++++ studies/upgrade_noise/run_jobs.py | 262 ++++++++++++++++++++++++ 3 files changed, 551 insertions(+) create mode 100644 studies/upgrade_noise/make_plots.py create mode 100644 studies/upgrade_noise/make_selection.py create mode 100644 studies/upgrade_noise/run_jobs.py diff --git a/studies/upgrade_noise/make_plots.py b/studies/upgrade_noise/make_plots.py new file mode 100644 index 000000000..a70ed9a6d --- /dev/null +++ b/studies/upgrade_noise/make_plots.py @@ -0,0 +1,203 @@ +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import sqlite3 +from sklearn.metrics import auc +from sklearn.metrics import roc_curve + +def add_truth(data, database): + data = data.sort_values('event_no').reset_index(drop = True) + with sqlite3.connect(database) as con: + query = 'select event_no, energy, interaction_type, pid from truth where event_no in %s'%str(tuple(data['event_no'])) + truth = pd.read_sql(query,con).sort_values('event_no').reset_index(drop = True) + + truth['track'] = 0 + truth.loc[(abs(truth['pid']) == 14) & (truth['interaction_type'] == 1), 'track'] = 1 + add_these = [] + for key in truth.columns: + if key not in data.columns: + add_these.append(key) + for key in add_these: + data[key] = truth[key] + return data + +def get_interaction_type(row): + if row["interaction_type"] == 1: # CC + particle_type = "nu_" + {12: 'e', 14: 'mu', 16: 'tau'}[abs(row['pid'])] + return f"{particle_type} CC" + else: + return "NC" +def resolution_fn(r): + if len(r) > 1: + return (np.percentile(r, 84) - np.percentile(r, 16)) / 2. + else: + return np.nan + +def add_energylog10(df): + df['energy_log10'] = np.log10(df['energy']) + return df + +def get_error(residual): + rng = np.random.default_rng(42) + w = [] + for i in range(150): + new_sample = rng.choice(residual, size = len(residual), replace = True) + w.append(resolution_fn(new_sample)) + return np.std(w) + +def get_roc_and_auc(data, target): + fpr, tpr, _ = roc_curve(data[target], data[target+'_pred']) + auc_score = auc(fpr,tpr) + return fpr,tpr,auc_score + +def plot_roc(target, runids, save_dir, save_as_csv = False): + width = 3.176*2 + height = 2.388*2 + fig = plt.figure(figsize = (width,height)) + for runid in runids: + data = pd.read_csv('/home/iwsatlas1/oersoe/phd/upgrade_noise/results/dev_step4_numu_%s_second_run/upgrade_%s_regression_45e_GraphSagePulses/results.csv'%(runid,target)) + database = '/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_%s_second_run/data/dev_step4_numu_%s_second_run.db'%(runid, runid) + if save_as_csv: + data = add_truth(data, database) + data = add_energylog10(data) + data.to_csv(save_dir + '/%s_%s.csv'%(runid, target)) + pulses_cut_val = 20 + if runid == 140021: + pulses_cut_val = 10 + fpr, tpr, auc = get_roc_and_auc(data, target) + plt.plot(fpr,tpr, label =' %s : %s'%(runid,round(auc,3))) + plt.legend() + plt.title('Track/Cascade Classification') + plt.ylabel('True Positive Rate', fontsize = 12) + plt.xlabel('False Positive Rate', fontsize = 12) + ymax = 0.3 + x_text = 0.2 + y_text = ymax - 0.05 + y_sep = 0.1 + plt.text(x_text, y_text - 0 * y_sep, "IceCubeUpgrade/nu_simulation/detector/step4/(%s,%s)"%(runids[0], runids[1]), va='top', fontsize = 8) + plt.text(x_text, y_text - 1 * y_sep, "Pulsemaps used: SplitInIcePulses_GraphSage_Pulses ", va='top', fontsize = 8) + plt.text(x_text, y_text - 2 * y_sep, "n_pulses > (%s, %s) selection applied during training"%(10,20), va='top', fontsize = 8) + + fig.savefig('/home/iwsatlas1/oersoe/phd/upgrade_noise/plots/preliminary_upgrade_performance_%s.pdf'%(target),bbox_inches="tight") + return + +def calculate_width(data_sliced, target): + track =data_sliced.loc[data_sliced['track'] == 1,:].reset_index(drop = True) + cascade =data_sliced.loc[data_sliced['track'] == 0,:].reset_index(drop = True) + if target == 'energy': + residual_track = ((track[target + '_pred'] - track[target])/track[target])*100 + residual_cascade = ((cascade[target + '_pred'] - cascade[target])/cascade[target])*100 + elif target == 'zenith': + residual_track = (track[target + '_pred'] - track[target])*(360/(2*np.pi)) + residual_cascade = (cascade[target + '_pred'] - cascade[target])*(360/(2*np.pi)) + else: + residual_track = (track[target + '_pred'] - track[target]) + residual_cascade = (cascade[target + '_pred'] - cascade[target]) + + return resolution_fn(residual_track), resolution_fn(residual_cascade), get_error(residual_track), get_error(residual_cascade) + +def get_width(df, target): + track_widths = [] + cascade_widths = [] + track_errors = [] + cascade_errors = [] + energy = [] + bins = np.arange(0,3.1,0.1) + if target in ['zenith', 'energy', 'XYZ']: + for i in range(1,len(bins)): + print(bins[i]) + idx = (df['energy_log10']> bins[i-1]) & (df['energy_log10'] < bins[i]) + data_sliced = df.loc[idx, :].reset_index(drop = True) + energy.append(np.mean(data_sliced['energy_log10'])) + track_width, cascade_width, track_error, cascade_error = calculate_width(data_sliced, target) + track_widths.append(track_width) + cascade_widths.append(cascade_width) + track_errors.append(track_error) + cascade_errors.append(cascade_error) + track_plot_data = pd.DataFrame({'mean': energy, 'width': track_widths, 'width_error': track_errors}) + cascade_plot_data = pd.DataFrame({'mean': energy, 'width': cascade_widths, 'width_error': cascade_errors}) + return track_plot_data, cascade_plot_data + else: + print('target not supported: %s'%target) + +# Load data +def make_plot(target, runids, save_dir, save_as_csv = False): + colors = {140021: 'tab:blue', 140022: 'tab:orange'} + fig = plt.figure(constrained_layout = True) + ax1 = plt.subplot2grid((6, 6), (0, 0), colspan = 6, rowspan= 6) + for runid in runids: + predictions_path = '/home/iwsatlas1/oersoe/phd/upgrade_noise/results/dev_step4_numu_%s_second_run/upgrade_%s_regression_45e_GraphSagePulses/results.csv'%(runid,target) + database = '/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_%s_second_run/data/dev_step4_numu_%s_second_run.db'%(runid, runid) + pulses_cut_val = 20 + if runid == 140021: + pulses_cut_val = 10 + df = pd.read_csv(predictions_path).sort_values('event_no').reset_index(drop = True) + df = add_truth(df, database) + df = add_energylog10(df) + if save_as_csv: + df.to_csv(save_dir + '/%s_%s.csv'%(runid, target)) + plot_data_track, plot_data_cascade = get_width(df, target) + + ax1.plot(plot_data_track['mean'],plot_data_track['width'],linestyle='solid', lw = 0.5, color = 'black', alpha = 1) + ax1.fill_between(plot_data_track['mean'],plot_data_track['width'] - plot_data_track['width_error'], plot_data_track['width'] + plot_data_track['width_error'],color = colors[runid], alpha = 0.8 ,label = 'Track %s'%runid) + + ax1.plot(plot_data_cascade['mean'],plot_data_cascade['width'],linestyle='dashed', color = 'tab:blue', lw = 0.5, alpha = 1) + ax1.fill_between(plot_data_cascade['mean'], plot_data_cascade['width']- plot_data_cascade['width_error'], plot_data_cascade['width']+ plot_data_cascade['width_error'], color = colors[runid], alpha = 0.3, label = 'Cascade %s'%runid ) + + ax2 = ax1.twinx() + ax2.hist(df['energy_log10'], histtype = 'step', label = 'deposited energy', color = colors[runid]) + + #plt.title('$\\nu_{v,u,e}$', size = 20) + ax1.tick_params(axis='x', labelsize=6) + ax1.tick_params(axis='y', labelsize=6) + ax1.set_xlim((0,3.1)) + + leg = ax1.legend(frameon=False, fontsize = 8) + for line in leg.get_lines(): + line.set_linewidth(4.0) + + if target == 'energy': + ax1.set_ylim((0,175)) + ymax = 23. + y_sep = 8 + unit_tag = '(%)' + else: + unit_tag = '(deg.)' + if target == 'angular_res': + target = 'direction' + if target == 'XYZ': + target = 'vertex' + unit_tag = '(m)' + if target == 'zenith': + ymax = 10. + y_sep = 2.3 + ax1.set_ylim((0,45)) + + + plt.tick_params(right=False,labelright=False) + ax1.set_ylabel('%s Resolution %s'%(target.capitalize(), unit_tag), size = 10) + ax1.set_xlabel('Energy (log10 GeV)', size = 10) + + + x_text = 0.5 + y_text = ymax - 2. + ax1.text(x_text, y_text - 0 * y_sep, "IceCubeUpgrade/nu_simulation/detector/step4/(%s,%s)"%(runids[0], runids[1]), va='top', fontsize = 8) + ax1.text(x_text, y_text - 1 * y_sep, "Pulsemaps used: SplitInIcePulses_GraphSage_Pulses ", va='top', fontsize = 8) + ax1.text(x_text, y_text - 2 * y_sep, "n_pulses > (%s, %s) selection applied during training"%(10,20), va='top', fontsize = 8) + + fig.suptitle("%s regression Upgrade MC using GNN"%target) + + #fig.suptitle('%s Resolution'%target.capitalize(), size = 12) + fig.savefig('/home/iwsatlas1/oersoe/phd/upgrade_noise/plots/preliminary_upgrade_performance_%s.pdf'%(target))#,bbox_inches="tight") + + return + +runids = [140021, 140022] +targets = ['zenith', 'energy', 'track'] +save_as_csv = True +save_dir = '/home/iwsatlas1/oersoe/phd/tmp/upgrade_csv' +for target in targets: + if target != 'track': + make_plot(target, runids, save_dir, save_as_csv) + else: + plot_roc(target, runids, save_dir, save_as_csv) \ No newline at end of file diff --git a/studies/upgrade_noise/make_selection.py b/studies/upgrade_noise/make_selection.py new file mode 100644 index 000000000..19f747aa7 --- /dev/null +++ b/studies/upgrade_noise/make_selection.py @@ -0,0 +1,86 @@ +from tkinter import W +import pandas as pd +import numpy as np +import sqlite3 + +import pandas as pd +import sqlite3 +import numpy as np +import matplotlib.pyplot as plt +import pickle +import pandas as pd +import sqlite3 +import numpy as np +import os +from copy import deepcopy +import multiprocessing +from multiprocessing import Pool +import tqdm + +def parallel_this(settings): + events, id, tmp_dir, db = settings + con = sqlite3.connect(db) + labels = [] + for index in tqdm.tqdm(events): + features = con.execute("SELECT event_no FROM SplitInIcePulses_GraphSage_Pulses WHERE event_no = {}".format(index)) + features = features.fetchall() + if len(features) > 10: + labels.append(index) + labels = pd.DataFrame(data = labels, columns = ['event_no']) + labels.to_csv('%s/tmp_%s.csv'%(tmp_dir, id)) + return + +def merge_tmp(tmp_dir): + tmps = os.listdir(tmp_dir) + is_first = True + for file in tmps: + if '.csv' in file: + if is_first: + df = pd.read_csv(tmp_dir + '/' + file) + is_first = False + else: + df = df.append(pd.read_csv(tmp_dir + '/' + file), ignore_index= True) + df = df.sort_values('event_no').reset_index(drop = True) + return df + +def over_10_pulses(n_workers, path, tmp_dir, db, events): + settings = [] + event_batches = np.array_split(events.values.ravel().tolist(),n_workers) + for i in range(n_workers): + settings.append([event_batches[i], i, tmp_dir, db]) + p = Pool(processes=n_workers) + p.map_async(parallel_this, settings) + p.close() + p.join() + selection = merge_tmp(tmp_dir) + selection.to_csv(path + '/over10pulses.csv') + return + +def make_even_track_cascade(events, db): + with sqlite3.connect(db) as con: + query = 'select event_no from truth where abs(pid) = 14 and interaction_type = 1 and event_no in %s'%str(tuple(events['event_no'])) + tracks = pd.read_sql(query,con) + query = 'select event_no from truth where event_no not in %s and event_no in %s'%(str(tuple(tracks['event_no'])), str(tuple(events['event_no']))) + cascades = pd.read_sql(query,con) + print('found %s tracks'%len(tracks)) + print('found %s cascades'%len(cascades)) + if len(tracks) > len(cascades): + return pd.concat([tracks.sample(len(cascades)), cascades], ignore_index = True).sample(frac = 1).reset_index(drop = True) + else: + return pd.concat([tracks, cascades.sample(len(tracks))], ignore_index = True).sample(frac = 1).reset_index(drop = True) + +if __name__ == '__main__': + n_workers = 50 + db = '/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_140021_second_run/data/dev_step4_numu_140021_second_run.db' + tmp_dir = '/home/iwsatlas1/oersoe/phd/upgrade_noise/tmp' + path = '/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_140021_second_run/selection' + with sqlite3.connect(db) as con: + query = 'select event_no from truth' + events = pd.read_sql(query,con) + over_10_pulses(n_workers, path, tmp_dir, db, events) + events = pd.read_csv(path + '/over10pulses.csv',) + selection = make_even_track_cascade(events, db) + selection.to_csv(path + '/even_track_cascade_over10pulses.csv') + + + diff --git a/studies/upgrade_noise/run_jobs.py b/studies/upgrade_noise/run_jobs.py new file mode 100644 index 000000000..3e4187579 --- /dev/null +++ b/studies/upgrade_noise/run_jobs.py @@ -0,0 +1,262 @@ +import os +import pandas as pd +from pytorch_lightning import Trainer +from pytorch_lightning.callbacks import EarlyStopping +from graphnet.models.task.reconstruction import BinaryClassificationTask +import torch +from torch.optim.adam import Adam +from graphnet.components.loss_functions import BinaryCrossEntropyLoss +from graphnet.components.loss_functions import LogCoshLoss, VonMisesFisher2DLoss, EuclideanDistance +from graphnet.components.utils import fit_scaler +from graphnet.data.constants import FEATURES, TRUTH +from graphnet.data.utils import get_equal_proportion_neutrino_indices +from graphnet.models import Model +from graphnet.models.detector.icecube import IceCubeUpgrade +from graphnet.models.gnn import DynEdge_V2 +from graphnet.models.graph_builders import KNNGraphBuilder +from graphnet.models.task.reconstruction import PassOutput1, PassOutput3, ZenithReconstructionWithKappa +from graphnet.models.training.callbacks import ProgressBar, PiecewiseLinearLR +from graphnet.models.training.utils import get_predictions, make_train_validation_dataloader, save_results +import dill + +def save_results(db, tag, results, archive,model, validation_loss = None, training_loss = None): + db_name = db.split('/')[-1].split('.')[0] + path = archive + '/' + db_name + '/' + tag + os.makedirs(path, exist_ok = True) + results.to_csv(path + '/results.csv') + model.save(path + '/' + tag + '.pth') + #torch.save(model.cpu().state_dict(), path + '/' + tag + '.pth') + if validation_loss != None: + pd.DataFrame({'training_loss': training_loss, 'validation_loss':validation_loss}).to_csv(path + '/' +'training_hist.csv') + print('Results saved at: \n %s'%path) + + +def remove_log10(x): + return torch.pow(10, x) + +def transform_to_log10(x): + return torch.log10(x) + +def scale_XYZ(x): + x[:,0] = x[:,0]/764.431509 + x[:,1] = x[:,1]/785.041607 + x[:,2] = x[:,2]/1083.249944 + return x + +def unscale_XYZ(x): + x[:,0] = 764.431509*x[:,0] + x[:,1] = 785.041607*x[:,1] + x[:,2] = 1083.249944*x[:,2] + return x + + +# Configurations +torch.multiprocessing.set_sharing_strategy('file_system') + +# Constants +features = FEATURES.UPGRADE +truth = TRUTH.UPGRADE + +# Configuration +def build_model(run_name, device, archive): + model = torch.load(os.path.join(archive, f"{run_name}.pth"),pickle_module=dill) + model.to('cuda:%s'%device[0]) + model.eval() + model.inference() + return model + +def train_and_predict_on_validation_set(target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device, run_name,archive, train, patience = 5): + try: + del truth[truth.index('interaction_time')] + except ValueError: + # not found in list + pass + + print(f"features: {features}") + print(f"truth: {truth}") + + training_dataloader, validation_dataloader = make_train_validation_dataloader( + db = database, + selection = selection, + pulsemaps = pulsemap, + features = features, + truth = truth, + batch_size = batch_size, + num_workers=num_workers, + ) + + # Building model + detector = IceCubeUpgrade( + graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8), + ) + gnn = DynEdge_V2( + nb_inputs=detector.nb_outputs, + ) + if target == 'zenith': + task = ZenithReconstructionWithKappa( + hidden_size=gnn.nb_outputs, + target_label=target, + loss_function=VonMisesFisher2DLoss(), + ) + elif target == 'energy': + task = PassOutput1(hidden_size=gnn.nb_outputs, target_label=target, loss_function=LogCoshLoss(), transform_target = transform_to_log10, transform_inference = remove_log10) + + elif target == 'XYZ': + task = XYZReconstruction(hidden_size=gnn.nb_outputs, target_label=target, loss_function=EuclideanDistance(), transform_target = scale_XYZ, transform_inference = unscale_XYZ) + elif target == 'track': + task = BinaryClassificationTask(hidden_size=gnn.nb_outputs,target_label=target,loss_function=BinaryCrossEntropyLoss()) + else: + print('task not found') + model = Model( + detector=detector, + gnn=gnn, + tasks=[task], + optimizer_class=Adam, + optimizer_kwargs={'lr': 1e-03, 'eps': 1e-03}, + scheduler_class=PiecewiseLinearLR, + scheduler_kwargs={ + 'milestones': [0, len(training_dataloader) / 2, len(training_dataloader) * n_epochs], + 'factors': [1e-2, 1, 1e-02], + }, + scheduler_config={ + 'interval': 'step', + }, + ) + + # Training model + callbacks = [ + EarlyStopping( + monitor='val_loss', + patience=patience, + ), + ProgressBar(), + ] + + if train: + trainer = Trainer( + default_root_dir=archive, + gpus=device, + max_epochs=n_epochs, + callbacks=callbacks, + log_every_n_steps=1, + logger=None, + ) + + try: + trainer.fit(model, training_dataloader, validation_dataloader) + except KeyboardInterrupt: + print("[ctrl+c] Exiting gracefully.") + pass + + # Saving model + model.save(os.path.join(archive, f"{run_name}.pth")) + model.save_state_dict(os.path.join(archive, f"{run_name}_state_dict.pth")) + predict(model,trainer,target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device) + #else: + # model.load_state_dict(os.path.join(archive, f"{run_name}_state_dict.pth")) + # predict(model,trainer,target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device) + + +def predict(model,trainer,target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device): + try: + del truth[truth.index('interaction_time')] + except ValueError: + # not found in list + pass + device = 'cuda:%s'%device[0] + model.to(device) + _, validation_dataloader = make_train_validation_dataloader( + db = database, + selection = selection, + pulsemaps = pulsemap, + features = features, + truth = truth, + batch_size = batch_size, + num_workers=num_workers, + ) + if target in ['zenith', 'azimuth']: + #predictor_valid = Predictor( + # dataloader=validation_dataloader, + # target=target, + # device=device, + # output_column_names=[target + '_pred', target + '_kappa'], + #) + results = get_predictions( + trainer, + model, + validation_dataloader, + [target + '_pred', target + '_kappa'], + [target, 'event_no', 'energy'], + ) + + if target in ['track', 'neutrino']: + #predictor_valid = Predictor( + # dataloader=validation_dataloader, + # target=target, + # device=device, + # output_column_names=[target + '_pred'], + #) + results = get_predictions( + trainer, + model, + validation_dataloader, + [target + '_pred'], + [target, 'event_no', 'energy'], + ) + + if target == 'energy': + #predictor_valid = Predictor( + # dataloader=validation_dataloader, + # target=target, + # device=device, + # output_column_names=[target + '_pred'], + # post_processing_method= remove_log10 + #) + results = get_predictions( + trainer, + model, + validation_dataloader, + [target + '_pred'], + [target, 'event_no'], + ) + if target == 'XYZ': + #predictor_valid = Predictor( + # dataloader=validation_dataloader, + # target=target, + # device=device, + # output_column_names=['position_x_pred','position_y_pred','position_z_pred'],#,'interaction_time_pred'], + # post_processing_method= rescale_XYZ + #) + results = get_predictions( + trainer, + model, + validation_dataloader, + ['position_x_pred','position_y_pred','position_z_pred'], + ['position_x','position_y','position_z', 'event_no', 'energy'], + ) + save_results(database, run_name, results, archive,model) + return +# Main function call +if __name__ == "__main__": + # Run management + archive = "/home/iwsatlas1/oersoe/phd/upgrade_noise/results" + targets = ['zenith' ,'track' , 'energy'] #, 'vertex'] #, 'XYZ'] + batch_size = 1024 + database ='/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_140021_second_run/data/dev_step4_numu_140021_second_run.db' + device = [0] + n_epochs = 45 + num_workers = 40 + patience = 5 + pulsemap = 'SplitInIcePulses_GraphSage_Pulses' + # Common variables + for target in targets: + if target == 'track': + selection = pd.read_csv('/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_140021_second_run/selection/even_track_cascade_over10pulses.csv')['event_no'].values.ravel().tolist() + else: + selection = pd.read_csv('/mnt/scratch/rasmus_orsoe/databases/dev_step4_numu_140021_second_run/selection/over10pulses.csv')['event_no'].values.ravel().tolist() + + run_name = "upgrade_{}_regression_45e_GraphSagePulses".format(target) + + train_and_predict_on_validation_set(target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device, run_name,archive, train = True) + #train_and_predict_on_validation_set(target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device, run_name,archive, train = False) + #predict(target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device) From 718eee284ce3f3a529031e6441149b8f00dda710 Mon Sep 17 00:00:00 2001 From: Rasmus Oersoe Date: Tue, 8 Feb 2022 16:51:15 +0100 Subject: [PATCH 2/2] removed old training file --- .../modelling/train_models.py | 224 ------------------ 1 file changed, 224 deletions(-) delete mode 100644 studies/upgrade_neutrino_reconstruction/modelling/train_models.py diff --git a/studies/upgrade_neutrino_reconstruction/modelling/train_models.py b/studies/upgrade_neutrino_reconstruction/modelling/train_models.py deleted file mode 100644 index 2e747251c..000000000 --- a/studies/upgrade_neutrino_reconstruction/modelling/train_models.py +++ /dev/null @@ -1,224 +0,0 @@ -"""NB: Need to be updated to use the transform-functionality in Task. -""" -import dill -import os -import pandas as pd - -from pytorch_lightning import Trainer -from pytorch_lightning.callbacks import EarlyStopping -from pytorch_lightning.loggers import WandbLogger -import torch -from torch.optim.adam import Adam - -from graphnet.components.loss_functions import BinaryCrossEntropyLoss, LogCoshLoss, VonMisesFisher2DLoss, XYZWithMaxScaling -from graphnet.data.constants import FEATURES, TRUTH -from graphnet.models import Model -from graphnet.models.detector.icecube import IceCubeUpgrade -from graphnet.models.gnn import DynEdge_V2 -from graphnet.models.graph_builders import KNNGraphBuilder -from graphnet.models.task.reconstruction import EnergyReconstruction, ZenithReconstructionWithKappa, XYZReconstruction, BinaryClassificationTask -from graphnet.models.training.callbacks import ProgressBar, PiecewiseLinearLR -from graphnet.models.training.utils import get_predictions, make_train_validation_dataloader, save_results, Predictor - -# Configurations -torch.multiprocessing.set_sharing_strategy('file_system') - -# Constants -features = FEATURES.UPGRADE -truth = TRUTH.UPGRADE - -# Utility methods -def load_model(run_name, device, archive): - model = torch.load(os.path.join(archive, f"{run_name}.pth"), pickle_module=dill) - model.to('cuda:%s'%device[0]) - return model - -def remove_log10(x, target): - x[target + '_pred'] = 10**x[target + '_pred'] - return x - -def rescale_XYZ(x, target): - x['position_x_pred'] = 764.431509*x['position_x_pred'] - x['position_y_pred'] = 785.041607*x['position_y_pred'] - x['position_z_pred'] = 1083.249944*x['position_z_pred'] - return x - - -def train_and_predict_on_validation_set(target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device, patience = 5): - - # Initialise Weights & Biases (W&B) run - wandb_logger = WandbLogger( - project=f"upgrade-{target}-new-noise-model-GraphSAGE-cleaned", - entity="graphnet-team", - save_dir='./wandb/', - log_model=True, - ) - - try: - del truth[truth.index('interaction_time')] - except ValueError: - # not found in list - pass - - print(f"features: {features}") - print(f"truth: {truth}") - - training_dataloader, validation_dataloader = make_train_validation_dataloader( - db = database, - selection = selection, - pulsemaps = pulsemap, - features = features, - truth = truth, - batch_size = batch_size, - num_workers=num_workers, - ) - - # Building model - detector = IceCubeUpgrade( - graph_builder=KNNGraphBuilder(nb_nearest_neighbours=8), - ) - gnn = DynEdge_V2( - nb_inputs=detector.nb_outputs, - ) - if target == 'zenith': - task = ZenithReconstructionWithKappa( - hidden_size=gnn.nb_outputs, - target_labels=target, - loss_function=VonMisesFisher2DLoss(), - ) - elif target == 'energy': - task = EnergyReconstruction(hidden_size=gnn.nb_outputs, target_labels=target, loss_function=LogCoshLoss()) - elif target == 'track': - task = BinaryClassificationTask(hidden_size=gnn.nb_outputs,target_labels=target,loss_function=BinaryCrossEntropyLoss()) - elif isinstance(target, list): - task = XYZReconstruction(hidden_size=gnn.nb_outputs, target_labels=target, loss_function=XYZWithMaxScaling()) - else: - print('task not found') - - model = Model( - detector=detector, - gnn=gnn, - tasks=[task], - optimizer_class=Adam, - optimizer_kwargs={'lr': 1e-03, 'eps': 1e-03}, - scheduler_class=PiecewiseLinearLR, - scheduler_kwargs={ - 'milestones': [0, len(training_dataloader) / 2, len(training_dataloader) * n_epochs], - 'factors': [1e-2, 1, 1e-02], - }, - scheduler_config={ - 'interval': 'step', - }, - ) - - # Training model - callbacks = [ - EarlyStopping( - monitor='val_loss', - patience=patience, - ), - ProgressBar(), - ] - - trainer = Trainer( - default_root_dir=archive, - gpus=device, - max_epochs=n_epochs, - callbacks=callbacks, - log_every_n_steps=1, - logger=wandb_logger, - ) - - try: - trainer.fit(model, training_dataloader, validation_dataloader) - except KeyboardInterrupt: - print("[ctrl+c] Exiting gracefully.") - pass - - # Saving model - model.save(os.path.join(archive, f"{run_name}.pth")) - model.save_state_dict(os.path.join(archive, f"{run_name}_state_dict.pth")) - -def predict(target,selection, database, pulsemap, batch_size, num_workers, n_epochs, device): - try: - del truth[truth.index('interaction_time')] - except ValueError: - # not found in list - pass - model = load_model(run_name, device, archive) - - device = 'cuda:%s'%device[0] - _, validation_dataloader = make_train_validation_dataloader( - db = database, - selection = selection, - pulsemaps = pulsemap, - features = features, - truth = truth, - batch_size = batch_size, - num_workers=num_workers, - ) - if target in ['zenith', 'azimuth']: - predictor_valid = Predictor( - dataloader=validation_dataloader, - target=target, - device=device, - output_column_names=[target + '_pred', target + '_kappa'], - ) - if target in ['track', 'neutrino']: - predictor_valid = Predictor( - dataloader=validation_dataloader, - target=target, - device=device, - output_column_names=[target + '_pred'], - ) - if target == 'energy': - predictor_valid = Predictor( - dataloader=validation_dataloader, - target=target, - device=device, - output_column_names=[target + '_pred'], - post_processing_method= remove_log10, - ) - if isinstance(target, list): - predictor_valid = Predictor( - dataloader=validation_dataloader, - target=target, - device=device, - output_column_names=['position_x_pred','position_y_pred','position_z_pred'],#,'interaction_time_pred'], - post_processing_method= rescale_XYZ, - ) - - results = predictor_valid(model) - - save_results(database, run_name, results, archive,model) - -# Main function call -if __name__ == "__main__": - - # Run management - archive = "/lustre/hpc/icecube/asogaard/gnn/results" - targets = [['position_x', 'position_y', 'position_z']] # 'track', 'zenith', 'energy', 'XYZ', - batch_size = 256 - database ='/lustre/hpc/icecube/asogaard/data/sqlite/dev_step4_numu_140022_second_run/data/dev_step4_numu_140022_second_run.db' - device = [0] - n_epochs = 30 - num_workers = 40 - patience = 5 - pulsemap = 'SplitInIcePulses_GraphSage_Pulses' - - # Common variables - for target in targets: - if target == 'track': - selection = pd.read_csv('/lustre/hpc/icecube/asogaard/data/sqlite/dev_step4_numu_140022_second_run/selection/even_track_cascade_over20pulses.csv')['event_no'].values.ravel().tolist() - else: - selection = pd.read_csv('/lustre/hpc/icecube/asogaard/data/sqlite/dev_step4_numu_140022_second_run/selection/over20pulses.csv')['event_no'].values.ravel().tolist() - - if isinstance(target, list): - target_name = 'XYZ' - else: - target_name = target - - run_name = "upgrade_{}_regression_GraphSagePulses".format(target_name) - - train_and_predict_on_validation_set(target, selection, database, pulsemap, batch_size, num_workers, n_epochs, device, patience) - predict(target, selection, database, pulsemap, batch_size, num_workers, n_epochs, device)