Skip to content

Commit

Permalink
Add parser and update readers for new options.yaml
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Jan 15, 2024
1 parent 51df872 commit 9a359c5
Show file tree
Hide file tree
Showing 25 changed files with 1,184 additions and 234 deletions.
2 changes: 1 addition & 1 deletion docs/src/dev-docs/utils/readers/target.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ Parsers for obtaining information from target files. All readers return a of
:py:class:`metatensor.torch.TensorMap`. The mapping which reader is used for which file
type is stored in

.. autodata:: metatensor.models.utils.data.readers.targets.TARGET_READERS
.. autodata:: metatensor.models.utils.data.readers.targets.ENERGY_READERS

Implemented Readers
-------------------
Expand Down
16 changes: 12 additions & 4 deletions docs/static/options.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
defaults:
- architecture: soap_bpnn # architecture used to train the model
- _self_

# Last position of the _self_ this entry defines that default options will be
# overwritten by this config.

# Section defining the parameters for structure and target data
dataset:
structure_path: "qm9_reduced_100.xyz" # file where the positions are stored
targets_path: "qm9_reduced_100.xyz" # file with target values (i.e energies)
target_value: "U0" # name of the target value in `targets_path`
training_set:
structures: "qm9_reduced_100.xyz" # file where the positions are stored
targets:
energy:
key: "U0" # name of the target value

test_set: 0.1
validation_set: 0.1
3 changes: 0 additions & 3 deletions src/metatensor/models/cli/conf/config.yaml

This file was deleted.

3 changes: 3 additions & 0 deletions src/metatensor/models/cli/conf/dataset/gradient.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
read_from: ${..read_from}
file_format: ${..file_format}
key:
3 changes: 3 additions & 0 deletions src/metatensor/models/cli/conf/dataset/structures.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
read_from:
file_format:
unit:
8 changes: 8 additions & 0 deletions src/metatensor/models/cli/conf/dataset/targets.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
quantity: energy
read_from: ${...structures.read_from}
file_format:
key:
unit:
forces: off
stress: off
virial: off
203 changes: 190 additions & 13 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import argparse
import importlib
import logging
import warnings
from pathlib import Path
from typing import Union

import hydra
import torch
from omegaconf import DictConfig, OmegaConf

from metatensor.models.utils.data import Dataset
Expand Down Expand Up @@ -68,7 +71,105 @@ def _add_train_model_parser(subparser: argparse._SubParsersAction) -> None:
)


@hydra.main(config_path=str(CONFIG_PATH), config_name="config", version_base=None)
def _resolve_single_str(config):
if isinstance(config, str):
return OmegaConf.create({"read_from": config})
else:
return config


def expand_dataset_config(conf: Union[str, DictConfig]) -> DictConfig:
"""Expand a short hand notation in a dataset config to actual format."""
base_conf_structures = OmegaConf.load(CONFIG_PATH / "dataset" / "structures.yaml")
base_conf_target = OmegaConf.load(CONFIG_PATH / "dataset" / "targets.yaml")
base_gradient_conf = OmegaConf.load(CONFIG_PATH / "dataset" / "gradient.yaml")

base_conf_energy = OmegaConf.load(CONFIG_PATH / "dataset" / "targets.yaml")
base_conf_energy["forces"] = base_gradient_conf.copy()
base_conf_energy["stress"] = base_gradient_conf.copy()
base_conf_energy["virial"] = False

if isinstance(conf, str):
read_from = conf
conf = OmegaConf.create(
{"structures": read_from, "targets": {"energy": read_from}}
)

if type(conf["structures"]) is str:
conf["structures"] = _resolve_single_str(conf["structures"])

conf["structures"] = OmegaConf.merge(base_conf_structures, conf["structures"])

if conf["structures"]["file_format"] is None:
conf["structures"]["file_format"] = Path(conf["structures"]["read_from"]).suffix

for target_key, target in conf["targets"].items():
if type(target) is str:
target = _resolve_single_str(target)

# Add default gradients "energy" target section
if target_key == "energy":
# For special case of the "energy" we add the section for force and stress
# gradient by default
target = OmegaConf.merge(base_conf_energy, target)
else:
target = OmegaConf.merge(base_conf_target, target)

if target["key"] is None:
target["key"] = target_key

# Update DictConfig to allow for config node interpolation of variables like
# "read_from"
conf["targets"][target_key] = target

# Check with respect to full config `conf` to avoid errors of node interpolation
if conf["targets"][target_key]["file_format"] is None:
conf["targets"][target_key]["file_format"] = Path(
conf["targets"][target_key]["read_from"]
).suffix

# merge and interpolate possible present gradients with default config
for gradient_key, gradient_conf in conf["targets"][target_key].items():
if (
type(gradient_conf) is str
and Path(gradient_conf).suffix # field is a file with a suffix
and gradient_key not in ["read_from", "file_format"]
):
gradient_conf = _resolve_single_str(gradient_conf)

if isinstance(gradient_conf, DictConfig):
gradient_conf = OmegaConf.merge(base_gradient_conf, gradient_conf)

if gradient_conf["key"] is None:
gradient_conf["key"] = gradient_key

conf["targets"][target_key][gradient_key] = gradient_conf

# If user sets the virial gradient and leaves the stress section untouched
# we disable the by default enabled stress gradients.
base_stress_gradient_conf = base_gradient_conf.copy()
base_stress_gradient_conf["key"] = "stress"

if (
target_key == "energy"
and conf["targets"][target_key]["virial"]
and conf["targets"][target_key]["stress"] == base_stress_gradient_conf
):
conf["targets"][target_key]["stress"] = False

if (
conf["targets"][target_key]["stress"]
and conf["targets"][target_key]["virial"]
):
raise ValueError(
f"Cannot perform training with respect to virials and stress as in "
f"section {target_key}. Set either `virials: off` or `stress: off`."
)

return conf


@hydra.main(config_path=str(CONFIG_PATH), version_base=None)
def train_model(options: DictConfig) -> None:
"""Train an atomistic machine learning model using configurations provided by Hydra.
Expand All @@ -87,13 +188,82 @@ def train_model(options: DictConfig) -> None:
necessary options for dataset preparation, model hyperparameters, and training.
"""

logger.info("Setting up dataset")
structures = read_structures(options["dataset"]["structure_path"])
targets = read_targets(
options["dataset"]["targets_path"],
target_values=options["dataset"]["target_value"],
# TODO load seed from config
generator = torch.Generator()

logger.info("Setting up training set")
conf_training_set = expand_dataset_config(options["training_set"])
structures_train = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
)
dataset = Dataset(structures, targets)
targets_train = read_targets(conf_training_set["targets"])
train_dataset = Dataset(structures_train, targets_train)

logger.info("Setting up test set")
conf_test_set = options["test_set"]
if not isinstance(conf_test_set, float):
conf_test_set = expand_dataset_config(conf_test_set)
structures_test = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
)
targets_test = read_targets(conf_test_set["targets"])
test_dataset = Dataset(structures_test, targets_test)
fraction_test_set = 0.0
else:
if conf_test_set < 0 or conf_test_set >= 1:
raise ValueError("Test set split must be between 0 and 1.")
fraction_test_set = conf_test_set

logger.info("Setting up validation set")
conf_validation_set = options["validation_set"]
if not isinstance(conf_validation_set, float):
conf_validation_set = expand_dataset_config(conf_validation_set)
structures_validation = read_structures(
filename=conf_training_set["structures"]["read_from"],
fileformat=conf_training_set["structures"]["file_format"],
)
targets_validation = read_targets(conf_validation_set["targets"])
validation_dataset = Dataset(structures_validation, targets_validation)
fraction_validation_set = 0.0
else:
if conf_validation_set < 0 or conf_validation_set >= 1:
raise ValueError("Validation set split must be between 0 and 1.")
fraction_validation_set = conf_validation_set

# Split train dataset if requested
if fraction_test_set or fraction_validation_set:
fraction_train_set = 1 - fraction_test_set - fraction_validation_set
if fraction_train_set < 0:
raise ValueError("fraction of the train set is smaller then 0!")

# ignore warning of possible empty dataset
with warnings.catch_warnings():
warnings.simplefilter("ignore")
subsets = torch.utils.data.random_split(
dataset=train_dataset,
lengths=[
fraction_train_set,
fraction_test_set,
fraction_validation_set,
],
generator=generator,
)

train_dataset = subsets[0]
if fraction_test_set and not fraction_validation_set:
test_dataset = subsets[1]
elif not fraction_validation_set and fraction_validation_set:
validation_dataset = subsets[1]
else:
test_dataset = subsets[1]
validation_dataset = subsets[2]

# TODO: Perform section and unit consistency checks between test/train/validation
# set
test_dataset
validation_dataset

logger.info("Setting up model")
architetcure_name = options["architecture"]["name"]
Expand All @@ -102,11 +272,18 @@ def train_model(options: DictConfig) -> None:
logger.info("Run training")
output_dir = hydra.core.hydra_config.HydraConfig.get().runtime.output_dir

print(OmegaConf.to_container(options))
model = architecture.train(
train_dataset=dataset,
hypers=OmegaConf.to_container(options["architecture"]),
output_dir=output_dir,
)
# HACK: Avoid passing a Subset which we can not handle yet. We now
if isinstance(train_dataset, torch.utils.data.Subset):
model = architecture.train(
train_dataset=train_dataset.dataset,
hypers=OmegaConf.to_container(options["architecture"]),
output_dir=output_dir,
)
else:
model = architecture.train(
train_dataset=train_dataset,
hypers=OmegaConf.to_container(options["architecture"]),
output_dir=output_dir,
)

save_model(model, options["output_path"])
21 changes: 15 additions & 6 deletions src/metatensor/models/soap_bpnn/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import rascaline.torch
import torch
from metatensor.torch.atomistic import ModelCapabilities, ModelOutput
from omegaconf import OmegaConf

from metatensor.models.soap_bpnn import DEFAULT_HYPERS, Model, train
from metatensor.models.utils.data import Dataset
Expand Down Expand Up @@ -39,8 +40,6 @@ def test_regression_init():
dtype=torch.float64,
)

print(output["energy"].block().values)

assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3)


Expand All @@ -49,7 +48,19 @@ def test_regression_train():
trained for 2 epoch on a small dataset"""

structures = read_structures(DATASET_PATH)
targets = read_targets(DATASET_PATH, "U0")

conf = {
"energy": {
"quantity": "energy",
"read_from": DATASET_PATH,
"file_format": ".xyz",
"key": "U0",
"forces": False,
"stress": False,
"virial": False,
}
}
targets = read_targets(OmegaConf.create(conf))

dataset = Dataset(structures, targets)

Expand All @@ -65,6 +76,4 @@ def test_regression_train():
dtype=torch.float64,
)

print(output["U0"].block().values)

assert torch.allclose(output["U0"].block().values, expected_output, rtol=1e-3)
assert torch.allclose(output["energy"].block().values, expected_output, rtol=1e-3)
9 changes: 8 additions & 1 deletion src/metatensor/models/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
from .dataset import Dataset, collate_fn # noqa: F401
from .readers import read_structures, read_targets # noqa: F401
from .readers import ( # noqa: F401
read_energy,
read_forces,
read_stress,
read_structures,
read_targets,
read_virial,
)
from .writers import write_predictions # noqa: F401
Loading

0 comments on commit 9a359c5

Please sign in to comment.