Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Pad train/val/test data #218

Merged
merged 13 commits into from
May 11, 2023
2 changes: 1 addition & 1 deletion river_dl/postproc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def plot_ts_obs_preds(pred_file, obs_file, index_start = 0, index_end=3, outfile
ax.legend()
ax.set_title(seg)
plt.tight_layout()
if out_file:
if outfile:
plt.savefig(outfile)
else:
plt.show()
Expand Down
42 changes: 38 additions & 4 deletions river_dl/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def predict_from_io_data(
tst_val_offset = 1.0,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
trn_latest_time=None,
val_latest_time=None,
tst_latest_time=None
):
"""
make predictions from trained model
Expand All @@ -70,14 +73,34 @@ def predict_from_io_data(
prep
:param trn_offset: [str] value for the training offset
:param tst_val_offset: [str] value for the testing and validation offset
:param trn_latest_time: [str] when specified, the training partition preds will
be trimmed to use trn_latest_time as the last date
:param trn_latest_time: [str] when specified, the validation partition preds will
be trimmed to use val_latest_time as the last date
:param trn_latest_time: [str] when specified, the test partition preds will
be trimmed to use tst_latest_time as the last date
:return: [pd dataframe] predictions
"""
io_data = get_data_if_file(io_data)
if partition == "trn":
keep_portion = trn_offset
else:
if trn_latest_time:
latest_time = trn_latest_time
else:
latest_time = None
elif partition == "val":
keep_portion = tst_val_offset

if val_latest_time:
latest_time = val_latest_time
else:
latest_time = None
elif partition == "tst":
keep_portion = tst_val_offset
if tst_latest_time:
latest_time = tst_latest_time
else:
latest_time = None

preds = predict(
model,
io_data[f"x_{partition}"],
Expand All @@ -91,6 +114,7 @@ def predict_from_io_data(
log_vars=log_vars,
spatial_idx_name=spatial_idx_name,
time_idx_name=time_idx_name,
latest_time=latest_time
)
return preds

Expand All @@ -108,6 +132,7 @@ def predict(
log_vars=False,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
latest_time=None
):
"""
use trained model to make predictions
Expand All @@ -126,6 +151,8 @@ def predict(
:param y_vars:[np array] the variable names of the y_dataset data
:param outfile: [str] the file where the output data should be stored
:param log_vars: [list-like] which variables_to_log (if any) were logged in data
:param latest_time: [str] when provided, the latest time that should be included
in the returned dataframe
prep
:return: out predictions
"""
Expand Down Expand Up @@ -154,8 +181,15 @@ def predict(
pred_dates = pred_dates[:, -frac_seq_len:,...]

y_pred_pp = prepped_array_to_df(y_pred, pred_dates, pred_ids, y_vars, spatial_idx_name, time_idx_name)

y_pred_pp = unscale_output(y_pred_pp, y_stds, y_means, y_vars, log_vars,)

y_pred_pp = unscale_output(y_pred_pp, y_stds, y_means, y_vars, log_vars)

#remove data that were added to fill batches
if latest_time:
y_pred_pp = (y_pred_pp.drop(y_pred_pp[y_pred_pp[time_idx_name] > np.datetime64(latest_time)].index)
.reset_index()
.drop(columns='index')
)

if outfile:
y_pred_pp.to_feather(outfile)
Expand Down
Loading