-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5b9b2ad
commit 00bdd0f
Showing
18 changed files
with
401 additions
and
25 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,55 @@ | ||
def eval_model(): | ||
"""evaluate a model""" | ||
print("Run evaluate...") | ||
import argparse | ||
|
||
from ..utils.data.readers import read_structures | ||
from ..utils.data.writers import write_predictions | ||
from ..utils.model_io import load_model | ||
|
||
|
||
def _eval_model_cli(parser: argparse.ArgumentParser) -> None: | ||
"""Add the `eval_model` paramaters to an argparse (sub)-parser""" | ||
parser.add_argument( | ||
"-m", | ||
"--model", | ||
dest="model_path", | ||
type=str, | ||
required=True, | ||
help="Path to a saved model", | ||
) | ||
parser.add_argument( | ||
"-s", | ||
"--structure", | ||
dest="structure_path", | ||
type=str, | ||
required=True, | ||
help="Path to a structure file which should be considered for the evaluation.", | ||
) | ||
parser.add_argument( | ||
"-o", | ||
"--output", | ||
dest="output_path", | ||
type=str, | ||
required=False, | ||
default="output.xyz", | ||
help="Path to save the predicted values.", | ||
) | ||
|
||
|
||
def eval_model( | ||
model_path: str, structure_path: str, output_path: str = "output.xyz" | ||
) -> None: | ||
"""Evaluate a pretrained model. | ||
``target_property`` wil be predicted on a provided set of structures. Predicted | ||
values will be written ``output_path``. | ||
:param model_path: Path to a saved model | ||
:param structure_path: Path to a structure file which should be considered for the | ||
evaluation. | ||
:param output_path: Path to save the predicted values | ||
""" | ||
|
||
model = load_model(model_path) | ||
structures = read_structures(structure_path) | ||
predictions = model(structures) | ||
write_predictions(output_path, predictions, structures) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .dataset import Dataset, collate_fn # noqa: F401 | ||
from .readers import read_structures, read_targets # noqa: F401 | ||
from .writers import write_predictions # noqa: F401 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
""""Writers for predictions""" | ||
|
||
from typing import List, Optional | ||
|
||
from pathlib import Path | ||
from metatensor.torch import TensorMap | ||
from rascaline.torch.system import System | ||
|
||
from .xyz import write_xyz | ||
|
||
|
||
PREDICTIONS_WRITERS = {".xyz": write_xyz} | ||
""":py:class:`dict`: dictionary mapping file suffixes to a preidction writers""" | ||
|
||
|
||
def write_predictions( | ||
filename: str, | ||
predictions: TensorMap, | ||
structures: List[System], | ||
fileformat: Optional[str] = None, | ||
) -> None: | ||
"""Writes predictions to a file. | ||
For certain file suffixes also the structures will be written (i.e ``xyz``). | ||
:param filename: name of the file to write | ||
:param predictions: :py:class`metatensor.torch.TensorMap` containinb the predictions | ||
that should be written | ||
:param structures: list of structures that for some writers will also be written | ||
:param fileformat: format of the target value file. If :py:obj:`None` the format is | ||
determined from the suffix. | ||
""" | ||
if fileformat is None: | ||
fileformat = Path(filename).suffix | ||
|
||
try: | ||
writer = PREDICTIONS_WRITERS[fileformat] | ||
except KeyError: | ||
raise ValueError(f"fileformat '{fileformat}' is not supported") | ||
|
||
return writer(filename, predictions, structures) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from typing import List | ||
|
||
import ase | ||
import ase.io | ||
import torch | ||
from metatensor.torch import TensorMap | ||
from rascaline.torch.system import System | ||
|
||
|
||
def write_xyz(filename: str, predictions: TensorMap, structures: List[System]) -> None: | ||
frames = [] | ||
for i_system, system in enumerate(structures): | ||
info = {"energy": float(predictions["energy"].block().values[i_system, 0])} | ||
atoms = ase.Atoms(symbols=system.species, positions=system.positions, info=info) | ||
|
||
if torch.any(torch.zeros(10) != 0): | ||
atoms.pbc = True | ||
atoms.cell = system.cell | ||
|
||
frames.append(atoms) | ||
|
||
ase.io.write(filename, frames) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import shutil | ||
import subprocess | ||
from pathlib import Path | ||
|
||
import ase.io | ||
import pytest | ||
|
||
|
||
RESOURCES_PATH = Path(__file__).parent.resolve() / ".." / "resources" | ||
|
||
|
||
@pytest.mark.parametrize("output", [None, "foo.xyz"]) | ||
def test_eval(output, monkeypatch, tmp_path): | ||
"""Test that training via the training cli runs without an error raise.""" | ||
monkeypatch.chdir(tmp_path) | ||
shutil.copy(RESOURCES_PATH / "qm9_reduced_100.xyz", "qm9_reduced_100.xyz") | ||
shutil.copy(RESOURCES_PATH / "bpnn-model.pt", "bpnn-model.pt") | ||
|
||
command = [ | ||
"metatensor-models", | ||
"eval", | ||
"-m", | ||
"bpnn-model.pt", | ||
"-s", | ||
"qm9_reduced_100.xyz", | ||
] | ||
|
||
if output is not None: | ||
command += ["-o", output] | ||
else: | ||
output = "output.xyz" | ||
|
||
subprocess.check_call(command) | ||
|
||
frames = ase.io.read(output, ":") | ||
frames[0].info["energy"] |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.