Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
Merge pull request #200 from janetrbarclay/main
Browse files Browse the repository at this point in the history
adding catch_prop var list
  • Loading branch information
janetrbarclay authored Apr 13, 2022
2 parents 29fab34 + 94db715 commit b25a45e
Showing 1 changed file with 15 additions and 2 deletions.
17 changes: 15 additions & 2 deletions river_dl/preproc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,24 +236,32 @@ def join_catch_properties(x_data_ts, catch_props):
return xr.merge([x_data_ts, ds_catch], join="left")


def prep_catch_props(x_data_ts, catch_prop_file, spatial_idx_name, replace_nan_with_mean=True):
def prep_catch_props(x_data_ts, catch_prop_file, catch_prop_vars, spatial_idx_name, replace_nan_with_mean=True):
"""
read catch property file and join with ts data
:param x_data_ts: [xr dataset] timeseries x-data
:param catch_prop_file: [str] the feather file of catchment attributes
:param catch_prop_vars: [list of str] the catchment attributes to use, if None, all attributes will be kept
:param spatial_idx_name: [str] name of column that is used for spatial
index (e.g., 'seg_id_nat')
:param replace_nan_with_mean: [bool] if true, any nan will be replaced with
the mean of that variable
:return: [xr dataset] merged datasets
"""
df_catch_props = pd.read_feather(catch_prop_file)

#keep only the requested variables
if catch_prop_vars:
catch_prop_vars.append(spatial_idx_name)
df_catch_props = df_catch_props[catch_prop_vars]

# replace nans with column means
if replace_nan_with_mean:
df_catch_props = df_catch_props.apply(
lambda x: x.fillna(x.mean()), axis=0
)
ds_catch_props = df_catch_props.set_index(spatial_idx_name).to_xarray()

return join_catch_properties(x_data_ts, ds_catch_props)


Expand Down Expand Up @@ -759,6 +767,7 @@ def prep_all_data(
dist_idx_name="rowcolnames",
dist_type="updown",
catch_prop_file=None,
catch_prop_vars=None,
exclude_file=None,
log_y_vars=False,
out_file=None,
Expand Down Expand Up @@ -812,6 +821,8 @@ def prep_all_data(
"updown")
:param catch_prop_file: [str] the path to the catchment properties file. If
left unfilled, the catchment properties will not be included as predictors
:param catch_prop_vars: [list of str] list of catchment properties to use. If
left unfilled and a catchment property file is supplied all variables will be used.
:param exclude_file: [str] path to exclude file
:param log_y_vars: [bool] whether or not to take the log of discharge in
training
Expand Down Expand Up @@ -860,7 +871,9 @@ def prep_all_data(
x_data = x_data[x_vars]

if catch_prop_file:
x_data = prep_catch_props(x_data, catch_prop_file, spatial_idx_name)
x_data = prep_catch_props(x_data, catch_prop_file, catch_prop_vars, spatial_idx_name)
#update the list of x_vars
x_vars = list(x_data.data_vars)
# make sure we don't have any weird or missing input values
check_if_finite(x_data)
x_trn, x_val, x_tst = separate_trn_tst(
Expand Down

0 comments on commit b25a45e

Please sign in to comment.