Skip to content
This repository has been archived by the owner on Jun 30, 2023. It is now read-only.

Commit

Permalink
Merge pull request #231 from jds485/230-RGCN-workflows
Browse files Browse the repository at this point in the history
RGCN workflows
  • Loading branch information
jds485 authored Apr 7, 2023
2 parents 8e292d3 + f442303 commit 9eed240
Show file tree
Hide file tree
Showing 4 changed files with 633 additions and 0 deletions.
273 changes: 273 additions & 0 deletions Snakefile_insal_rgcn_pytorch.smk
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
import os
import numpy as np
import torch
import torch.optim as optim

#river-dl
code_dir = config['code_dir']
# if using river_dl installed with pip this is not needed
import sys
sys.path.insert(1, code_dir)

from river_dl.preproc_utils import asRunConfig
from river_dl.preproc_utils import prep_all_data
from river_dl.torch_utils import train_torch
from river_dl.torch_utils import rmse_masked
from river_dl.evaluate import combined_metrics
from river_dl.torch_models import RGCN_v1
from river_dl.predict import predict_from_io_data

#user-defined functions
from utils import *

out_dir = config['out_dir']
os.makedirs(out_dir, exist_ok=True)

#spatial holdout info
train_segs = check_spatial_segs(config['train_segs_f'])
val_segs = check_spatial_segs(config['val_segs_f'])
test_segs = check_spatial_segs(config['test_segs_f'])

rule all:
input:
f"{out_dir}/finetuned_weights.pth",
f"{out_dir}/finetune_log.csv",
expand("{outdir}/{metric_type}_metrics.csv",
outdir=out_dir,
metric_type=['overall', 'month', 'reach', 'month_reach', 'monthly_all_sites', 'monthly_site_based', 'monthly_reach', 'biweekly_all_sites', 'biweekly_site_based', 'biweekly_reach', 'year', 'year_reach', 'yearly_all_sites', 'yearly_site_based', 'yearly_reach'],
),
expand("{outdir}/asRunConfig.yml", outdir = out_dir),
expand("{outdir}/Snakefile", outdir = out_dir),
f"{out_dir}/trn_preds_obs.csv",
f"{out_dir}/val_preds_obs.csv",
f"{out_dir}/tst_preds_obs.csv"

#save the as-run config settings to a text file
rule as_run_config:
group: "prep"
output:
"{outdir}/asRunConfig.yml"
run:
asRunConfig(config, code_dir, output[0])

#save the as-run snakefile to the output
rule copy_snakefile:
group: "prep"
output:
"{outdir}/Snakefile"
shell:
"""
scp Snakefile_insal_rgcn_pytorch.smk {output[0]}
"""

rule prep_io_data:
group: "prep"
input:
config['attrs_file'],
config['obs_file'],
config['dist_matrix_file']
output:
"{outdir}/prepped.npz"
threads: 3
run:
prep_all_data(
x_data_file = input[0],
y_data_file = input[1],
distfile = input[2],
x_vars = txt_to_list(config['x_vars_file']),
spatial_idx_name = "PRMS_segid",
time_idx_name = "Date",
y_vars_finetune = config['y_vars'],
train_start_date = txt_to_list(config['train_start_date_f']),
train_end_date = txt_to_list(config['train_end_date_f']),
val_start_date = txt_to_list(config['val_start_date_f']),
val_end_date = txt_to_list(config['val_end_date_f']),
test_start_date = txt_to_list(config['test_start_date_f']),
test_end_date = txt_to_list(config['test_end_date_f']),
val_sites = val_segs,
test_sites = test_segs,
explicit_spatial_partition = True,
earliest_time = txt_to_list(config['earliest_date'])[0],
latest_time = txt_to_list(config['latest_date'])[0],
out_file = output[0],
trn_offset = config['trn_offset'],
tst_val_offset = config['tst_val_offset'],
seq_len = config['seq_len'],
log_y_vars = config['log_y_vars']
)


rule finetune_train:
group: "train"
input:
"{outdir}/prepped.npz"
output:
"{outdir}/finetuned_weights.pth",
"{outdir}/finetune_log.csv"
threads: 3
run:
data = np.load(input[0], allow_pickle=True)
num_segs = len(np.unique(data['ids_trn']))
adj_mx = data['dist_matrix']
in_dim = len(data['x_vars'])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = RGCN_v1(input_dim = in_dim,
hidden_dim = config['hidden_size'],
adj_matrix = adj_mx,
recur_dropout = config['recurrent_dropout'],
dropout = config['dropout'],
device = device,
seed = config['seed'])
opt = optim.Adam(model.parameters(), lr = config['learning_rate'])
train_torch(model,
loss_function = rmse_masked,
optimizer = opt,
x_train = data['x_trn'],
y_train = data['y_obs_trn'],
x_val = data['x_val'],
y_val = data['y_obs_val'],
max_epochs = config['epochs'],
early_stopping_patience = config['early_stopping'],
batch_size = num_segs,
weights_file = output[0],
log_file = output[1],
device = device)


rule make_predictions:
input:
"{outdir}/finetuned_weights.pth",
"{outdir}/prepped.npz"
output:
"{outdir}/{partition}_preds.feather"
group: "train_predict_evaluate"
run:
data = np.load(input[1])
adj_mx = data['dist_matrix']
in_dim = len(data['x_vars'])
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = RGCN_v1(input_dim = in_dim,
hidden_dim = config['hidden_size'],
adj_matrix = adj_mx,
recur_dropout = config['recurrent_dropout'],
dropout = config['dropout'],
device = device,
seed = config['seed'])
opt = optim.Adam(model.parameters(), lr = config['learning_rate'])
model.load_state_dict(torch.load(input[0]))
predict_from_io_data(model = model,
io_data = input[1],
partition = wildcards.partition,
outfile = output[0],
trn_offset = config['trn_offset'],
tst_val_offset = config['tst_val_offset'],
spatial_idx_name = "PRMS_segid",
time_idx_name = "Date",
log_vars = config['log_y_vars'])


#Order in the list is:
# spatial (bool), temporal (False or timestep to use), time_aggregation (bool), site_based (bool)
def get_grp_arg(wildcards):
if wildcards.metric_type == 'overall':
return [False, False, False, False]
elif wildcards.metric_type == 'month':
return [False, 'M', False, False]
elif wildcards.metric_type == 'reach':
return [True, False, False, False]
elif wildcards.metric_type == 'month_reach':
return [True, 'M', False, False]
elif wildcards.metric_type == 'monthly_site_based':
return [False, 'M', True, True]
elif wildcards.metric_type == 'monthly_all_sites':
return [False, 'M', True, False]
elif wildcards.metric_type == 'monthly_reach':
return [True, 'M', True, False]
elif wildcards.metric_type == 'year':
return [False, 'Y', False, False]
elif wildcards.metric_type == 'year_reach':
return [True, 'Y', False, False]
elif wildcards.metric_type == 'yearly_site_based':
return [False, 'Y', True, True]
elif wildcards.metric_type == 'yearly_all_sites':
return [False, 'Y', True, False]
elif wildcards.metric_type == 'yearly_reach':
return [True, 'Y', True, False]
elif wildcards.metric_type == 'biweekly_site_based':
return [False, '2W', True, True]
elif wildcards.metric_type == 'biweekly_all_sites':
return [False, '2W', True, False]
elif wildcards.metric_type == 'biweekly_reach':
return [True, '2W', True, False]

#compute performance metrics
rule combine_metrics:
group: 'train_predict_evaluate'
input:
config['obs_file'],
"{outdir}/trn_preds.feather",
"{outdir}/val_preds.feather",
"{outdir}/tst_preds.feather"
output:
"{outdir}/{metric_type}_metrics.csv"
params:
grp_arg = get_grp_arg
run:
combined_metrics(obs_file = input[0],
pred_trn = input[1],
pred_val = input[2],
pred_tst = input[3],
train_sites = train_segs,
val_sites = val_segs,
test_sites = test_segs,
group_spatially = params.grp_arg[0],
group_temporally = params.grp_arg[1],
time_aggregation = params.grp_arg[2],
site_based = params.grp_arg[3],
outfile = output[0],
spatial_idx_name = "PRMS_segid",
time_idx_name = "Date")

#write prediction+obs files to use in R plot targets
rule write_preds_obs:
group: 'train_predict_evaluate'
input:
config['obs_file'],
"{outdir}/trn_preds.feather",
"{outdir}/val_preds.feather",
"{outdir}/tst_preds.feather"
output:
"{outdir}/trn_preds_obs.csv",
"{outdir}/val_preds_obs.csv",
"{outdir}/tst_preds_obs.csv"
run:
write_preds_obs(pred_file = input[1],
obs_file = input[0],
partition = 'trn',
spatial_idx_name = "PRMS_segid",
time_idx_name = "Date",
filepath = output[0],
spatial = config['spatial_write'],
train_sites = train_segs,
val_sites = val_segs,
test_sites = test_segs)
write_preds_obs(pred_file = input[2],
obs_file = input[0],
partition = 'val',
spatial_idx_name = "PRMS_segid",
time_idx_name = "Date",
filepath = output[1],
spatial = config['spatial_write'],
train_sites = train_segs,
val_sites = val_segs,
test_sites = test_segs)
write_preds_obs(pred_file = input[3],
obs_file = input[0],
partition = 'tst',
spatial_idx_name = "PRMS_segid",
time_idx_name = "Date",
filepath = output[2],
spatial = config['spatial_write'],
train_sites = train_segs,
val_sites = val_segs,
test_sites = test_segs)
59 changes: 59 additions & 0 deletions config_insal_spatial-dyn_rgcn_pytorch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
runDescription: "DRB Inland Salinity RGCN prediction with spatial holdout - dynamic attributes"

# Input files
obs_file: "2_process/out/drb_SC_obs_PRMS.zarr"
attrs_file: "2_process/out/drb_attrs_PRMS.zarr"
dist_matrix_file: "1_fetch/out/drb_distance_matrix.npz"
# File containing the attribute names to use within the attrs_file
x_vars_file: "4_predict/out/dynamic_attrs.txt"

#river-dl code directory
code_dir: "../river-dl"

# Attribute names within obs_file
y_vars: ['mean_value']
#Attributs to take the log of for model training.
log_y_vars: False

# output location
out_dir: "4_predict/out/spatial/RGCN/dynamic_test1"

#random seed for training False==No seed, otherwise specify the seed
seed: 21023

lambdas: [1]

#Define number of epochs and the early stopping criteria (number of epochs)
epochs: 200
#setting this to False will turn off early stopping rounds
early_stopping: 10


#Hyperparameter settings
hidden_size: 20
dropout: 0.2
recurrent_dropout: 0.2
learning_rate: 0.005
trn_offset: 1.0
tst_val_offset: 1.0
seq_len: 365

# Define the spatial test segments (to be loaded from a file)
# Use 'None' if all segments may be used
# using f instead of file because file triggers preproc_utils to generate metadata that wouldn't exist for 'None'
# None for training means that val and test segs are removed to create the training data.
train_segs_f: 'None'
val_segs_f: '4_predict/out/spatial_val_reaches1.txt'
test_segs_f: '4_predict/out/spatial_test_reaches.txt'
#when True, the final file of predictions and observations are trimmed to the reaches in each of the 3 above files
spatial_write: True

# Define the earliest date, latest date, training, validation, and testing time periods
earliest_date: "4_predict/out/train_start.txt"
latest_date: "4_predict/out/test_end.txt"
train_start_date_f: "4_predict/out/train_start.txt"
train_end_date_f: "4_predict/out/test_end.txt"
val_start_date_f: "4_predict/out/train_start.txt"
val_end_date_f: "4_predict/out/test_end.txt"
test_start_date_f: "4_predict/out/train_start.txt"
test_end_date_f: "4_predict/out/test_end.txt"
58 changes: 58 additions & 0 deletions config_insal_temporal-minstatdyn_rgcn_pytorch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
runDescription: "DRB Inland Salinity RGCN prediction with temporal holdout - min static and dynamic attributes"

# Input files
obs_file: "2_process/out/drb_SC_obs_PRMS.zarr"
attrs_file: "2_process/out/drb_attrs_PRMS.zarr"
dist_matrix_file: "1_fetch/out/drb_distance_matrix.npz"
# File containing the attribute names to use within the attrs_file
x_vars_file: "4_predict/out/min_static_dynamic_attrs.txt"

#river-dl code directory
code_dir: "../river-dl"

# Attribute names within obs_file
y_vars: ['mean_value']
#Attributs to take the log of for model training.
log_y_vars: False

# output location
out_dir: "4_predict/out/temporal/RGCN/min_static_dynamic_test1"

#random seed for training False==No seed, otherwise specify the seed
seed: 21023

lambdas: [1]

#Define number of epochs and the early stopping criteria (number of epochs)
epochs: 200
#setting this to False will turn off early stopping rounds
early_stopping: 10


#Hyperparameter settings
hidden_size: 20
dropout: 0.2
recurrent_dropout: 0.2
learning_rate: 0.005
trn_offset: 1.0
tst_val_offset: 1.0
seq_len: 365

# Define the spatial test segments (to be loaded from a file)
# Use 'None' if all segments may be used
# using f instead of file because file triggers preproc_utils to generate metadata that wouldn't exist for 'None'
train_segs_f: 'None'
val_segs_f: 'None'
test_segs_f: 'None'
#when True, the final file of predictions and observations are trimmed to the reaches in each of the 3 above files
spatial_write: False

# Define the earliest date, latest date, training, validation, and testing time periods
earliest_date: "4_predict/out/train_start.txt"
latest_date: "4_predict/out/test_end.txt"
train_start_date_f: "4_predict/out/train_start.txt"
train_end_date_f: "4_predict/out/val_end4.txt"
val_start_date_f: "4_predict/out/val_start5.txt"
val_end_date_f: "4_predict/out/train_end.txt"
test_start_date_f: "4_predict/out/test_start.txt"
test_end_date_f: "4_predict/out/test_end.txt"
Loading

0 comments on commit 9eed240

Please sign in to comment.