From 661ebe5d066f86ac824c82cca485aa3e946f9671 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Mon, 2 Sep 2024 08:08:51 +0200 Subject: [PATCH] Only warn if atomic types are present in the validation dataset but not in the training dataset --- src/metatrain/utils/composition.py | 6 ++++-- tests/resources/generate-outputs.sh | 4 ++-- tests/utils/test_composition.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/metatrain/utils/composition.py b/src/metatrain/utils/composition.py index 6a188bfc..da76361e 100644 --- a/src/metatrain/utils/composition.py +++ b/src/metatrain/utils/composition.py @@ -1,3 +1,4 @@ +import warnings from typing import Dict, List, Optional, Union import torch @@ -85,9 +86,10 @@ def train_model( missing_types = sorted(set(self.atomic_types) - set(get_atomic_types(datasets))) if missing_types: - raise ValueError( + warnings.warn( f"Provided `datasets` do not contain atomic types {missing_types}. " - f"Known types from initilaization are {self.atomic_types}." + f"Known types from initilaization are {self.atomic_types}.", + stacklevel=2, ) # Fill the weights for each target in the dataset info diff --git a/tests/resources/generate-outputs.sh b/tests/resources/generate-outputs.sh index 7d99a7c8..6ddd140b 100755 --- a/tests/resources/generate-outputs.sh +++ b/tests/resources/generate-outputs.sh @@ -7,5 +7,5 @@ ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) cd $ROOT_DIR -mtt train options.yaml -o model-32-bit.pt -r base_precision=32 > /dev/null -mtt train options.yaml -o model-64-bit.pt -r base_precision=64 > /dev/null +mtt train options.yaml -o model-32-bit.pt -r base_precision=32 #> /dev/null +mtt train options.yaml -o model-64-bit.pt -r base_precision=64 #> /dev/null diff --git a/tests/utils/test_composition.py b/tests/utils/test_composition.py index e9472197..c1787d24 100644 --- a/tests/utils/test_composition.py +++ b/tests/utils/test_composition.py @@ -336,8 +336,8 @@ def test_composition_model_missing_types(): ), ), ) - with pytest.raises( - ValueError, + with pytest.warns( + UserWarning, match="do not contain atomic types", ): composition_model.train_model(dataset)