Skip to content

Commit

Permalink
Add tests for errors
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Sep 1, 2024
1 parent 32c24be commit 7eda27b
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/metatrain/utils/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,20 @@ def train_model(
if fixed_weights is None:
fixed_weights = {}

missing_types = sorted(set(get_atomic_types(datasets)) - set(self.atomic_types))
additional_types = sorted(
set(get_atomic_types(datasets)) - set(self.atomic_types)
)
if additional_types:
raise ValueError(
"Provided `datasets` contains unknown "
f"atomic types {additional_types}. "
f"Known types from initilaization are {self.atomic_types}."
)

missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets)))
if missing_types:
raise ValueError(
f"Provided `datasets` contains unknown atomic types {missing_types}. "
f"Provided `datasets` do not contain atomic types {missing_types}. "
f"Known types from initilaization are {self.atomic_types}."
)

Expand Down
102 changes: 102 additions & 0 deletions tests/utils/test_composition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from pathlib import Path

import metatensor.torch
import pytest
import torch
from metatensor.torch import Labels, TensorBlock, TensorMap
from metatensor.torch.atomistic import ModelOutput, System
Expand Down Expand Up @@ -239,3 +240,104 @@ def test_remove_composition():
# In QM9 the composition contribution is very large: the standard deviation
# of the energies is reduced by a factor of over 100 upon removing the composition
assert std_after < 100.0 * std_before


def test_composition_model_missing_types():
"""
Test the error when there are too many or too types in the dataset
compared to those declared at initialization.
"""

# Here we use three synthetic structures:
# - O atom, with an energy of 1.0
# - H2O molecule, with an energy of 5.0
# - H4O2 molecule, with an energy of 10.0
# The expected composition weights are 2.0 for H and 1.0 for O.

systems = [
System(
positions=torch.tensor([[0.0, 0.0, 0.0]], dtype=torch.float64),
types=torch.tensor([8]),
cell=torch.eye(3, dtype=torch.float64),
),
System(
positions=torch.tensor(
[[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], dtype=torch.float64
),
types=torch.tensor([1, 1, 8]),
cell=torch.eye(3, dtype=torch.float64),
),
System(
positions=torch.tensor(
[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[1.0, 0.0, 1.0],
[0.0, 1.0, 1.0],
],
dtype=torch.float64,
),
types=torch.tensor([1, 1, 8, 1, 1, 8]),
cell=torch.eye(3, dtype=torch.float64),
),
]
energies = [1.0, 5.0, 10.0]
energies = [
TensorMap(
keys=Labels(names=["_"], values=torch.tensor([[0]])),
blocks=[
TensorBlock(
values=torch.tensor([[e]], dtype=torch.float64),
samples=Labels(names=["system"], values=torch.tensor([[i]])),
components=[],
properties=Labels(names=["energy"], values=torch.tensor([[0]])),
)
],
)
for i, e in enumerate(energies)
]
dataset = Dataset({"system": systems, "energy": energies})

composition_model = CompositionModel(
model_hypers={},
dataset_info=DatasetInfo(
length_unit="angstrom",
atomic_types=[1],
targets=TargetInfoDict(
{
"energy": TargetInfo(
quantity="energy",
per_atom=False,
)
}
),
),
)
with pytest.raises(
ValueError,
match="unknown atomic types",
):
composition_model.train_model(dataset)

composition_model = CompositionModel(
model_hypers={},
dataset_info=DatasetInfo(
length_unit="angstrom",
atomic_types=[1, 8, 100],
targets=TargetInfoDict(
{
"energy": TargetInfo(
quantity="energy",
per_atom=False,
)
}
),
),
)
with pytest.raises(
ValueError,
match="do not contain atomic types",
):
composition_model.train_model(dataset)

0 comments on commit 7eda27b

Please sign in to comment.