Skip to content

Commit

Permalink
fix wrong name assignment in parser
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 31, 2024
1 parent e07e2a8 commit 0115dcf
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ def train_model(options: DictConfig) -> None:
if not isinstance(test_options, float):
test_options = expand_dataset_config(test_options)
test_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
filename=test_options["structures"]["read_from"],
fileformat=test_options["structures"]["file_format"],
)
test_targets = read_targets(test_options["targets"])
test_dataset = Dataset(test_structures, test_targets)
Expand All @@ -125,8 +125,8 @@ def train_model(options: DictConfig) -> None:
if not isinstance(validation_options, float):
validation_options = expand_dataset_config(validation_options)
validation_structures = read_structures(
filename=train_options["structures"]["read_from"],
fileformat=train_options["structures"]["file_format"],
filename=validation_options["structures"]["read_from"],
fileformat=validation_options["structures"]["file_format"],
)
validation_targets = read_targets(validation_options["targets"])
validation_dataset = Dataset(validation_structures, validation_targets)
Expand Down
39 changes: 39 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import subprocess
from pathlib import Path

import ase.io
import pytest
from omegaconf import OmegaConf


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"
Expand All @@ -26,6 +28,43 @@ def test_train(monkeypatch, tmp_path, output):
assert Path(output).is_file()


@pytest.mark.parametrize("test_set_file", (True, False))
@pytest.mark.parametrize("validation_set_file", (True, False))
@pytest.mark.parametrize("output", [None, "mymodel.pt"])
def test_train_explicit_validation_test(
monkeypatch, tmp_path, test_set_file, validation_set_file, output
):
"""Test that training via the training cli runs without an error raise
also when the validation and test sets are provided explicitly."""
monkeypatch.chdir(tmp_path)

structures = ase.io.read(RESOURCES_PATH / "qm9_reduced_100.xyz", ":")
options = OmegaConf.load(RESOURCES_PATH / "options.yaml")

ase.io.write("qm9_reduced_100.xyz", structures[:50])

if test_set_file:
ase.io.write("test.xyz", structures[50:80])
options["validation_set"] = options["training_set"].copy()
options["validation_set"]["structures"]["read_from"] = "test.xyz"

if validation_set_file:
ase.io.write("validation.xyz", structures[80:])
options["test_set"] = options["training_set"].copy()
options["test_set"]["structures"]["read_from"] = "validation.xyz"

OmegaConf.save(config=options, f="options.yaml")
command = ["metatensor-models", "train", "options.yaml"]

if output is not None:
command += ["-o", output]
else:
output = "model.pt"

subprocess.check_call(command)
assert Path(output).is_file()


def test_yml_error():
"""Test error raise of the option file is not a .yaml file."""
try:
Expand Down
Binary file removed tests/resources/model.pt
Binary file not shown.

0 comments on commit 0115dcf

Please sign in to comment.