diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index 62dc3a0cc..23f9a8459 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -109,8 +109,8 @@ def train_model(options: DictConfig) -> None: if not isinstance(test_options, float): test_options = expand_dataset_config(test_options) test_structures = read_structures( - filename=train_options["structures"]["read_from"], - fileformat=train_options["structures"]["file_format"], + filename=test_options["structures"]["read_from"], + fileformat=test_options["structures"]["file_format"], ) test_targets = read_targets(test_options["targets"]) test_dataset = Dataset(test_structures, test_targets) @@ -125,8 +125,8 @@ def train_model(options: DictConfig) -> None: if not isinstance(validation_options, float): validation_options = expand_dataset_config(validation_options) validation_structures = read_structures( - filename=train_options["structures"]["read_from"], - fileformat=train_options["structures"]["file_format"], + filename=validation_options["structures"]["read_from"], + fileformat=validation_options["structures"]["file_format"], ) validation_targets = read_targets(validation_options["targets"]) validation_dataset = Dataset(validation_structures, validation_targets) diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 06566dc0a..4aee7db5a 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -2,7 +2,9 @@ import subprocess from pathlib import Path +import ase.io import pytest +from omegaconf import OmegaConf RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" @@ -26,6 +28,43 @@ def test_train(monkeypatch, tmp_path, output): assert Path(output).is_file() +@pytest.mark.parametrize("test_set_file", (True, False)) +@pytest.mark.parametrize("validation_set_file", (True, False)) +@pytest.mark.parametrize("output", [None, "mymodel.pt"]) +def test_train_explicit_validation_test( + monkeypatch, tmp_path, test_set_file, validation_set_file, output +): + """Test that training via the training cli runs without an error raise + also when the validation and test sets are provided explicitly.""" + monkeypatch.chdir(tmp_path) + + structures = ase.io.read(RESOURCES_PATH / "qm9_reduced_100.xyz", ":") + options = OmegaConf.load(RESOURCES_PATH / "options.yaml") + + ase.io.write("qm9_reduced_100.xyz", structures[:50]) + + if test_set_file: + ase.io.write("test.xyz", structures[50:80]) + options["validation_set"] = options["training_set"].copy() + options["validation_set"]["structures"]["read_from"] = "test.xyz" + + if validation_set_file: + ase.io.write("validation.xyz", structures[80:]) + options["test_set"] = options["training_set"].copy() + options["test_set"]["structures"]["read_from"] = "validation.xyz" + + OmegaConf.save(config=options, f="options.yaml") + command = ["metatensor-models", "train", "options.yaml"] + + if output is not None: + command += ["-o", output] + else: + output = "model.pt" + + subprocess.check_call(command) + assert Path(output).is_file() + + def test_yml_error(): """Test error raise of the option file is not a .yaml file.""" try: diff --git a/tests/resources/model.pt b/tests/resources/model.pt deleted file mode 100644 index a8aa5a3d3..000000000 Binary files a/tests/resources/model.pt and /dev/null differ