Skip to content

Commit

Permalink
Add evaluation timings
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Oct 4, 2024
1 parent 161fd56 commit a59ff7d
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import argparse
import itertools
import logging
import time
from pathlib import Path
from typing import Dict, List, Optional, Union

Expand Down Expand Up @@ -194,6 +195,10 @@ def _eval_targets(
if return_predictions:
all_predictions = []

# Set up timings:
total_time = 0.0
timings_per_atom = []

# Evaluate the model
for batch in dataloader:
systems, batch_targets = batch
Expand All @@ -202,13 +207,21 @@ def _eval_targets(
key: value.to(dtype=dtype, device=device)
for key, value in batch_targets.items()
}

start_time = time.time()

batch_predictions = evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)

if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()

batch_predictions = average_by_num_atoms(
batch_predictions, systems, per_structure_keys=[]
)
Expand All @@ -219,6 +232,10 @@ def _eval_targets(
if return_predictions:
all_predictions.append(batch_predictions)

time_taken = end_time - start_time
total_time += time_taken
timings_per_atom.append(time_taken / sum(len(system) for system in systems))

# Finalize the RMSEs
rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"])
# print the RMSEs with MetricLogger
Expand All @@ -229,6 +246,13 @@ def _eval_targets(
)
metric_logger.log(rmse_values)

# Log timings
mean_time_per_atom = sum(timings_per_atom) / len(timings_per_atom)
logger.info(
f"evaluation time: {total_time:.2f} s "
f"[{1000.0*mean_time_per_atom:.2f} ms per atom]"
)

if return_predictions:
# concatenate the TensorMaps
all_predictions_joined = _concatenate_tensormaps(all_predictions)
Expand Down
2 changes: 2 additions & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
log = "".join([rec.message for rec in caplog.records])
assert "energy RMSE (per atom)" in log
assert "dataset with index" not in log
assert "evaluation time" in log
assert "ms per atom" in log

# Test file is written predictions
frames = ase.io.read("foo.xyz", ":")
Expand Down

0 comments on commit a59ff7d

Please sign in to comment.