From 4661be20f9a66321877ffa66498a1a3a18b17f77 Mon Sep 17 00:00:00 2001 From: Philip Loche Date: Wed, 17 Jan 2024 10:18:58 +0100 Subject: [PATCH] code cleanup and proofread of docs --- docs/src/dev-docs/utils/index.rst | 1 + docs/src/dev-docs/utils/omegaconf.rst | 10 ++ docs/src/dev-docs/utils/readers/index.rst | 4 +- .../getting-started/custom_dataset_conf.rst | 25 +-- .../models/cli/conf/dataset/gradient.yaml | 3 - .../cli/conf/dataset/gradients_avail.yaml | 3 - .../models/cli/conf/dataset/structures.yaml | 3 - .../models/cli/conf/dataset/targets.yaml | 5 - src/metatensor/models/cli/train_model.py | 132 +-------------- .../models/utils/data/readers/readers.py | 2 +- src/metatensor/models/utils/omegaconf.py | 155 ++++++++++++++++++ tests/cli/test_train_model.py | 137 ---------------- tests/utils/test_omegaconf.py | 146 +++++++++++++++++ 13 files changed, 332 insertions(+), 294 deletions(-) create mode 100644 docs/src/dev-docs/utils/omegaconf.rst delete mode 100644 src/metatensor/models/cli/conf/dataset/gradient.yaml delete mode 100644 src/metatensor/models/cli/conf/dataset/gradients_avail.yaml delete mode 100644 src/metatensor/models/cli/conf/dataset/structures.yaml delete mode 100644 src/metatensor/models/cli/conf/dataset/targets.yaml create mode 100644 src/metatensor/models/utils/omegaconf.py create mode 100644 tests/utils/test_omegaconf.py diff --git a/docs/src/dev-docs/utils/index.rst b/docs/src/dev-docs/utils/index.rst index 7b7852442..bfcddaea1 100644 --- a/docs/src/dev-docs/utils/index.rst +++ b/docs/src/dev-docs/utils/index.rst @@ -10,3 +10,4 @@ This is the API for the ``utils`` module of ``metatensor-models``. readers/index writers model-io + omegaconf diff --git a/docs/src/dev-docs/utils/omegaconf.rst b/docs/src/dev-docs/utils/omegaconf.rst new file mode 100644 index 000000000..2514daf17 --- /dev/null +++ b/docs/src/dev-docs/utils/omegaconf.rst @@ -0,0 +1,10 @@ +Custom omegaconf functions +========================== + +Resolvers to handle special fields in our configs as well as the expansion/completion of +the dataset section. + +.. automodule:: metatensor.models.utils.omegaconf + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/utils/readers/index.rst b/docs/src/dev-docs/utils/readers/index.rst index 68767ec11..cd7b9a458 100644 --- a/docs/src/dev-docs/utils/readers/index.rst +++ b/docs/src/dev-docs/utils/readers/index.rst @@ -1,5 +1,5 @@ -General Structure and Target data Readers -========================================= +Structure and Target data Readers +================================= The main entry point for reading structure and target information are the two reader functions diff --git a/docs/src/getting-started/custom_dataset_conf.rst b/docs/src/getting-started/custom_dataset_conf.rst index d2e9de108..5dec91fe0 100644 --- a/docs/src/getting-started/custom_dataset_conf.rst +++ b/docs/src/getting-started/custom_dataset_conf.rst @@ -63,7 +63,7 @@ Understanding the YAML Block The ``training_set`` is divided into sections ``structures`` and ``targets``: Structures Section ------------------- +^^^^^^^^^^^^^^^^^^ Describes the structure data like positions and cell information. :param read_from: The file containing structure data. @@ -76,16 +76,17 @@ A single string in this section automatically expands, using the string as the .. note:: - Metatensor-models does not convert units during training or evaluation. Units are - necessary for MD simulations. + ``metatensor-models`` does not convert units during training or evaluation. Units are + only required if model should be used to run MD simulations. Targets Section ---------------- +^^^^^^^^^^^^^^^ Allows defining multiple target sections, each with a unique name. -- Commonly, a section named ``energy`` is defined, which is essential for MD - simulations. -- For other target sections, gradients are disabled by default. +- Commonly, a section named ``energy`` should be defined, which is essential for MD + simulations. For this section gradients like `forces` and `stress` are enabled by + default. See :ref:`energy-section` for further details on this section. +- For other target sections, all gradients are disabled by default. Target section parameters include: @@ -105,8 +106,8 @@ A single string in a target section automatically expands, using the string as t .. _gradient-section: -Gradient Sections ------------------ +Gradient Section +^^^^^^^^^^^^^^^^ Each gradient section (like ``forces`` or ``stress``) has similar parameters: :param read_from: The file for gradient data. @@ -115,8 +116,10 @@ Each gradient section (like ``forces`` or ``stress``) has similar parameters: Sections set to ``true`` or ``on`` automatically expand with default parameters. +.. _energy-section: + Energy Section --------------- +^^^^^^^^^^^^^^ The ``energy`` section is mandatory for MD simulations, with forces and stresses enabled by default. @@ -126,4 +129,4 @@ by default. .. note:: - Metatensor-models ignores unknown keys in these sections during dataset parsing. + Unknown keys are ignored and not deleted in all sections during dataset parsing. diff --git a/src/metatensor/models/cli/conf/dataset/gradient.yaml b/src/metatensor/models/cli/conf/dataset/gradient.yaml deleted file mode 100644 index e38b18018..000000000 --- a/src/metatensor/models/cli/conf/dataset/gradient.yaml +++ /dev/null @@ -1,3 +0,0 @@ -read_from: ${..read_from} -file_format: ${..file_format} -key: diff --git a/src/metatensor/models/cli/conf/dataset/gradients_avail.yaml b/src/metatensor/models/cli/conf/dataset/gradients_avail.yaml deleted file mode 100644 index 205c1a040..000000000 --- a/src/metatensor/models/cli/conf/dataset/gradients_avail.yaml +++ /dev/null @@ -1,3 +0,0 @@ -forces: off -stress: off -virial: off diff --git a/src/metatensor/models/cli/conf/dataset/structures.yaml b/src/metatensor/models/cli/conf/dataset/structures.yaml deleted file mode 100644 index 83ac0f1db..000000000 --- a/src/metatensor/models/cli/conf/dataset/structures.yaml +++ /dev/null @@ -1,3 +0,0 @@ -read_from: -file_format: -unit: diff --git a/src/metatensor/models/cli/conf/dataset/targets.yaml b/src/metatensor/models/cli/conf/dataset/targets.yaml deleted file mode 100644 index 2a3c0c095..000000000 --- a/src/metatensor/models/cli/conf/dataset/targets.yaml +++ /dev/null @@ -1,5 +0,0 @@ -quantity: energy -read_from: ${...structures.read_from} -file_format: -key: -unit: diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index 27e077b6b..bfd98167c 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -3,7 +3,6 @@ import logging import warnings from pathlib import Path -from typing import Union import hydra import torch @@ -14,6 +13,7 @@ from .. import CONFIG_PATH from ..utils.model_io import save_model +from ..utils.omegaconf import expand_dataset_config from .formatter import CustomHelpFormatter @@ -71,133 +71,6 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None: ) -def _resolve_single_str(config): - if isinstance(config, str): - return OmegaConf.create({"read_from": config}) - else: - return config - - -def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig: - """Expands shorthand notations in a dataset configuration to their full formats. - - This function takes a dataset configuration, either as a string or a DictConfig, and - expands it into a detailed configuration format. It processes structures, targets, - and gradient sections, setting default values and inferring missing information. - Unknown keys are ignored, allowing for flexibility. - - The function performs the following steps: - - - Loads base configurations for structures, targets, and gradients from predefined - YAML files. - - Merges and interpolates the input configuration with the base configurations. - - Expands shorthand notations like file paths or simple true/false settings to full - dictionary structures. - - Handles special cases, such as the mandatory nature of the 'energy' section for MD - simulations and the mutual exclusivity of 'stress' and 'virial' sections. - - Validates the final expanded configuration, particularly for gradient-related - settings, to ensure consistency and prevent conflicts during training. - - :param conf: The dataset configuration, either as a file path string or a DictConfig - object. - :returns: The fully expanded dataset configuration. - :raises ValueError: If both ``virial`` and ``stress`` sections are enabled in the - 'energy' target, as this is not permissible for training. - """ - - conf_path = CONFIG_PATH / "dataset" - base_conf_structures = OmegaConf.load(conf_path / "structures.yaml") - base_conf_target = OmegaConf.load(conf_path / "targets.yaml") - base_conf_gradients_avail = OmegaConf.load(conf_path / "gradients_avail.yaml") - base_conf_gradient = OmegaConf.load(conf_path / "gradient.yaml") - - known_gradient_keys = list(base_conf_gradients_avail.keys()) - - # merge confif to get default configs for energies and other target config. - base_conf_target = OmegaConf.merge(base_conf_target, base_conf_gradients_avail) - base_conf_energy = base_conf_target.copy() - base_conf_energy["forces"] = base_conf_gradient.copy() - base_conf_energy["stress"] = base_conf_gradient.copy() - - if isinstance(conf, str): - read_from = conf - conf = OmegaConf.create( - {"structures": read_from, "targets": {"energy": read_from}} - ) - - if type(conf["structures"]) is str: - conf["structures"] = _resolve_single_str(conf["structures"]) - - conf["structures"] = OmegaConf.merge(base_conf_structures, conf["structures"]) - - if conf["structures"]["file_format"] is None: - conf["structures"]["file_format"] = Path(conf["structures"]["read_from"]).suffix - - for target_key, target in conf["targets"].items(): - if type(target) is str: - target = _resolve_single_str(target) - - # Add default gradients "energy" target section - if target_key == "energy": - # For special case of the "energy" we add the section for force and stress - # gradient by default - target = OmegaConf.merge(base_conf_energy, target) - else: - target = OmegaConf.merge(base_conf_target, target) - - if target["key"] is None: - target["key"] = target_key - - # Update DictConfig to allow for config node interpolation of variables like - # "read_from" - conf["targets"][target_key] = target - - # Check with respect to full config `conf` to avoid errors of node interpolation - if conf["targets"][target_key]["file_format"] is None: - conf["targets"][target_key]["file_format"] = Path( - conf["targets"][target_key]["read_from"] - ).suffix - - # merge and interpolate possible present gradients with default gradient config - for gradient_key, gradient_conf in conf["targets"][target_key].items(): - if gradient_key in known_gradient_keys: - if gradient_key is True: - gradient_conf = base_conf_gradient.copy() - elif type(gradient_key) is str: - gradient_conf = _resolve_single_str(gradient_conf) - - if isinstance(gradient_conf, DictConfig): - gradient_conf = OmegaConf.merge(base_conf_gradient, gradient_conf) - - if gradient_conf["key"] is None: - gradient_conf["key"] = gradient_key - - conf["targets"][target_key][gradient_key] = gradient_conf - - # If user sets the virial gradient and leaves the stress section untouched, - # we disable the by default enabled stress gradient section. - base_stress_gradient_conf = base_conf_gradient.copy() - base_stress_gradient_conf["key"] = "stress" - - if ( - target_key == "energy" - and conf["targets"][target_key]["virial"] - and conf["targets"][target_key]["stress"] == base_stress_gradient_conf - ): - conf["targets"][target_key]["stress"] = False - - if ( - conf["targets"][target_key]["stress"] - and conf["targets"][target_key]["virial"] - ): - raise ValueError( - f"Cannot perform training with respect to virials and stress as in " - f"section {target_key}. Set either `virials: off` or `stress: off`." - ) - - return conf - - @hydra.main(config_path=str(CONFIG_PATH), version_base=None) def train_model(options: DictConfig) -> None: """Train an atomistic machine learning model using configurations provided by Hydra. @@ -301,7 +174,8 @@ def train_model(options: DictConfig) -> None: logger.info("Run training") output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir - # HACK: Avoid passing a Subset which we can not handle yet. We now + # HACK: Avoid passing a Subset which we can not handle yet. For now we pass + # the complete training set even though it was split before... if isinstance(train_dataset, torch.utils.data.Subset): model = architecture.train( train_dataset=train_dataset.dataset, diff --git a/src/metatensor/models/utils/data/readers/readers.py b/src/metatensor/models/utils/data/readers/readers.py index 5dc2d5e6a..7ffb9f1d4 100644 --- a/src/metatensor/models/utils/data/readers/readers.py +++ b/src/metatensor/models/utils/data/readers/readers.py @@ -132,7 +132,7 @@ def read_targets(conf: DictConfig) -> Dict[str, TensorMap]: """Reading all target information from a fully expanded config. To get such a config you can use - :func:`metatensor.models.cli.train_model.expand_dataset_config`. + :func:`metatensor.models.utils.omegaconf.expand_dataset_config`. This function uses subfunctions like :func:`read_energy` to parse the requested target quantity. Currently only `energy` is a supported target property. But, within diff --git a/src/metatensor/models/utils/omegaconf.py b/src/metatensor/models/utils/omegaconf.py new file mode 100644 index 000000000..0eedc0321 --- /dev/null +++ b/src/metatensor/models/utils/omegaconf.py @@ -0,0 +1,155 @@ +from pathlib import Path +from typing import Union + +from omegaconf import DictConfig, OmegaConf + + +def file_format(_parent_): + """Custom OmegaConf resolver to find the file format. + + File format is obtained based on the suffix of the ``read_from`` field in the same + section.""" + return Path(_parent_["read_from"]).suffix + + +# Register custom resolvers +OmegaConf.register_new_resolver("file_format", file_format) + + +def _resolve_single_str(config): + if isinstance(config, str): + return OmegaConf.create({"read_from": config}) + else: + return config + + +# BASE CONFIGURATIONS +CONF_STRUCTURES = OmegaConf.create( + { + "read_from": "${..read_from}", + "file_format": "${file_format:}", + "key": None, + } +) + +CONF_TARGET_FIELDS = OmegaConf.create( + { + "quantity": "energy", + "read_from": "${...structures.read_from}", + "file_format": "${file_format:}", + "key": None, + "unit": None, + } +) + +CONF_GRADIENTS = OmegaConf.create({"forces": False, "stress": False, "virial": False}) +CONF_GRADIENT = OmegaConf.create( + { + "read_from": "${..read_from}", + "file_format": "${file_format:}", + "key": None, + } +) + +KNWON_GRADIENTS = list(CONF_GRADIENTS.keys()) + +# merge configs to get default configs for energies and other targets +CONF_TARGET = OmegaConf.merge(CONF_TARGET_FIELDS, CONF_GRADIENTS) +CONF_ENERGY = CONF_TARGET.copy() +CONF_ENERGY["forces"] = CONF_GRADIENT.copy() +CONF_ENERGY["stress"] = CONF_GRADIENT.copy() + + +def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig: + """Expands shorthand notations in a dataset configuration to their full formats. + + This function takes a dataset configuration, either as a string or a DictConfig, and + expands it into a detailed configuration format. It processes structures, targets, + and gradient sections, setting default values and inferring missing information. + Unknown keys are ignored, allowing for flexibility. + + The function performs the following steps: + + - Loads base configurations for structures, targets, and gradients from predefined + YAML files. + - Merges and interpolates the input configuration with the base configurations. + - Expands shorthand notations like file paths or simple true/false settings to full + dictionary structures. + - Handles special cases, such as the mandatory nature of the 'energy' section for MD + simulations and the mutual exclusivity of 'stress' and 'virial' sections. + - Validates the final expanded configuration, particularly for gradient-related + settings, to ensure consistency and prevent conflicts during training. + + :param conf: The dataset configuration, either as a file path string or a DictConfig + object. + :returns: The fully expanded dataset configuration. + :raises ValueError: If both ``virial`` and ``stress`` sections are enabled in the + 'energy' target, as this is not permissible for training. + """ + if isinstance(conf, str): + read_from = conf + conf = OmegaConf.create( + {"structures": read_from, "targets": {"energy": read_from}} + ) + + if type(conf["structures"]) is str: + conf["structures"] = _resolve_single_str(conf["structures"]) + + conf["structures"] = OmegaConf.merge(CONF_STRUCTURES, conf["structures"]) + + for target_key, target in conf["targets"].items(): + if type(target) is str: + target = _resolve_single_str(target) + + # Add default gradients "energy" target section + if target_key == "energy": + # For special case of the "energy" we add the section for force and stress + # gradient by default + target = OmegaConf.merge(CONF_ENERGY, target) + else: + target = OmegaConf.merge(CONF_TARGET, target) + + if target["key"] is None: + target["key"] = target_key + + # Update DictConfig to allow for config node interpolation + conf["targets"][target_key] = target + + # merge and interpolate possibly present gradients with default gradient config + for gradient_key, gradient_conf in conf["targets"][target_key].items(): + if gradient_key in KNWON_GRADIENTS: + if gradient_conf is True: + gradient_conf = CONF_GRADIENT.copy() + elif type(gradient_conf) is str: + gradient_conf = _resolve_single_str(gradient_conf) + + if isinstance(gradient_conf, DictConfig): + gradient_conf = OmegaConf.merge(CONF_GRADIENT, gradient_conf) + + if gradient_conf["key"] is None: + gradient_conf["key"] = gradient_key + + conf["targets"][target_key][gradient_key] = gradient_conf + + # If user sets the virial gradient and leaves the stress section untouched, + # we disable the by default enabled stress gradient section. + base_stress_gradient_conf = CONF_GRADIENT.copy() + base_stress_gradient_conf["key"] = "stress" + + if ( + target_key == "energy" + and conf["targets"][target_key]["virial"] + and conf["targets"][target_key]["stress"] == base_stress_gradient_conf + ): + conf["targets"][target_key]["stress"] = False + + if ( + conf["targets"][target_key]["stress"] + and conf["targets"][target_key]["virial"] + ): + raise ValueError( + f"Cannot perform training with respect to virials and stress as in " + f"section {target_key}. Set either `virials: off` or `stress: off`." + ) + + return conf diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index 0e6ab4211..06566dc0a 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -3,9 +3,6 @@ from pathlib import Path import pytest -from omegaconf import OmegaConf - -from metatensor.models.cli.train_model import expand_dataset_config RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" @@ -52,137 +49,3 @@ def test_hydra_arguments(): # TODO: test split of train/test/validation using floats and combinations of these. - - -def test_expand_dataset_config(): - file_name = "foo.xyz" - file_format = ".xyz" - - structure_section = {"read_from": file_name, "unit": "angstrom"} - - target_section = { - "quantity": "energy", - "forces": file_name, - "virial": file_name, - "bar": {"read_from": "my_grad.dat", "key": "foo"}, - } - - conf = { - "structures": structure_section, - "targets": {"energy": target_section, "energy2": target_section}, - } - - conf_expanded = expand_dataset_config(OmegaConf.create(conf)) - - assert conf_expanded["structures"]["read_from"] == file_name - assert conf_expanded["structures"]["file_format"] == file_format - assert conf_expanded["structures"]["unit"] == "angstrom" - - targets_conf = conf_expanded["targets"] - assert len(targets_conf) == 2 - - assert targets_conf["energy"]["quantity"] == "energy" - assert targets_conf["energy"]["read_from"] == file_name - assert targets_conf["energy"]["file_format"] == file_format - assert targets_conf["energy"]["file_format"] == file_format - assert targets_conf["energy"]["key"] == "energy" - assert targets_conf["energy"]["unit"] is None - - for gradient in ["forces", "virial"]: - assert targets_conf["energy"][gradient]["read_from"] == file_name - assert targets_conf["energy"][gradient]["file_format"] == file_format - assert targets_conf["energy"][gradient]["key"] == gradient - - assert targets_conf["energy"]["bar"]["read_from"] == "my_grad.dat" - assert targets_conf["energy"]["bar"]["key"] == "foo" - - # If a virial is parsed as in the conf above the by default enabled section "stress" - # should be disabled automatically - assert targets_conf["energy"]["stress"] is False - - assert targets_conf["energy2"]["key"] == "energy2" - assert targets_conf["energy"]["quantity"] == "energy" - - -def test_expand_dataset_config_not_energy(): - file_name = "foo.xyz" - - structure_section = {"read_from": file_name, "unit": "angstrom"} - - target_section = { - "quantity": "my_dipole_moment", - } - - conf = { - "structures": structure_section, - "targets": {"dipole_moment": target_section}, - } - - conf_expanded = expand_dataset_config(OmegaConf.create(conf)) - - assert conf_expanded["targets"]["dipole_moment"]["key"] == "dipole_moment" - assert conf_expanded["targets"]["dipole_moment"]["quantity"] == "my_dipole_moment" - assert conf_expanded["targets"]["dipole_moment"]["forces"] is False - assert conf_expanded["targets"]["dipole_moment"]["stress"] is False - assert conf_expanded["targets"]["dipole_moment"]["virial"] is False - - -def test_expand_dataset_config_min(): - file_name = "dataset.dat" - file_format = ".dat" - - conf_expanded = expand_dataset_config(file_name) - - assert conf_expanded["structures"]["read_from"] == file_name - assert conf_expanded["structures"]["file_format"] == file_format - - targets_conf = conf_expanded["targets"] - assert targets_conf["energy"]["quantity"] == "energy" - assert targets_conf["energy"]["read_from"] == file_name - assert targets_conf["energy"]["file_format"] == file_format - assert targets_conf["energy"]["file_format"] == file_format - assert targets_conf["energy"]["key"] == "energy" - assert targets_conf["energy"]["unit"] is None - - for gradient in ["forces", "stress"]: - assert targets_conf["energy"][gradient]["read_from"] == file_name - assert targets_conf["energy"][gradient]["file_format"] == file_format - assert targets_conf["energy"][gradient]["key"] == gradient - - assert targets_conf["energy"]["virial"] is False - - -def test_expand_dataset_config_error(): - file_name = "foo.xyz" - - conf = { - "structures": file_name, - "targets": { - "energy": { - "virial": file_name, - "stress": {"read_from": file_name, "key": "foo"}, - } - }, - } - - with pytest.raises( - ValueError, match="Cannot perform training with respect to virials and stress" - ): - expand_dataset_config(OmegaConf.create(conf)) - - -def test_expand_dataset_gradient_true(): - conf = { - "structures": "foo.xyz", - "targets": { - "energy": { - "virial": True, - "stress": False, - } - }, - } - - conf_expanded = expand_dataset_config(OmegaConf.create(conf)) - - assert conf_expanded["targets"]["energy"]["stress"] is False - conf_expanded["targets"]["energy"]["virial"]["read_from"] diff --git a/tests/utils/test_omegaconf.py b/tests/utils/test_omegaconf.py new file mode 100644 index 000000000..151f7a324 --- /dev/null +++ b/tests/utils/test_omegaconf.py @@ -0,0 +1,146 @@ +import pytest +from omegaconf import OmegaConf + +from metatensor.models.utils.omegaconf import expand_dataset_config + + +def test_file_format_resolver(): + conf = OmegaConf.create({"read_from": "foo.xyz", "file_format": "${file_format:}"}) + + assert (conf["file_format"]) == ".xyz" + + +def test_expand_dataset_config(): + file_name = "foo.xyz" + file_format = ".xyz" + + structure_section = {"read_from": file_name, "unit": "angstrom"} + + target_section = { + "quantity": "energy", + "forces": file_name, + "virial": {"read_from": "my_grad.dat", "key": "foo"}, + } + + conf = { + "structures": structure_section, + "targets": {"energy": target_section, "my_target": target_section}, + } + + conf_expanded = expand_dataset_config(OmegaConf.create(conf)) + + assert conf_expanded["structures"]["read_from"] == file_name + assert conf_expanded["structures"]["file_format"] == file_format + assert conf_expanded["structures"]["unit"] == "angstrom" + + targets_conf = conf_expanded["targets"] + assert len(targets_conf) == 2 + + for target_key in ["energy", "my_target"]: + assert targets_conf[target_key]["quantity"] == "energy" + assert targets_conf[target_key]["read_from"] == file_name + assert targets_conf[target_key]["file_format"] == file_format + assert targets_conf[target_key]["file_format"] == file_format + assert targets_conf[target_key]["unit"] is None + + assert targets_conf[target_key]["forces"]["read_from"] == file_name + assert targets_conf[target_key]["forces"]["file_format"] == file_format + assert targets_conf[target_key]["forces"]["key"] == "forces" + + assert targets_conf[target_key]["virial"]["read_from"] == "my_grad.dat" + assert targets_conf[target_key]["virial"]["file_format"] == ".dat" + assert targets_conf[target_key]["virial"]["key"] == "foo" + + assert targets_conf[target_key]["stress"] is False + + # If a virial is parsed as in the conf above the by default enabled section "stress" + # should be disabled automatically + assert targets_conf["energy"]["stress"] is False + + +def test_expand_dataset_config_not_energy(): + file_name = "foo.xyz" + + structure_section = {"read_from": file_name, "unit": "angstrom"} + + target_section = { + "quantity": "my_dipole_moment", + } + + conf = { + "structures": structure_section, + "targets": {"dipole_moment": target_section}, + } + + conf_expanded = expand_dataset_config(OmegaConf.create(conf)) + + assert conf_expanded["targets"]["dipole_moment"]["key"] == "dipole_moment" + assert conf_expanded["targets"]["dipole_moment"]["quantity"] == "my_dipole_moment" + assert conf_expanded["targets"]["dipole_moment"]["forces"] is False + assert conf_expanded["targets"]["dipole_moment"]["stress"] is False + assert conf_expanded["targets"]["dipole_moment"]["virial"] is False + + +def test_expand_dataset_config_min(): + file_name = "dataset.dat" + file_format = ".dat" + + conf_expanded = expand_dataset_config(file_name) + + assert conf_expanded["structures"]["read_from"] == file_name + assert conf_expanded["structures"]["file_format"] == file_format + + targets_conf = conf_expanded["targets"] + assert targets_conf["energy"]["quantity"] == "energy" + assert targets_conf["energy"]["read_from"] == file_name + assert targets_conf["energy"]["file_format"] == file_format + assert targets_conf["energy"]["file_format"] == file_format + assert targets_conf["energy"]["key"] == "energy" + assert targets_conf["energy"]["unit"] is None + + for gradient in ["forces", "stress"]: + assert targets_conf["energy"][gradient]["read_from"] == file_name + assert targets_conf["energy"][gradient]["file_format"] == file_format + assert targets_conf["energy"][gradient]["key"] == gradient + + assert targets_conf["energy"]["virial"] is False + + +def test_expand_dataset_config_error(): + file_name = "foo.xyz" + + conf = { + "structures": file_name, + "targets": { + "energy": { + "virial": file_name, + "stress": {"read_from": file_name, "key": "foo"}, + } + }, + } + + with pytest.raises( + ValueError, match="Cannot perform training with respect to virials and stress" + ): + expand_dataset_config(OmegaConf.create(conf)) + + +def test_expand_dataset_gradient(): + conf = { + "structures": "foo.xyz", + "targets": { + "my_energy": { + "forces": "data.txt", + "virial": True, + "stress": False, + } + }, + } + + conf_expanded = expand_dataset_config(OmegaConf.create(conf)) + + assert conf_expanded["targets"]["my_energy"]["forces"]["read_from"] == "data.txt" + assert conf_expanded["targets"]["my_energy"]["forces"]["file_format"] == ".txt" + + assert conf_expanded["targets"]["my_energy"]["stress"] is False + conf_expanded["targets"]["my_energy"]["virial"]["read_from"]