diff --git a/src/metatrain/utils/data/dataset.py b/src/metatrain/utils/data/dataset.py index 6debbe4e..f8759892 100644 --- a/src/metatrain/utils/data/dataset.py +++ b/src/metatrain/utils/data/dataset.py @@ -389,17 +389,35 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]): or targets that are not present in the training set """ # Check that system `dtypes` are consistent within datasets - desired_dtype = train_datasets[0][0].system.positions.dtype - msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and " + desired_dtype = None for train_dataset in train_datasets: + if len(train_dataset) == 0: + continue + actual_dtype = train_dataset[0].system.positions.dtype + if desired_dtype is None: + desired_dtype = actual_dtype + if actual_dtype != desired_dtype: - raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`") + raise TypeError( + "`dtype` between datasets is inconsistent, " + f"found {desired_dtype} and {actual_dtype} in training datasets" + ) for val_dataset in val_datasets: + if len(val_dataset) == 0: + continue + actual_dtype = val_dataset[0].system.positions.dtype + + if desired_dtype is None: + desired_dtype = actual_dtype + if actual_dtype != desired_dtype: - raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`") + raise TypeError( + "`dtype` between datasets is inconsistent, " + f"found {desired_dtype} and {actual_dtype} in validation datasets" + ) # Get all targets in the training and validation sets: train_targets = get_all_targets(train_datasets) diff --git a/tests/utils/data/test_dataset.py b/tests/utils/data/test_dataset.py index 24c54bac..956c4d8a 100644 --- a/tests/utils/data/test_dataset.py +++ b/tests/utils/data/test_dataset.py @@ -557,7 +557,8 @@ def test_check_datasets(): # wrong dtype systems_qm9_32bit = [system.to(dtype=torch.float32) for system in systems_qm9] targets_qm9_32bit = { - k: [v.to(dtype=torch.float32) for v in l] for k, l in targets_qm9.items() + name: [tensor.to(dtype=torch.float32) for tensor in values] + for name, values in targets_qm9.items() } train_set_32_bit = Dataset.from_dict( {"system": systems_qm9_32bit, **targets_qm9_32bit} @@ -565,17 +566,17 @@ def test_check_datasets(): match = ( "`dtype` between datasets is inconsistent, found torch.float64 and " - "torch.float32 found in `val_datasets`" + "torch.float32 in validation datasets" ) with pytest.raises(TypeError, match=match): check_datasets([train_set], [train_set_32_bit]) match = ( "`dtype` between datasets is inconsistent, found torch.float64 and " - "torch.float32 found in `train_datasets`" + "torch.float32 in training datasets" ) with pytest.raises(TypeError, match=match): - check_datasets([train_set, train_set_32_bit], [val_set]) + check_datasets([train_set, train_set_32_bit], []) def test_collate_fn():