Skip to content

Commit

Permalink
Only warn if atomic types are present in the validation dataset but n…
Browse files Browse the repository at this point in the history
…ot in the training dataset
  • Loading branch information
frostedoyster committed Sep 2, 2024
1 parent 0b35fed commit 661ebe5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 6 deletions.
6 changes: 4 additions & 2 deletions src/metatrain/utils/composition.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import warnings
from typing import Dict, List, Optional, Union

import torch
Expand Down Expand Up @@ -85,9 +86,10 @@ def train_model(

missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets)))
if missing_types:
raise ValueError(
warnings.warn(
f"Provided `datasets` do not contain atomic types {missing_types}. "
f"Known types from initilaization are {self.atomic_types}."
f"Known types from initilaization are {self.atomic_types}.",
stacklevel=2,
)

# Fill the weights for each target in the dataset info
Expand Down
4 changes: 2 additions & 2 deletions tests/resources/generate-outputs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,5 @@ ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)

cd $ROOT_DIR

mtt train options.yaml -o model-32-bit.pt -r base_precision=32 > /dev/null
mtt train options.yaml -o model-64-bit.pt -r base_precision=64 > /dev/null
mtt train options.yaml -o model-32-bit.pt -r base_precision=32 #> /dev/null
mtt train options.yaml -o model-64-bit.pt -r base_precision=64 #> /dev/null
4 changes: 2 additions & 2 deletions tests/utils/test_composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def test_composition_model_missing_types():
),
),
)
with pytest.raises(
ValueError,
with pytest.warns(
UserWarning,
match="do not contain atomic types",
):
composition_model.train_model(dataset)

0 comments on commit 661ebe5

Please sign in to comment.