Skip to content

Commit

Permalink
Add xyz structure and target reader
Browse files Browse the repository at this point in the history
  • Loading branch information
PicoCentauri committed Nov 30, 2023
1 parent aa3c55e commit c7b7f8d
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 3 deletions.
5 changes: 2 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ authors = [
dependencies = [
"ase",
"torch",
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
"metatensor-core",
"metatensor-torch"
]
Expand Down Expand Up @@ -54,9 +55,7 @@ requires = [
build-backend = "setuptools.build_meta"

[project.optional-dependencies]
soap-bpnn = [
"rascaline-torch @ git+https://github.com/luthaf/rascaline#subdirectory=python/rascaline-torch",
]
soap-bpnn = []

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
46 changes: 46 additions & 0 deletions src/metatensor_models/readers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
""""Readers for structures and target values."""

from typing import List, Dict, Optional

from pathlib import Path

from metatensor.torch import TensorMap

from .structures import STRUCTURE_READERS
from .targets import TARGET_READERS

from rascaline.systems import SystemBase


def read_structures(
filename: str, fileformat: Optional[str] = None
) -> List[SystemBase]:
"""Reads a structure information from file."""

if fileformat is None:
fileformat = Path(filename).suffix

try:
reader = STRUCTURE_READERS[fileformat]
except KeyError:
raise ValueError(f"fileformat '{fileformat}' is not supported")

return reader(filename)


def read_targets(
filename: str,
target_value: str,
fileformat: Optional[str] = None,
) -> Dict[str, TensorMap]:
"""Reads target information from file."""

if fileformat is None:
fileformat = Path(filename).suffix

try:
reader = TARGET_READERS[fileformat]
except KeyError:
raise ValueError(f"fileformat '{fileformat}' is not supported")

return reader(filename, target_value)
3 changes: 3 additions & 0 deletions src/metatensor_models/readers/structures/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ase import read_ase

STRUCTURE_READERS = {".xyz": read_ase}
9 changes: 9 additions & 0 deletions src/metatensor_models/readers/structures/ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from typing import List

import ase.io
from rascaline.systems import SystemBase
from rascaline.systems.ase import AseSystem


def read_ase(filename: str) -> List[SystemBase]:
return [AseSystem(atoms) for atoms in ase.io.read(filename, ":")]
3 changes: 3 additions & 0 deletions src/metatensor_models/readers/targets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ase import read_ase

TARGET_READERS = {".xyz": read_ase}
38 changes: 38 additions & 0 deletions src/metatensor_models/readers/targets/ase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import Dict, List, Union

import ase.io
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap


def read_ase(
filename: str,
target_values: Union[List[str], str],
) -> Dict[str, TensorMap]:
"""Store target informations from file in a :class:`metatensor.TensorMap`.
:returns:
TensorMap containing the given information
"""

if type(target_values) is str:
target_values = [target_values]

frames = ase.io.read(filename, ":")

target_dictionary = {}
for target_value in target_values:
values = [f.info[target_value] for f in frames]

n_structures = len(values)

block = TensorBlock(
values=torch.tensor(values).reshape(-1, 1),
samples=Labels(["structure"], torch.arange(n_structures).reshape(-1, 1)),
components=[],
properties=Labels([target_value], torch.tensor([(0,)])),
)

target_dictionary[target_value] = TensorMap(Labels.single(), [block])

return target_dictionary

0 comments on commit c7b7f8d

Please sign in to comment.