From b269f6afc07fa55eb4b28ef9d796667a770d627a Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 24 Jan 2024 13:08:24 +0100 Subject: [PATCH] Add length_units and rename variables --- src/metatensor/models/cli/train_model.py | 79 ++++++++++++------------ src/metatensor/models/utils/omegaconf.py | 1 + tests/utils/test_omegaconf.py | 5 +- 3 files changed, 45 insertions(+), 40 deletions(-) diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index cfd236edf..62dc3a0cc 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -96,50 +96,50 @@ def train_model(options: DictConfig) -> None: generator = torch.Generator() logger.info("Setting up training set") - conf_training_set = expand_dataset_config(options["training_set"]) - structures_train = read_structures( - filename=conf_training_set["structures"]["read_from"], - fileformat=conf_training_set["structures"]["file_format"], + train_options = expand_dataset_config(options["training_set"]) + train_structures = read_structures( + filename=train_options["structures"]["read_from"], + fileformat=train_options["structures"]["file_format"], ) - targets_train = read_targets(conf_training_set["targets"]) - train_dataset = Dataset(structures_train, targets_train) + train_targets = read_targets(train_options["targets"]) + train_dataset = Dataset(train_structures, train_targets) logger.info("Setting up test set") - conf_test_set = options["test_set"] - if not isinstance(conf_test_set, float): - conf_test_set = expand_dataset_config(conf_test_set) - structures_test = read_structures( - filename=conf_training_set["structures"]["read_from"], - fileformat=conf_training_set["structures"]["file_format"], + test_options = options["test_set"] + if not isinstance(test_options, float): + test_options = expand_dataset_config(test_options) + test_structures = read_structures( + filename=train_options["structures"]["read_from"], + fileformat=train_options["structures"]["file_format"], ) - targets_test = read_targets(conf_test_set["targets"]) - test_dataset = Dataset(structures_test, targets_test) - fraction_test_set = 0.0 + test_targets = read_targets(test_options["targets"]) + test_dataset = Dataset(test_structures, test_targets) + test_fraction = 0.0 else: - if conf_test_set < 0 or conf_test_set >= 1: + if test_options < 0 or test_options >= 1: raise ValueError("Test set split must be between 0 and 1.") - fraction_test_set = conf_test_set + test_fraction = test_options logger.info("Setting up validation set") - conf_validation_set = options["validation_set"] - if not isinstance(conf_validation_set, float): - conf_validation_set = expand_dataset_config(conf_validation_set) - structures_validation = read_structures( - filename=conf_training_set["structures"]["read_from"], - fileformat=conf_training_set["structures"]["file_format"], + validation_options = options["validation_set"] + if not isinstance(validation_options, float): + validation_options = expand_dataset_config(validation_options) + validation_structures = read_structures( + filename=train_options["structures"]["read_from"], + fileformat=train_options["structures"]["file_format"], ) - targets_validation = read_targets(conf_validation_set["targets"]) - validation_dataset = Dataset(structures_validation, targets_validation) - fraction_validation_set = 0.0 + validation_targets = read_targets(validation_options["targets"]) + validation_dataset = Dataset(validation_structures, validation_targets) + validation_fraction = 0.0 else: - if conf_validation_set < 0 or conf_validation_set >= 1: + if validation_options < 0 or validation_options >= 1: raise ValueError("Validation set split must be between 0 and 1.") - fraction_validation_set = conf_validation_set + validation_fraction = validation_options # Split train dataset if requested - if fraction_test_set or fraction_validation_set: - fraction_train_set = 1 - fraction_test_set - fraction_validation_set - if fraction_train_set < 0: + if test_fraction or validation_fraction: + train_fraction = 1 - test_fraction - validation_fraction + if train_fraction < 0: raise ValueError("fraction of the train set is smaller then 0!") # ignore warning of possible empty dataset @@ -148,17 +148,17 @@ def train_model(options: DictConfig) -> None: subsets = torch.utils.data.random_split( dataset=train_dataset, lengths=[ - fraction_train_set, - fraction_test_set, - fraction_validation_set, + train_fraction, + test_fraction, + validation_fraction, ], generator=generator, ) train_dataset = subsets[0] - if fraction_test_set and not fraction_validation_set: + if test_fraction and not validation_fraction: test_dataset = subsets[1] - elif not fraction_validation_set and fraction_validation_set: + elif not validation_fraction and validation_fraction: validation_dataset = subsets[1] else: test_dataset = subsets[1] @@ -184,12 +184,13 @@ def train_model(options: DictConfig) -> None: outputs = { key: ModelOutput( quantity=value["quantity"], - unit=(value["unit"] if value["unit"] is not None else ""), # potential HACK + unit=(value["unit"] if value["unit"] is not None else ""), ) for key, value in options["training_set"]["targets"].items() } + length_unit = train_options["structures"]["length_unit"] model_capabilities = ModelCapabilities( - length_unit="Angstrom", + length_unit=length_unit if length_unit is not None else "", species=all_species, outputs=outputs, ) @@ -203,3 +204,5 @@ def train_model(options: DictConfig) -> None: ) save_model(model, options["output_path"]) + + # TODO: add evaluation of the test set diff --git a/src/metatensor/models/utils/omegaconf.py b/src/metatensor/models/utils/omegaconf.py index 0eedc0321..000475884 100644 --- a/src/metatensor/models/utils/omegaconf.py +++ b/src/metatensor/models/utils/omegaconf.py @@ -29,6 +29,7 @@ def _resolve_single_str(config): "read_from": "${..read_from}", "file_format": "${file_format:}", "key": None, + "length_unit": None, } ) diff --git a/tests/utils/test_omegaconf.py b/tests/utils/test_omegaconf.py index 151f7a324..d8bc38f16 100644 --- a/tests/utils/test_omegaconf.py +++ b/tests/utils/test_omegaconf.py @@ -14,7 +14,7 @@ def test_expand_dataset_config(): file_name = "foo.xyz" file_format = ".xyz" - structure_section = {"read_from": file_name, "unit": "angstrom"} + structure_section = {"read_from": file_name, "length_unit": "angstrom"} target_section = { "quantity": "energy", @@ -31,7 +31,7 @@ def test_expand_dataset_config(): assert conf_expanded["structures"]["read_from"] == file_name assert conf_expanded["structures"]["file_format"] == file_format - assert conf_expanded["structures"]["unit"] == "angstrom" + assert conf_expanded["structures"]["length_unit"] == "angstrom" targets_conf = conf_expanded["targets"] assert len(targets_conf) == 2 @@ -89,6 +89,7 @@ def test_expand_dataset_config_min(): assert conf_expanded["structures"]["read_from"] == file_name assert conf_expanded["structures"]["file_format"] == file_format + assert conf_expanded["structures"]["length_unit"] is None targets_conf = conf_expanded["targets"] assert targets_conf["energy"]["quantity"] == "energy"