Skip to content

Commit

Permalink
Add API docs and allow bool in grad options
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 16, 2024
1 parent 3893055 commit d8d1e5c
Show file tree
Hide file tree
Showing 15 changed files with 190 additions and 43 deletions.
7 changes: 7 additions & 0 deletions docs/src/dev-docs/cli/eval_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
eval_model
##########

.. automodule:: metatensor.models.cli.eval_model
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/src/dev-docs/cli/export_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
export_model
############

.. automodule:: metatensor.models.cli.export_model
:members:
:undoc-members:
:show-inheritance:
7 changes: 7 additions & 0 deletions docs/src/dev-docs/cli/formatter.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
formatter
#########

.. automodule:: metatensor.models.cli.formatter
:members:
:undoc-members:
:show-inheritance:
20 changes: 20 additions & 0 deletions docs/src/dev-docs/cli/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
CLI API
=======

This is the API for the command line interface ``cli`` functions of
``metatensor-models``.

.. toctree::
:maxdepth: 1

train_model
eval_model
export_model

We provide a custom formatter class for the formatting the help message of the
`argparse` package.

.. toctree::
:maxdepth: 1

formatter
7 changes: 7 additions & 0 deletions docs/src/dev-docs/cli/train_model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
train_model
###########

.. automodule:: metatensor.models.cli.train_model
:members:
:undoc-members:
:show-inheritance:
1 change: 1 addition & 0 deletions docs/src/dev-docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ module.
:maxdepth: 1

adding-models
cli/index
utils/index
27 changes: 21 additions & 6 deletions docs/src/dev-docs/utils/readers/index.rst
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
Structure and Target data Readers
=================================
General Structure and Target data Readers
=========================================

The main entry point for reading structure and target information are the two reader
functions

.. automodule:: metatensor.models.utils.data.readers
:members:
.. autofunction:: metatensor.models.utils.data.read_structures
.. autofunction:: metatensor.models.utils.data.read_targets

Based on the provided filename they chose which child reader they use. For details on
Target type specific readers
----------------------------

:func:`metatensor.models.utils.data.read_targets` uses sub-functions to parse supported
target properties like the `energy` or `forces`. Currently we support reading the
following target properties via

.. autofunction:: metatensor.models.utils.data.read_energy
.. autofunction:: metatensor.models.utils.data.read_forces
.. autofunction:: metatensor.models.utils.data.read_virial
.. autofunction:: metatensor.models.utils.data.read_stress

File type specific readers
--------------------------

Based on the provided `file_format` they chose which sub-reader they use. For details on
these refer to their documentation

.. toctree::
:maxdepth: 1

structure
target
targets
13 changes: 0 additions & 13 deletions docs/src/dev-docs/utils/readers/target.rst

This file was deleted.

63 changes: 63 additions & 0 deletions docs/src/dev-docs/utils/readers/targets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
Target data Readers
###################

Parsers for obtaining target informations from target files. All readers return a
:py:class:`metatensor.torch.TensorBlock`. Currently we support the following target
properties

- :ref:`energy`
- :ref:`forces`
- :ref:`stress`
- :ref:`virial`

The mapping which reader is used for which file type is stored in a dictionary.

.. _energy:

Energy
======

.. autodata:: metatensor.models.utils.data.readers.targets.ENERGY_READERS

Implemented Readers
-------------------

.. autofunction:: metatensor.models.utils.data.readers.targets.read_energy_ase


.. _forces:

Forces
======

.. autodata:: metatensor.models.utils.data.readers.targets.FORCES_READERS

Implemented Readers
-------------------

.. autofunction:: metatensor.models.utils.data.readers.targets.read_forces_ase

.. _stress:

Stress
======

.. autodata:: metatensor.models.utils.data.readers.targets.STRESS_READERS

Implemented Readers
-------------------

.. autofunction:: metatensor.models.utils.data.readers.targets.read_stress_ase

.. _virial:

Virial
======

.. autodata:: metatensor.models.utils.data.readers.targets.VIRIAL_READERS

Implemented Readers
-------------------

.. autofunction:: metatensor.models.utils.data.readers.targets.read_virial_ase

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ dependencies = [
"ase",
"torch",
"hydra-core",
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
#"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
"metatensor-core",
"metatensor-operations",
"metatensor-torch",
Expand Down
3 changes: 3 additions & 0 deletions src/metatensor/models/cli/conf/dataset/gradients_avail.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
forces: off
stress: off
virial: off
3 changes: 0 additions & 3 deletions src/metatensor/models/cli/conf/dataset/targets.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,3 @@ read_from: ${...structures.read_from}
file_format:
key:
unit:
forces: off
stress: off
virial: off
48 changes: 28 additions & 20 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,24 @@ def _resolve_single_str(config):


def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig:
"""Expand a short hand notation in a dataset config to actual format."""
base_conf_structures = OmegaConf.load(CONFIG_PATH / "dataset" / "structures.yaml")
base_conf_target = OmegaConf.load(CONFIG_PATH / "dataset" / "targets.yaml")
base_gradient_conf = OmegaConf.load(CONFIG_PATH / "dataset" / "gradient.yaml")
"""Expand a short hand notation in a dataset config to actual format.
base_conf_energy = OmegaConf.load(CONFIG_PATH / "dataset" / "targets.yaml")
base_conf_energy["forces"] = base_gradient_conf.copy()
base_conf_energy["stress"] = base_gradient_conf.copy()
base_conf_energy["virial"] = False
Currently the config si not checked if all keys are known. Unknown keys can be added
and will be ignored and not deleted."""

conf_path = CONFIG_PATH / "dataset"
base_conf_structures = OmegaConf.load(conf_path / "structures.yaml")
base_conf_target = OmegaConf.load(conf_path / "targets.yaml")
base_conf_gradients_avail = OmegaConf.load(conf_path / "gradients_avail.yaml")
base_conf_gradient = OmegaConf.load(conf_path / "gradient.yaml")

known_gradient_keys = list(base_conf_gradients_avail.keys())

# merge confif to get default configs for energies and other target config.
base_conf_target = OmegaConf.merge(base_conf_target, base_conf_gradients_avail)
base_conf_energy = base_conf_target.copy()
base_conf_energy["forces"] = base_conf_gradient.copy()
base_conf_energy["stress"] = base_conf_gradient.copy()

if isinstance(conf, str):
read_from = conf
Expand Down Expand Up @@ -130,24 +139,23 @@ def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig:

# merge and interpolate possible present gradients with default config
for gradient_key, gradient_conf in conf["targets"][target_key].items():
if (
type(gradient_conf) is str
and Path(gradient_conf).suffix # field is a file with a suffix
and gradient_key not in ["read_from", "file_format"]
):
gradient_conf = _resolve_single_str(gradient_conf)
if gradient_key in known_gradient_keys:
if gradient_key:
gradient_conf = base_conf_gradient.copy()
elif type(gradient_key) is str:
gradient_conf = _resolve_single_str(gradient_conf)

if isinstance(gradient_conf, DictConfig):
gradient_conf = OmegaConf.merge(base_gradient_conf, gradient_conf)
if isinstance(gradient_conf, DictConfig):
gradient_conf = OmegaConf.merge(base_conf_gradient, gradient_conf)

if gradient_conf["key"] is None:
gradient_conf["key"] = gradient_key
if gradient_conf["key"] is None:
gradient_conf["key"] = gradient_key

conf["targets"][target_key][gradient_key] = gradient_conf
conf["targets"][target_key][gradient_key] = gradient_conf

# If user sets the virial gradient and leaves the stress section untouched
# we disable the by default enabled stress gradients.
base_stress_gradient_conf = base_gradient_conf.copy()
base_stress_gradient_conf = base_conf_gradient.copy()
base_stress_gradient_conf["key"] = "stress"

if (
Expand Down
8 changes: 8 additions & 0 deletions src/metatensor/models/utils/data/readers/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def read_virial(
def read_targets(conf: DictConfig) -> Dict[str, TensorMap]:
"""Reading all target information from a fully expanded config.
To get such a config you can use
:func:`metatensor.models.cli.train_model.expand_dataset_config`.
This function uses subfunctions like :func:`read_energy` to parse the requested
target quantity. Currently only `energy` is a supported target property. But, within
the `energy` section gradients such as `forces`, the `stress` or the `virial` can be
added. Other gradients are silentlty irgnored.
:param conf: config containing the keys for what should be read.
:returns: Dictionary containing one TensorMaps for each target section in the
config."""
Expand Down
17 changes: 17 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,20 @@ def test_expand_dataset_config_error():
ValueError, match="Cannot perform training with respect to virials and stress"
):
expand_dataset_config(OmegaConf.create(conf))


def test_expand_dataset_gradient_true():
conf = {
"structures": "foo.xyz",
"targets": {
"energy": {
"virial": True,
"stress": False,
}
},
}

conf_expanded = expand_dataset_config(OmegaConf.create(conf))

assert conf_expanded["targets"]["energy"]["stress"] is False
conf_expanded["targets"]["energy"]["virial"]["read_from"]

0 comments on commit d8d1e5c

Please sign in to comment.