Skip to content
This repository has been archived by the owner on Jun 2, 2023. It is now read-only.

Commit

Permalink
Merge pull request #201 from SimonTopp/main
Browse files Browse the repository at this point in the history
Specify train sites to remove from test metrics
  • Loading branch information
SimonTopp authored May 31, 2022
2 parents b25a45e + 8102332 commit 4f1500a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 6 deletions.
13 changes: 10 additions & 3 deletions river_dl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def partition_metrics(
outfile=None,
val_sites=None,
test_sites=None,

train_sites=None,
):
"""
calculate metrics for a certain group (or no group at all) for a given
Expand All @@ -222,8 +222,9 @@ def partition_metrics(
names and dict values are the id values. These are added as columns to the
metrics information
:param outfile: [str] file where the metrics should be written
:param val_sites: [list] sites to exclude from training metrics
:param val_sites: [list] sites to exclude from training and test metrics
:param test_sites: [list] sites to exclude from validation and training metrics
:param train_sites: [list] sites to exclude from test metrics
:return: [pd dataframe] the condensed metrics
"""
var_data = fmt_preds_obs(preds, obs_file, spatial_idx_name,
Expand All @@ -240,6 +241,10 @@ def partition_metrics(
# mask out test sites from val partition
if test_sites and partition=='val':
data = data[~data[spatial_idx_name].isin(test_sites)]
if train_sites and partition=='tst':
data = data[~data[spatial_idx_name].isin(train_sites)]
if val_sites and partition=='tst':
data = data[~data[spatial_idx_name].isin(val_sites)]

if not group:
metrics = calc_metrics(data)
Expand Down Expand Up @@ -286,6 +291,7 @@ def combined_metrics(
pred_tst=None,
val_sites=None,
test_sites=None,
train_sites=None,
spatial_idx_name="seg_id_nat",
time_idx_name="date",
group=None,
Expand Down Expand Up @@ -349,7 +355,8 @@ def combined_metrics(
id_dict=id_dict,
group=group,
val_sites = val_sites,
test_sites = test_sites)
test_sites = test_sites,
train_sites=train_sites)
df_all.extend([metrics])

df_all = pd.concat(df_all, axis=0)
Expand Down
7 changes: 4 additions & 3 deletions river_dl/torch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,19 +238,20 @@ def predict_torch(x_data, model, batch_size):
@param device: [str] cuda or cpu
@return: [tensor] predicted values
"""
device = next(model.parameters()).device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model.to(device)
data = []
for i in range(len(x_data)):
data.append(torch.from_numpy(x_data[i]).float())

dataloader = torch.utils.data.DataLoader(data, batch_size=batch_size, shuffle=False, pin_memory=True)
model.to(device)
model.eval()
predicted = []
for iter, x in enumerate(dataloader):
trainx = x.to(device)
with torch.no_grad():
output = model(trainx.to(device)).cpu()
output = model(trainx).detach().cpu()
predicted.append(output)
predicted = torch.cat(predicted, dim=0)
return predicted
Expand Down

0 comments on commit 4f1500a

Please sign in to comment.