diff --git a/docs/src/dev-docs/cli/eval_model.rst b/docs/src/dev-docs/cli/eval_model.rst new file mode 100644 index 000000000..dcec7896d --- /dev/null +++ b/docs/src/dev-docs/cli/eval_model.rst @@ -0,0 +1,7 @@ +eval_model +########## + +.. automodule:: metatensor.models.cli.eval_model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/cli/export_model.rst b/docs/src/dev-docs/cli/export_model.rst new file mode 100644 index 000000000..3facdb13c --- /dev/null +++ b/docs/src/dev-docs/cli/export_model.rst @@ -0,0 +1,7 @@ +export_model +############ + +.. automodule:: metatensor.models.cli.export_model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/cli/formatter.rst b/docs/src/dev-docs/cli/formatter.rst new file mode 100644 index 000000000..bbe577be8 --- /dev/null +++ b/docs/src/dev-docs/cli/formatter.rst @@ -0,0 +1,7 @@ +formatter +######### + +.. automodule:: metatensor.models.cli.formatter + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/cli/index.rst b/docs/src/dev-docs/cli/index.rst new file mode 100644 index 000000000..2141d8228 --- /dev/null +++ b/docs/src/dev-docs/cli/index.rst @@ -0,0 +1,20 @@ +CLI API +======= + +This is the API for the command line interface ``cli`` functions of +``metatensor-models``. + +.. toctree:: + :maxdepth: 1 + + train_model + eval_model + export_model + +We provide a custom formatter class for the formatting the help message of the +`argparse` package. + +.. toctree:: + :maxdepth: 1 + + formatter diff --git a/docs/src/dev-docs/cli/train_model.rst b/docs/src/dev-docs/cli/train_model.rst new file mode 100644 index 000000000..6e9981cbf --- /dev/null +++ b/docs/src/dev-docs/cli/train_model.rst @@ -0,0 +1,7 @@ +train_model +########### + +.. automodule:: metatensor.models.cli.train_model + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/src/dev-docs/index.rst b/docs/src/dev-docs/index.rst index c46ce355d..9f6950d68 100644 --- a/docs/src/dev-docs/index.rst +++ b/docs/src/dev-docs/index.rst @@ -10,4 +10,5 @@ module. :maxdepth: 1 adding-models + cli/index utils/index diff --git a/docs/src/dev-docs/utils/readers/index.rst b/docs/src/dev-docs/utils/readers/index.rst index c221ecfd5..68767ec11 100644 --- a/docs/src/dev-docs/utils/readers/index.rst +++ b/docs/src/dev-docs/utils/readers/index.rst @@ -1,17 +1,32 @@ -Structure and Target data Readers -================================= +General Structure and Target data Readers +========================================= The main entry point for reading structure and target information are the two reader functions -.. automodule:: metatensor.models.utils.data.readers - :members: +.. autofunction:: metatensor.models.utils.data.read_structures +.. autofunction:: metatensor.models.utils.data.read_targets -Based on the provided filename they chose which child reader they use. For details on +Target type specific readers +---------------------------- + +:func:`metatensor.models.utils.data.read_targets` uses sub-functions to parse supported +target properties like the `energy` or `forces`. Currently we support reading the +following target properties via + +.. autofunction:: metatensor.models.utils.data.read_energy +.. autofunction:: metatensor.models.utils.data.read_forces +.. autofunction:: metatensor.models.utils.data.read_virial +.. autofunction:: metatensor.models.utils.data.read_stress + +File type specific readers +-------------------------- + +Based on the provided `file_format` they chose which sub-reader they use. For details on these refer to their documentation .. toctree:: :maxdepth: 1 structure - target + targets diff --git a/docs/src/dev-docs/utils/readers/target.rst b/docs/src/dev-docs/utils/readers/target.rst deleted file mode 100644 index deb4958c9..000000000 --- a/docs/src/dev-docs/utils/readers/target.rst +++ /dev/null @@ -1,13 +0,0 @@ -Target data Reader -################## - -Parsers for obtaining information from target files. All readers return a of -:py:class:`metatensor.torch.TensorMap`. The mapping which reader is used for which file -type is stored in - -.. autodata:: metatensor.models.utils.data.readers.targets.ENERGY_READERS - -Implemented Readers -------------------- - -.. autofunction:: metatensor.models.utils.data.readers.targets.read_energy_ase diff --git a/docs/src/dev-docs/utils/readers/targets.rst b/docs/src/dev-docs/utils/readers/targets.rst new file mode 100644 index 000000000..6fd1d4e87 --- /dev/null +++ b/docs/src/dev-docs/utils/readers/targets.rst @@ -0,0 +1,63 @@ +Target data Readers +################### + +Parsers for obtaining target informations from target files. All readers return a +:py:class:`metatensor.torch.TensorBlock`. Currently we support the following target +properties + +- :ref:`energy` +- :ref:`forces` +- :ref:`stress` +- :ref:`virial` + +The mapping which reader is used for which file type is stored in a dictionary. + +.. _energy: + +Energy +====== + +.. autodata:: metatensor.models.utils.data.readers.targets.ENERGY_READERS + +Implemented Readers +------------------- + +.. autofunction:: metatensor.models.utils.data.readers.targets.read_energy_ase + + +.. _forces: + +Forces +====== + +.. autodata:: metatensor.models.utils.data.readers.targets.FORCES_READERS + +Implemented Readers +------------------- + +.. autofunction:: metatensor.models.utils.data.readers.targets.read_forces_ase + +.. _stress: + +Stress +====== + +.. autodata:: metatensor.models.utils.data.readers.targets.STRESS_READERS + +Implemented Readers +------------------- + +.. autofunction:: metatensor.models.utils.data.readers.targets.read_stress_ase + +.. _virial: + +Virial +====== + +.. autodata:: metatensor.models.utils.data.readers.targets.VIRIAL_READERS + +Implemented Readers +------------------- + +.. autofunction:: metatensor.models.utils.data.readers.targets.read_virial_ase + diff --git a/pyproject.toml b/pyproject.toml index 12dd7b594..1a250ca7a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ dependencies = [ "ase", "torch", "hydra-core", - "rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", + #"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch", "metatensor-core", "metatensor-operations", "metatensor-torch", diff --git a/src/metatensor/models/cli/conf/dataset/gradients_avail.yaml b/src/metatensor/models/cli/conf/dataset/gradients_avail.yaml new file mode 100644 index 000000000..205c1a040 --- /dev/null +++ b/src/metatensor/models/cli/conf/dataset/gradients_avail.yaml @@ -0,0 +1,3 @@ +forces: off +stress: off +virial: off diff --git a/src/metatensor/models/cli/conf/dataset/targets.yaml b/src/metatensor/models/cli/conf/dataset/targets.yaml index 557ba7f31..2a3c0c095 100644 --- a/src/metatensor/models/cli/conf/dataset/targets.yaml +++ b/src/metatensor/models/cli/conf/dataset/targets.yaml @@ -3,6 +3,3 @@ read_from: ${...structures.read_from} file_format: key: unit: -forces: off -stress: off -virial: off diff --git a/src/metatensor/models/cli/train_model.py b/src/metatensor/models/cli/train_model.py index ffef7c530..7a735c2cb 100644 --- a/src/metatensor/models/cli/train_model.py +++ b/src/metatensor/models/cli/train_model.py @@ -79,15 +79,24 @@ def _resolve_single_str(config): def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig: - """Expand a short hand notation in a dataset config to actual format.""" - base_conf_structures = OmegaConf.load(CONFIG_PATH / "dataset" / "structures.yaml") - base_conf_target = OmegaConf.load(CONFIG_PATH / "dataset" / "targets.yaml") - base_gradient_conf = OmegaConf.load(CONFIG_PATH / "dataset" / "gradient.yaml") + """Expand a short hand notation in a dataset config to actual format. - base_conf_energy = OmegaConf.load(CONFIG_PATH / "dataset" / "targets.yaml") - base_conf_energy["forces"] = base_gradient_conf.copy() - base_conf_energy["stress"] = base_gradient_conf.copy() - base_conf_energy["virial"] = False + Currently the config si not checked if all keys are known. Unknown keys can be added + and will be ignored and not deleted.""" + + conf_path = CONFIG_PATH / "dataset" + base_conf_structures = OmegaConf.load(conf_path / "structures.yaml") + base_conf_target = OmegaConf.load(conf_path / "targets.yaml") + base_conf_gradients_avail = OmegaConf.load(conf_path / "gradients_avail.yaml") + base_conf_gradient = OmegaConf.load(conf_path / "gradient.yaml") + + known_gradient_keys = list(base_conf_gradients_avail.keys()) + + # merge confif to get default configs for energies and other target config. + base_conf_target = OmegaConf.merge(base_conf_target, base_conf_gradients_avail) + base_conf_energy = base_conf_target.copy() + base_conf_energy["forces"] = base_conf_gradient.copy() + base_conf_energy["stress"] = base_conf_gradient.copy() if isinstance(conf, str): read_from = conf @@ -130,24 +139,23 @@ def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig: # merge and interpolate possible present gradients with default config for gradient_key, gradient_conf in conf["targets"][target_key].items(): - if ( - type(gradient_conf) is str - and Path(gradient_conf).suffix # field is a file with a suffix - and gradient_key not in ["read_from", "file_format"] - ): - gradient_conf = _resolve_single_str(gradient_conf) + if gradient_key in known_gradient_keys: + if gradient_key: + gradient_conf = base_conf_gradient.copy() + elif type(gradient_key) is str: + gradient_conf = _resolve_single_str(gradient_conf) - if isinstance(gradient_conf, DictConfig): - gradient_conf = OmegaConf.merge(base_gradient_conf, gradient_conf) + if isinstance(gradient_conf, DictConfig): + gradient_conf = OmegaConf.merge(base_conf_gradient, gradient_conf) - if gradient_conf["key"] is None: - gradient_conf["key"] = gradient_key + if gradient_conf["key"] is None: + gradient_conf["key"] = gradient_key - conf["targets"][target_key][gradient_key] = gradient_conf + conf["targets"][target_key][gradient_key] = gradient_conf # If user sets the virial gradient and leaves the stress section untouched # we disable the by default enabled stress gradients. - base_stress_gradient_conf = base_gradient_conf.copy() + base_stress_gradient_conf = base_conf_gradient.copy() base_stress_gradient_conf["key"] = "stress" if ( diff --git a/src/metatensor/models/utils/data/readers/readers.py b/src/metatensor/models/utils/data/readers/readers.py index 7e46b5f98..5dc2d5e6a 100644 --- a/src/metatensor/models/utils/data/readers/readers.py +++ b/src/metatensor/models/utils/data/readers/readers.py @@ -131,6 +131,14 @@ def read_virial( def read_targets(conf: DictConfig) -> Dict[str, TensorMap]: """Reading all target information from a fully expanded config. + To get such a config you can use + :func:`metatensor.models.cli.train_model.expand_dataset_config`. + + This function uses subfunctions like :func:`read_energy` to parse the requested + target quantity. Currently only `energy` is a supported target property. But, within + the `energy` section gradients such as `forces`, the `stress` or the `virial` can be + added. Other gradients are silentlty irgnored. + :param conf: config containing the keys for what should be read. :returns: Dictionary containing one TensorMaps for each target section in the config.""" diff --git a/tests/cli/test_train_model.py b/tests/cli/test_train_model.py index a3428e9e3..0e6ab4211 100644 --- a/tests/cli/test_train_model.py +++ b/tests/cli/test_train_model.py @@ -169,3 +169,20 @@ def test_expand_dataset_config_error(): ValueError, match="Cannot perform training with respect to virials and stress" ): expand_dataset_config(OmegaConf.create(conf)) + + +def test_expand_dataset_gradient_true(): + conf = { + "structures": "foo.xyz", + "targets": { + "energy": { + "virial": True, + "stress": False, + } + }, + } + + conf_expanded = expand_dataset_config(OmegaConf.create(conf)) + + assert conf_expanded["targets"]["energy"]["stress"] is False + conf_expanded["targets"]["energy"]["virial"]["read_from"]