Skip to content

Commit

Permalink
Added evaluation function
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Dec 11, 2023
1 parent 5b9b2ad commit 00bdd0f
Show file tree
Hide file tree
Showing 18 changed files with 401 additions and 25 deletions.
21 changes: 13 additions & 8 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@

from . import __version__
from .cli import eval_model, export_model, train_model
from .cli.eval_model import _eval_model_cli


def main():
ap = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

ap.add_argument(
Expand All @@ -21,24 +22,25 @@ def main():
subparser = ap.add_subparsers(help="sub-command help")
evaluate_parser = subparser.add_parser(
"eval",
help=eval_model.__doc__,
help="Evaluate a pretrained model.",
description="eval model",
formatter_class=argparse.RawDescriptionHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
evaluate_parser.set_defaults(callable="eval_model")
_eval_model_cli(evaluate_parser)

export_parser = subparser.add_parser(
"export",
help=export_model.__doc__,
description="export model",
formatter_class=argparse.RawDescriptionHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
export_parser.set_defaults(callable="export_model")
train_parser = subparser.add_parser(
"train",
help=train_model.__doc__,
description="train model",
formatter_class=argparse.RawDescriptionHelpFormatter,
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
train_parser.set_defaults(callable="train_model")

Expand All @@ -56,11 +58,14 @@ def main():
train_model()

else:
ap.parse_args([callable])
args = ap.parse_args()

# Remove `callable`` because it is not an argument of the two functions
args.__dict__.pop("callable")
if callable == "eval":
eval_model()
eval_model(**args.__dict__)
elif callable == "export":
export_model()
export_model(**args.__dict__)


if __name__ == "__main__":
Expand Down
58 changes: 55 additions & 3 deletions src/metatensor/models/cli/eval_model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,55 @@
def eval_model():
"""evaluate a model"""
print("Run evaluate...")
import argparse

from ..utils.data.readers import read_structures
from ..utils.data.writers import write_predictions
from ..utils.model_io import load_model


def _eval_model_cli(parser: argparse.ArgumentParser) -> None:
"""Add the `eval_model` paramaters to an argparse (sub)-parser"""
parser.add_argument(
"-m",
"--model",
dest="model_path",
type=str,
required=True,
help="Path to a saved model",
)
parser.add_argument(
"-s",
"--structure",
dest="structure_path",
type=str,
required=True,
help="Path to a structure file which should be considered for the evaluation.",
)
parser.add_argument(
"-o",
"--output",
dest="output_path",
type=str,
required=False,
default="output.xyz",
help="Path to save the predicted values.",
)


def eval_model(
model_path: str, structure_path: str, output_path: str = "output.xyz"
) -> None:
"""Evaluate a pretrained model.
``target_property`` wil be predicted on a provided set of structures. Predicted
values will be written ``output_path``.
:param model_path: Path to a saved model
:param structure_path: Path to a structure file which should be considered for the
evaluation.
:param output_path: Path to save the predicted values
"""

model = load_model(model_path)
structures = read_structures(structure_path)
predictions = model(structures)
write_predictions(output_path, predictions, structures)
5 changes: 2 additions & 3 deletions src/metatensor/models/cli/train_model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os

import importlib
import logging
import os

import hydra
from omegaconf import DictConfig, OmegaConf
Expand All @@ -23,7 +22,7 @@ def train_model(config: DictConfig) -> None:
structures = read_structures(config["dataset"]["structure_path"])
targets = read_targets(
config["dataset"]["targets_path"],
target_value=config["dataset"]["target_value"],
target_values=config["dataset"]["target_value"],
)
dataset = Dataset(structures, targets)

Expand Down
11 changes: 10 additions & 1 deletion src/metatensor/models/soap_bpnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ def loss_function(predicted, target):

def train(model, train_dataset, hypers=DEFAULT_TRAINING_HYPERS):
# Calculate and set the composition weights:
composition_weights = calculate_composition_weights(train_dataset, "U0")

if len(train_dataset.targets) > 1:
raise ValueError(
f"`train_dataset` contains {len(train_dataset.targets)} targets but we "
"currently only support a single target value!"
)
else:
target = list(train_dataset.targets.keys())[0]

composition_weights = calculate_composition_weights(train_dataset, target)
model.set_composition_weights(composition_weights)

# Create a dataloader for the training dataset:
Expand Down
1 change: 1 addition & 0 deletions src/metatensor/models/utils/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .dataset import Dataset, collate_fn # noqa: F401
from .readers import read_structures, read_targets # noqa: F401
from .writers import write_predictions # noqa: F401
23 changes: 18 additions & 5 deletions src/metatensor/models/utils/data/readers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
""""Readers for structures and target values."""

from typing import List, Dict, Optional
from typing import List, Dict, Optional, Union

from pathlib import Path

Expand All @@ -13,7 +13,13 @@


def read_structures(filename: str, fileformat: Optional[str] = None) -> List[System]:
"""Reads a structure information from file."""
"""Read structure informations from a file.
:param filename: name of the file to read
:param fileformat: format of the structure file. If :py:obj:`None` the format is
determined from the suffix.
:returns: list of structures
"""

if fileformat is None:
fileformat = Path(filename).suffix
Expand All @@ -28,10 +34,17 @@ def read_structures(filename: str, fileformat: Optional[str] = None) -> List[Sys

def read_targets(
filename: str,
target_value: str,
target_values: Union[List[str], str],
fileformat: Optional[str] = None,
) -> Dict[str, TensorMap]:
"""Reads target information from file."""
"""Read target informations from a file.
:param filename: name of the file to read
:param target_values: target values to be parsed from the file.
:param fileformat: format of the target value file. If :py:obj:`None` the format is
determined from the suffix.
:returns: dictionary containing one key per ``target_value``.
"""

if fileformat is None:
fileformat = Path(filename).suffix
Expand All @@ -41,4 +54,4 @@ def read_targets(
except KeyError:
raise ValueError(f"fileformat '{fileformat}' is not supported")

return reader(filename, target_value)
return reader(filename, target_values)
7 changes: 7 additions & 0 deletions src/metatensor/models/utils/data/readers/structures/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,13 @@


def read_ase(filename: str) -> List[System]:
"""Store structure informations using ase.
:param filename: name of the file to read
:returns:
A list of structures
"""
systems = [AseSystem(atoms) for atoms in ase.io.read(filename, ":")]

return systems_to_torch(systems)
7 changes: 5 additions & 2 deletions src/metatensor/models/utils/data/readers/targets/ase.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ def read_ase(
filename: str,
target_values: Union[List[str], str],
) -> Dict[str, TensorMap]:
"""Store target informations from file in a :class:`metatensor.TensorMap`.
"""Store target informations with ase in a :class:`metatensor.TensorMap`.
:param filename: name of the file to read
:param target_values: target values to be parsed from the file.
:returns:
TensorMap containing the given information
Expand All @@ -30,7 +33,7 @@ def read_ase(
values=torch.tensor(values).reshape(-1, 1),
samples=Labels(["structure"], torch.arange(n_structures).reshape(-1, 1)),
components=[],
properties=Labels([target_value], torch.tensor([(0,)])),
properties=Labels(["energy"], torch.tensor([(0,)])),
)

target_dictionary[target_value] = TensorMap(Labels.single(), [block])
Expand Down
41 changes: 41 additions & 0 deletions src/metatensor/models/utils/data/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
""""Writers for predictions"""

from typing import List, Optional

from pathlib import Path
from metatensor.torch import TensorMap
from rascaline.torch.system import System

from .xyz import write_xyz


PREDICTIONS_WRITERS = {".xyz": write_xyz}
""":py:class:`dict`: dictionary mapping file suffixes to a preidction writers"""


def write_predictions(
filename: str,
predictions: TensorMap,
structures: List[System],
fileformat: Optional[str] = None,
) -> None:
"""Writes predictions to a file.
For certain file suffixes also the structures will be written (i.e ``xyz``).
:param filename: name of the file to write
:param predictions: :py:class`metatensor.torch.TensorMap` containinb the predictions
that should be written
:param structures: list of structures that for some writers will also be written
:param fileformat: format of the target value file. If :py:obj:`None` the format is
determined from the suffix.
"""
if fileformat is None:
fileformat = Path(filename).suffix

try:
writer = PREDICTIONS_WRITERS[fileformat]
except KeyError:
raise ValueError(f"fileformat '{fileformat}' is not supported")

return writer(filename, predictions, structures)
22 changes: 22 additions & 0 deletions src/metatensor/models/utils/data/writers/xyz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import List

import ase
import ase.io
import torch
from metatensor.torch import TensorMap
from rascaline.torch.system import System


def write_xyz(filename: str, predictions: TensorMap, structures: List[System]) -> None:
frames = []
for i_system, system in enumerate(structures):
info = {"energy": float(predictions["energy"].block().values[i_system, 0])}
atoms = ase.Atoms(symbols=system.species, positions=system.positions, info=info)

if torch.any(torch.zeros(10) != 0):
atoms.pbc = True
atoms.cell = system.cell

frames.append(atoms)

ase.io.write(filename, frames)
36 changes: 36 additions & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import shutil
import subprocess
from pathlib import Path

import ase.io
import pytest


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"


@pytest.mark.parametrize("output", [None, "foo.xyz"])
def test_eval(output, monkeypatch, tmp_path):
"""Test that training via the training cli runs without an error raise."""
monkeypatch.chdir(tmp_path)
shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz")
shutil.copy(RESOURCES_PATH / "bpnn-model.pt", "bpnn-model.pt")

command = [
"metatensor-models",
"eval",
"-m",
"bpnn-model.pt",
"-s",
"qm9_reduced_100.xyz",
]

if output is not None:
command += ["-o", output]
else:
output = "output.xyz"

subprocess.check_call(command)

frames = ase.io.read(output, ":")
frames[0].info["energy"]
Binary file added tests/resources/bpnn-model.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/resources/parameters.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ architecture:

training:
batch_size: 8
num_epochs: 5
num_epochs: 1

dataset:
structure_path: "qm9_reduced_100.xyz"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)


RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources"
RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / ".." / "resources"


def test_dataset():
Expand Down
Loading

0 comments on commit 00bdd0f

Please sign in to comment.