Skip to content

Commit

Permalink
New virial dataset (#313)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jul 24, 2024
1 parent b2fc734 commit 85fd06d
Show file tree
Hide file tree
Showing 12 changed files with 677 additions and 1,808 deletions.
4 changes: 2 additions & 2 deletions src/metatrain/experimental/alchemical_model/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz")

ALCHEMICAL_DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/alchemical_reduced_10.xyz"
QM9_DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz"
)

DEFAULT_HYPERS = get_default_hypers("experimental.alchemical_model")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,22 @@
from metatrain.utils.data import DatasetInfo, TargetInfo, TargetInfoDict, read_systems
from metatrain.utils.neighbor_lists import get_system_with_neighbor_lists

from . import ALCHEMICAL_DATASET_PATH, MODEL_HYPERS
from . import MODEL_HYPERS, QM9_DATASET_PATH


random.seed(0)
np.random.seed(0)
torch.manual_seed(0)

systems = read_systems(ALCHEMICAL_DATASET_PATH)
systems = read_systems(QM9_DATASET_PATH)
systems = [system.to(torch.float32) for system in systems]
nl_options = NeighborListOptions(
cutoff=5.0,
full_list=True,
)
systems = [get_system_with_neighbor_lists(system, [nl_options]) for system in systems]

frames = read(ALCHEMICAL_DATASET_PATH, ":")
frames = read(QM9_DATASET_PATH, ":")
dataset = AtomisticDataset(
frames,
target_properties=["energies", "forces"],
Expand Down
4 changes: 1 addition & 3 deletions src/metatrain/experimental/pet/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/alchemical_reduced_10.xyz"
)
DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz")
19 changes: 16 additions & 3 deletions src/metatrain/utils/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def __init__(
for name, metrics_dict in zip(names, initial_metrics):
for key, value in metrics_dict.items():
target_name = key.split(" ", 1)[0]
if "loss" in key:
if key == "loss":
# losses will be printed in scientific notation
continue
unit = self._get_units(target_name)
Expand Down Expand Up @@ -110,7 +110,8 @@ def log(
logging_string = f"Epoch {epoch:4}"

for name, metrics_dict in zip(self.names, metrics):
for key, value in metrics_dict.items():
for key in _sort_metric_names(metrics_dict.keys()):
value = metrics_dict[key]

new_key = key
if key != "loss": # special case: not a metric associated with a target
Expand All @@ -122,7 +123,7 @@ def log(
logging_string += f", {new_key}: "
else:
logging_string += f", {name} {new_key}: "
if "loss" in key: # print losses with scientific notation
if key == "loss": # print losses with scientific notation
logging_string += f"{value:.3e}"
else:
unit = self._get_units(target_name)
Expand Down Expand Up @@ -264,3 +265,15 @@ def get_cli_input(argv: Optional[List[str]] = None) -> str:
# Add additional quotes for connected arguments.
arguments = [f'"{arg}"' if " " in arg else arg for arg in argv[1:]]
return f"{program_name} {' '.join(arguments)}"


def _sort_metric_names(name_list):
name_list = list(name_list)
sorted_name_list = []
if "loss" in name_list:
# loss goes first
loss_index = name_list.index("loss")
sorted_name_list.append(name_list.pop(loss_index))
# then alphabetical order
sorted_name_list.extend(sorted(name_list))
return sorted_name_list
1 change: 1 addition & 0 deletions tests/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

DATASET_PATH_QM9 = RESOURCES_PATH / "qm9_reduced_100.xyz"
DATASET_PATH_ETHANOL = RESOURCES_PATH / "ethanol_reduced_100.xyz"
DATASET_PATH_CARBON = RESOURCES_PATH / "carbon_reduced_100.xyz"
EVAL_OPTIONS_PATH = RESOURCES_PATH / "eval.yaml"
MODEL_PATH = RESOURCES_PATH / "model-32-bit.pt"
MODEL_PATH_64_BIT = RESOURCES_PATH / "model-64-bit.ckpt"
Expand Down
41 changes: 41 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from metatrain.utils.errors import ArchitectureError

from . import (
DATASET_PATH_CARBON,
DATASET_PATH_ETHANOL,
DATASET_PATH_QM9,
MODEL_PATH_64_BIT,
Expand Down Expand Up @@ -510,3 +511,43 @@ def test_train_issue_290(monkeypatch, tmp_path):
options["test_set"] = 0.85

train_model(options)


def test_train_log_order(caplog, monkeypatch, tmp_path, options):
"""Tests that the log is always printed in the same order for forces
and virials."""

monkeypatch.chdir(tmp_path)
shutil.copy(DATASET_PATH_CARBON, "carbon_reduced_100.xyz")

options["architecture"]["training"]["num_epochs"] = 5
options["architecture"]["training"]["log_interval"] = 1

options["training_set"]["systems"]["read_from"] = str(DATASET_PATH_CARBON)
options["training_set"]["targets"]["energy"]["read_from"] = str(DATASET_PATH_CARBON)
options["training_set"]["targets"]["energy"]["key"] = "energy"
options["training_set"]["targets"]["energy"]["forces"] = {
"key": "force",
}
options["training_set"]["targets"]["energy"]["virial"] = True

caplog.set_level(logging.INFO)
train_model(options)
log_test = caplog.text

# find all the lines that have "Epoch" in them; these are the lines that
# contain the training metrics
epoch_lines = [line for line in log_test.split("\n") if "Epoch" in line]

# check that "training forces RMSE" comes before "training virial RMSE"
# in every line
for line in epoch_lines:
force_index = line.index("training forces RMSE")
virial_index = line.index("training virial RMSE")
assert force_index < virial_index

# same for validation
for line in epoch_lines:
force_index = line.index("validation forces RMSE")
virial_index = line.index("validation virial RMSE")
assert force_index < virial_index
Loading

0 comments on commit 85fd06d

Please sign in to comment.