diff --git a/Snakefile_insal_rgcn_pytorch.smk b/Snakefile_insal_rgcn_pytorch.smk new file mode 100644 index 0000000..214a230 --- /dev/null +++ b/Snakefile_insal_rgcn_pytorch.smk @@ -0,0 +1,273 @@ +import os +import numpy as np +import torch +import torch.optim as optim + +#river-dl +code_dir = config['code_dir'] +# if using river_dl installed with pip this is not needed +import sys +sys.path.insert(1, code_dir) + +from river_dl.preproc_utils import asRunConfig +from river_dl.preproc_utils import prep_all_data +from river_dl.torch_utils import train_torch +from river_dl.torch_utils import rmse_masked +from river_dl.evaluate import combined_metrics +from river_dl.torch_models import RGCN_v1 +from river_dl.predict import predict_from_io_data + +#user-defined functions +from utils import * + +out_dir = config['out_dir'] +os.makedirs(out_dir, exist_ok=True) + +#spatial holdout info +train_segs = check_spatial_segs(config['train_segs_f']) +val_segs = check_spatial_segs(config['val_segs_f']) +test_segs = check_spatial_segs(config['test_segs_f']) + +rule all: + input: + f"{out_dir}/finetuned_weights.pth", + f"{out_dir}/finetune_log.csv", + expand("{outdir}/{metric_type}_metrics.csv", + outdir=out_dir, + metric_type=['overall', 'month', 'reach', 'month_reach', 'monthly_all_sites', 'monthly_site_based', 'monthly_reach', 'biweekly_all_sites', 'biweekly_site_based', 'biweekly_reach', 'year', 'year_reach', 'yearly_all_sites', 'yearly_site_based', 'yearly_reach'], + ), + expand("{outdir}/asRunConfig.yml", outdir = out_dir), + expand("{outdir}/Snakefile", outdir = out_dir), + f"{out_dir}/trn_preds_obs.csv", + f"{out_dir}/val_preds_obs.csv", + f"{out_dir}/tst_preds_obs.csv" + +#save the as-run config settings to a text file +rule as_run_config: + group: "prep" + output: + "{outdir}/asRunConfig.yml" + run: + asRunConfig(config, code_dir, output[0]) + +#save the as-run snakefile to the output +rule copy_snakefile: + group: "prep" + output: + "{outdir}/Snakefile" + shell: + """ + scp Snakefile_insal_rgcn_pytorch.smk {output[0]} + """ + +rule prep_io_data: + group: "prep" + input: + config['attrs_file'], + config['obs_file'], + config['dist_matrix_file'] + output: + "{outdir}/prepped.npz" + threads: 3 + run: + prep_all_data( + x_data_file = input[0], + y_data_file = input[1], + distfile = input[2], + x_vars = txt_to_list(config['x_vars_file']), + spatial_idx_name = "PRMS_segid", + time_idx_name = "Date", + y_vars_finetune = config['y_vars'], + train_start_date = txt_to_list(config['train_start_date_f']), + train_end_date = txt_to_list(config['train_end_date_f']), + val_start_date = txt_to_list(config['val_start_date_f']), + val_end_date = txt_to_list(config['val_end_date_f']), + test_start_date = txt_to_list(config['test_start_date_f']), + test_end_date = txt_to_list(config['test_end_date_f']), + val_sites = val_segs, + test_sites = test_segs, + explicit_spatial_partition = True, + earliest_time = txt_to_list(config['earliest_date'])[0], + latest_time = txt_to_list(config['latest_date'])[0], + out_file = output[0], + trn_offset = config['trn_offset'], + tst_val_offset = config['tst_val_offset'], + seq_len = config['seq_len'], + log_y_vars = config['log_y_vars'] + ) + + +rule finetune_train: + group: "train" + input: + "{outdir}/prepped.npz" + output: + "{outdir}/finetuned_weights.pth", + "{outdir}/finetune_log.csv" + threads: 3 + run: + data = np.load(input[0], allow_pickle=True) + num_segs = len(np.unique(data['ids_trn'])) + adj_mx = data['dist_matrix'] + in_dim = len(data['x_vars']) + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model = RGCN_v1(input_dim = in_dim, + hidden_dim = config['hidden_size'], + adj_matrix = adj_mx, + recur_dropout = config['recurrent_dropout'], + dropout = config['dropout'], + device = device, + seed = config['seed']) + opt = optim.Adam(model.parameters(), lr = config['learning_rate']) + train_torch(model, + loss_function = rmse_masked, + optimizer = opt, + x_train = data['x_trn'], + y_train = data['y_obs_trn'], + x_val = data['x_val'], + y_val = data['y_obs_val'], + max_epochs = config['epochs'], + early_stopping_patience = config['early_stopping'], + batch_size = num_segs, + weights_file = output[0], + log_file = output[1], + device = device) + + +rule make_predictions: + input: + "{outdir}/finetuned_weights.pth", + "{outdir}/prepped.npz" + output: + "{outdir}/{partition}_preds.feather" + group: "train_predict_evaluate" + run: + data = np.load(input[1]) + adj_mx = data['dist_matrix'] + in_dim = len(data['x_vars']) + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model = RGCN_v1(input_dim = in_dim, + hidden_dim = config['hidden_size'], + adj_matrix = adj_mx, + recur_dropout = config['recurrent_dropout'], + dropout = config['dropout'], + device = device, + seed = config['seed']) + opt = optim.Adam(model.parameters(), lr = config['learning_rate']) + model.load_state_dict(torch.load(input[0])) + predict_from_io_data(model = model, + io_data = input[1], + partition = wildcards.partition, + outfile = output[0], + trn_offset = config['trn_offset'], + tst_val_offset = config['tst_val_offset'], + spatial_idx_name = "PRMS_segid", + time_idx_name = "Date", + log_vars = config['log_y_vars']) + + +#Order in the list is: +# spatial (bool), temporal (False or timestep to use), time_aggregation (bool), site_based (bool) +def get_grp_arg(wildcards): + if wildcards.metric_type == 'overall': + return [False, False, False, False] + elif wildcards.metric_type == 'month': + return [False, 'M', False, False] + elif wildcards.metric_type == 'reach': + return [True, False, False, False] + elif wildcards.metric_type == 'month_reach': + return [True, 'M', False, False] + elif wildcards.metric_type == 'monthly_site_based': + return [False, 'M', True, True] + elif wildcards.metric_type == 'monthly_all_sites': + return [False, 'M', True, False] + elif wildcards.metric_type == 'monthly_reach': + return [True, 'M', True, False] + elif wildcards.metric_type == 'year': + return [False, 'Y', False, False] + elif wildcards.metric_type == 'year_reach': + return [True, 'Y', False, False] + elif wildcards.metric_type == 'yearly_site_based': + return [False, 'Y', True, True] + elif wildcards.metric_type == 'yearly_all_sites': + return [False, 'Y', True, False] + elif wildcards.metric_type == 'yearly_reach': + return [True, 'Y', True, False] + elif wildcards.metric_type == 'biweekly_site_based': + return [False, '2W', True, True] + elif wildcards.metric_type == 'biweekly_all_sites': + return [False, '2W', True, False] + elif wildcards.metric_type == 'biweekly_reach': + return [True, '2W', True, False] + +#compute performance metrics +rule combine_metrics: + group: 'train_predict_evaluate' + input: + config['obs_file'], + "{outdir}/trn_preds.feather", + "{outdir}/val_preds.feather", + "{outdir}/tst_preds.feather" + output: + "{outdir}/{metric_type}_metrics.csv" + params: + grp_arg = get_grp_arg + run: + combined_metrics(obs_file = input[0], + pred_trn = input[1], + pred_val = input[2], + pred_tst = input[3], + train_sites = train_segs, + val_sites = val_segs, + test_sites = test_segs, + group_spatially = params.grp_arg[0], + group_temporally = params.grp_arg[1], + time_aggregation = params.grp_arg[2], + site_based = params.grp_arg[3], + outfile = output[0], + spatial_idx_name = "PRMS_segid", + time_idx_name = "Date") + +#write prediction+obs files to use in R plot targets +rule write_preds_obs: + group: 'train_predict_evaluate' + input: + config['obs_file'], + "{outdir}/trn_preds.feather", + "{outdir}/val_preds.feather", + "{outdir}/tst_preds.feather" + output: + "{outdir}/trn_preds_obs.csv", + "{outdir}/val_preds_obs.csv", + "{outdir}/tst_preds_obs.csv" + run: + write_preds_obs(pred_file = input[1], + obs_file = input[0], + partition = 'trn', + spatial_idx_name = "PRMS_segid", + time_idx_name = "Date", + filepath = output[0], + spatial = config['spatial_write'], + train_sites = train_segs, + val_sites = val_segs, + test_sites = test_segs) + write_preds_obs(pred_file = input[2], + obs_file = input[0], + partition = 'val', + spatial_idx_name = "PRMS_segid", + time_idx_name = "Date", + filepath = output[1], + spatial = config['spatial_write'], + train_sites = train_segs, + val_sites = val_segs, + test_sites = test_segs) + write_preds_obs(pred_file = input[3], + obs_file = input[0], + partition = 'tst', + spatial_idx_name = "PRMS_segid", + time_idx_name = "Date", + filepath = output[2], + spatial = config['spatial_write'], + train_sites = train_segs, + val_sites = val_segs, + test_sites = test_segs) \ No newline at end of file diff --git a/config_insal_spatial-dyn_rgcn_pytorch.yml b/config_insal_spatial-dyn_rgcn_pytorch.yml new file mode 100644 index 0000000..7e9d3ff --- /dev/null +++ b/config_insal_spatial-dyn_rgcn_pytorch.yml @@ -0,0 +1,59 @@ +runDescription: "DRB Inland Salinity RGCN prediction with spatial holdout - dynamic attributes" + +# Input files +obs_file: "2_process/out/drb_SC_obs_PRMS.zarr" +attrs_file: "2_process/out/drb_attrs_PRMS.zarr" +dist_matrix_file: "1_fetch/out/drb_distance_matrix.npz" +# File containing the attribute names to use within the attrs_file +x_vars_file: "4_predict/out/dynamic_attrs.txt" + +#river-dl code directory +code_dir: "../river-dl" + +# Attribute names within obs_file +y_vars: ['mean_value'] +#Attributs to take the log of for model training. +log_y_vars: False + +# output location +out_dir: "4_predict/out/spatial/RGCN/dynamic_test1" + +#random seed for training False==No seed, otherwise specify the seed +seed: 21023 + +lambdas: [1] + +#Define number of epochs and the early stopping criteria (number of epochs) +epochs: 200 +#setting this to False will turn off early stopping rounds +early_stopping: 10 + + +#Hyperparameter settings +hidden_size: 20 +dropout: 0.2 +recurrent_dropout: 0.2 +learning_rate: 0.005 +trn_offset: 1.0 +tst_val_offset: 1.0 +seq_len: 365 + +# Define the spatial test segments (to be loaded from a file) +# Use 'None' if all segments may be used +# using f instead of file because file triggers preproc_utils to generate metadata that wouldn't exist for 'None' +# None for training means that val and test segs are removed to create the training data. +train_segs_f: 'None' +val_segs_f: '4_predict/out/spatial_val_reaches1.txt' +test_segs_f: '4_predict/out/spatial_test_reaches.txt' +#when True, the final file of predictions and observations are trimmed to the reaches in each of the 3 above files +spatial_write: True + +# Define the earliest date, latest date, training, validation, and testing time periods +earliest_date: "4_predict/out/train_start.txt" +latest_date: "4_predict/out/test_end.txt" +train_start_date_f: "4_predict/out/train_start.txt" +train_end_date_f: "4_predict/out/test_end.txt" +val_start_date_f: "4_predict/out/train_start.txt" +val_end_date_f: "4_predict/out/test_end.txt" +test_start_date_f: "4_predict/out/train_start.txt" +test_end_date_f: "4_predict/out/test_end.txt" \ No newline at end of file diff --git a/config_insal_temporal-minstatdyn_rgcn_pytorch.yml b/config_insal_temporal-minstatdyn_rgcn_pytorch.yml new file mode 100644 index 0000000..d093dcd --- /dev/null +++ b/config_insal_temporal-minstatdyn_rgcn_pytorch.yml @@ -0,0 +1,58 @@ +runDescription: "DRB Inland Salinity RGCN prediction with temporal holdout - min static and dynamic attributes" + +# Input files +obs_file: "2_process/out/drb_SC_obs_PRMS.zarr" +attrs_file: "2_process/out/drb_attrs_PRMS.zarr" +dist_matrix_file: "1_fetch/out/drb_distance_matrix.npz" +# File containing the attribute names to use within the attrs_file +x_vars_file: "4_predict/out/min_static_dynamic_attrs.txt" + +#river-dl code directory +code_dir: "../river-dl" + +# Attribute names within obs_file +y_vars: ['mean_value'] +#Attributs to take the log of for model training. +log_y_vars: False + +# output location +out_dir: "4_predict/out/temporal/RGCN/min_static_dynamic_test1" + +#random seed for training False==No seed, otherwise specify the seed +seed: 21023 + +lambdas: [1] + +#Define number of epochs and the early stopping criteria (number of epochs) +epochs: 200 +#setting this to False will turn off early stopping rounds +early_stopping: 10 + + +#Hyperparameter settings +hidden_size: 20 +dropout: 0.2 +recurrent_dropout: 0.2 +learning_rate: 0.005 +trn_offset: 1.0 +tst_val_offset: 1.0 +seq_len: 365 + +# Define the spatial test segments (to be loaded from a file) +# Use 'None' if all segments may be used +# using f instead of file because file triggers preproc_utils to generate metadata that wouldn't exist for 'None' +train_segs_f: 'None' +val_segs_f: 'None' +test_segs_f: 'None' +#when True, the final file of predictions and observations are trimmed to the reaches in each of the 3 above files +spatial_write: False + +# Define the earliest date, latest date, training, validation, and testing time periods +earliest_date: "4_predict/out/train_start.txt" +latest_date: "4_predict/out/test_end.txt" +train_start_date_f: "4_predict/out/train_start.txt" +train_end_date_f: "4_predict/out/val_end4.txt" +val_start_date_f: "4_predict/out/val_start5.txt" +val_end_date_f: "4_predict/out/train_end.txt" +test_start_date_f: "4_predict/out/test_start.txt" +test_end_date_f: "4_predict/out/test_end.txt" \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..f615134 --- /dev/null +++ b/utils.py @@ -0,0 +1,243 @@ +# -*- coding: utf-8 -*- +""" +Created on Wed Mar 1 10:16:29 2023 + +@author: jsmith +""" +import numpy as np +import pandas as pd +from river_dl.postproc_utils import fmt_preds_obs +from river_dl.evaluate import calc_metrics + +def txt_to_list(file, sep='\n'): + """ + Converts a sing-column txt file to a Python list. + + Parameters + ---------- + file : TYPE str + DESCRIPTION path to a text file with one column whose rows are converted to list elements + sep : TYPE str + DESCRIPTION the separator for each row. Typically new line (default) + + Returns + ------- + Python list with elements corresponding to the rows of file + + """ + + f = open(file, 'r') + r = [line.split(sep)[0] for line in f] + return r + +def check_spatial_segs(segs, sep='\n'): + """ + Checks if segs is None or a file name. + If a file name, it uses txt_to_list to convert to a list and returns the list. + + Parameters + ---------- + segs : TYPE str + DESCRIPTION 'None' or the path to a text file that is used in txt_to_list + sep : TYPE str + DESCRIPTION the separator for each row. Typically new line (default) + + Returns + ------- + None or a Python list with elements corresponding to the rows of the segs file + + """ + + if segs != 'None': + r = txt_to_list(segs, sep) + else: + r = None + return r + +def write_preds_obs(pred_file, obs_file, partition, spatial_idx_name, + time_idx_name, filepath, spatial=False, train_sites=None, + val_sites=None, test_sites=None): + """ + Joins the predictions and observations, and writes the result to a csv file + + Parameters + ---------- + pred_file : TYPE str + DESCRIPTION filepath to the predictions file + obs : TYPE str + DESCRIPTION filepath to the observations file + partition : TYPE str + DESCRIPTION one of 'trn', 'val', or 'tst' + spatial_idx_name : TYPE str + DESCRIPTION name of column that is used for spatial + index (e.g., 'seg_id_nat') + time_idx_name : TYPE str + DESCRIPTION name of column that is used for temporal index + (usually 'time') + filepath : TYPE str + DESCRIPTION path name of the output file + spatial : TYPE bool + DESCRIPTION when True, the pred_file is trimmed to the reaches according + to each of the provided splits + val_sites : TYPE list + sites to exclude from training and test metrics + test_sites : TYPE list + sites to exclude from validation and training metrics + train_sites : TYPE list + sites to exclude from validation and test metrics + + Returns + ------- + None. + Writes a file with columns for the time_idx_name, spatial_idx_name, obs, and pred + """ + var_data = fmt_preds_obs(pred_file, obs_file, + spatial_idx_name, time_idx_name) + + for data_var, data in var_data.items(): + #reset the index so that the time and space indicies are attributes + data.reset_index(inplace=True) + + if spatial: + # mask out validation and test sites from trn partition + if train_sites and partition == 'trn': + # simply use the train sites when specified. + data = data[data[spatial_idx_name].isin(train_sites)] + else: + #check if validation or testing sites are specified + if val_sites and partition == 'trn': + data = data[~data[spatial_idx_name].isin(val_sites)] + if test_sites and partition == 'trn': + data = data[~data[spatial_idx_name].isin(test_sites)] + # mask out training and test sites from val partition + if val_sites and partition == 'val': + data = data[data[spatial_idx_name].isin(val_sites)] + else: + if test_sites and partition=='val': + data = data[~data[spatial_idx_name].isin(test_sites)] + if train_sites and partition=='val': + data = data[~data[spatial_idx_name].isin(train_sites)] + # mask out training and validation sites from val partition + if test_sites and partition == 'tst': + data = data[data[spatial_idx_name].isin(test_sites)] + else: + if train_sites and partition=='tst': + data = data[~data[spatial_idx_name].isin(train_sites)] + if val_sites and partition=='tst': + data = data[~data[spatial_idx_name].isin(val_sites)] + + data.to_csv(filepath) + +def RF_model_metrics(pred_obs_csv, spatial_idx_name, time_idx_name, + group_spatially=False, group_temporally=False, + time_aggregation=False, site_based=False, + outfile=None): + data = pd.read_csv(pred_obs_csv) + data[time_idx_name] = pd.to_datetime(data[time_idx_name]) + data.set_index([time_idx_name, spatial_idx_name], inplace=True) + + var_metrics_list = [] + + if not group_spatially and not group_temporally: + metrics = calc_metrics(data) + # need to convert to dataframe and transpose so it looks like the + # others + metrics = pd.DataFrame(metrics).T + elif group_spatially and not group_temporally: + #note: same as data.groupby(level=spatial_idx_name) + metrics = (data.groupby(pd.Grouper(level=spatial_idx_name)) + .apply(calc_metrics) + .reset_index() + ) + elif not group_spatially and group_temporally: + if time_aggregation: + #performance metrics computed at the group_temporally timestep + #for some reason, no `.` calculations are allowed after .mean(), + #so calc_metrics() is called first. + if site_based: + #create a group_temporally timeseries for each observation site + metrics = calc_metrics(data + #filter the data to remove nans before computing the sum + #so that the same days are being summed in the month. + .dropna() + .groupby([pd.Grouper(level=time_idx_name, freq=group_temporally), + pd.Grouper(level=spatial_idx_name)]) + .mean() + ) + else: + #create a group_temporally timeseries using data from all reaches + data_sum = (data + .dropna() + .groupby(pd.Grouper(level=time_idx_name, freq=group_temporally)) + .mean() + ) + #For some reason, with pd.Grouper the sum is computed as 0 + # on days with no observations. Need to remove these days + # before calculating metrics. Get the indicies with 0 obs: + data_count_0 = np.where(data + #filter the data to remove nans before computing the sum + #so that the same days are being summed in the month. + .dropna() + .groupby(pd.Grouper(level=time_idx_name, freq=group_temporally)) + .count() + .reset_index() + .obs == 0 + )[0] + if len(data_count_0) > 0: + data_sum = data_sum.drop(index=data_sum.index[data_count_0]) + metrics = calc_metrics(data_sum) + metrics = pd.DataFrame(metrics).T + else: + if group_temporally != 'M': + #native timestep performance metrics within the group_temporally groups + #This method will report one row per group_temporally group + # examples: year-month-week would be a group when group_temporally is 'W' + # year would be a group when group_temporally is 'Y' + metrics = (data + .groupby(pd.Grouper(level=time_idx_name, freq=group_temporally)) + .apply(calc_metrics) + .reset_index() + ) + else: + #This method reports one row per calendar month (1-12) + metrics = (data.reset_index() + .groupby(data.reset_index()[time_idx_name].dt.month) + .apply(calc_metrics) + .reset_index() + ) + elif group_spatially and group_temporally: + if time_aggregation: + #performance metrics for each reach computed at the group_temporally timestep + data_calc = (data + .dropna() + .groupby([pd.Grouper(level=time_idx_name, freq=group_temporally), + pd.Grouper(level=spatial_idx_name)]) + .mean() + ) + #unable to apply any other . functions after .mean(). + metrics = (data_calc.groupby(pd.Grouper(level=spatial_idx_name)) + .apply(calc_metrics) + .reset_index() + ) + else: + if group_temporally != 'M': + metrics = (data + .groupby([pd.Grouper(level=time_idx_name, freq=group_temporally), + pd.Grouper(level=spatial_idx_name)]) + .apply(calc_metrics) + .reset_index() + ) + else: + metrics = (data.reset_index() + .groupby([data.reset_index()[time_idx_name].dt.month, spatial_idx_name]) + .apply(calc_metrics) + .reset_index() + ) + + metrics["variable"] = 'mean_value' + metrics["partition"] = 'tst' + var_metrics_list.append(metrics) + var_metrics = pd.concat(var_metrics_list).round(6) + if outfile: + var_metrics.to_csv(outfile, header=True, index=False) + return var_metrics \ No newline at end of file