Skip to content

Commit

Permalink
Handle empty datasets in check_datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
Luthaf committed Sep 23, 2024
1 parent 2b3d2f2 commit 507316a
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
26 changes: 22 additions & 4 deletions src/metatrain/utils/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,17 +389,35 @@ def check_datasets(train_datasets: List[Dataset], val_datasets: List[Dataset]):
or targets that are not present in the training set
"""
# Check that system `dtypes` are consistent within datasets
desired_dtype = train_datasets[0][0].system.positions.dtype
msg = f"`dtype` between datasets is inconsistent, found {desired_dtype} and "
desired_dtype = None
for train_dataset in train_datasets:
if len(train_dataset) == 0:
continue

actual_dtype = train_dataset[0].system.positions.dtype
if desired_dtype is None:
desired_dtype = actual_dtype

if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `train_datasets`")
raise TypeError(
"`dtype` between datasets is inconsistent, "
f"found {desired_dtype} and {actual_dtype} in training datasets"
)

for val_dataset in val_datasets:
if len(val_dataset) == 0:
continue

actual_dtype = val_dataset[0].system.positions.dtype

if desired_dtype is None:
desired_dtype = actual_dtype

if actual_dtype != desired_dtype:
raise TypeError(f"{msg}{actual_dtype} found in `val_datasets`")
raise TypeError(
"`dtype` between datasets is inconsistent, "
f"found {desired_dtype} and {actual_dtype} in validation datasets"
)

# Get all targets in the training and validation sets:
train_targets = get_all_targets(train_datasets)
Expand Down
9 changes: 5 additions & 4 deletions tests/utils/data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,25 +557,26 @@ def test_check_datasets():
# wrong dtype
systems_qm9_32bit = [system.to(dtype=torch.float32) for system in systems_qm9]
targets_qm9_32bit = {
k: [v.to(dtype=torch.float32) for v in l] for k, l in targets_qm9.items()
name: [tensor.to(dtype=torch.float32) for tensor in values]
for name, values in targets_qm9.items()
}
train_set_32_bit = Dataset.from_dict(
{"system": systems_qm9_32bit, **targets_qm9_32bit}
)

match = (
"`dtype` between datasets is inconsistent, found torch.float64 and "
"torch.float32 found in `val_datasets`"
"torch.float32 in validation datasets"
)
with pytest.raises(TypeError, match=match):
check_datasets([train_set], [train_set_32_bit])

match = (
"`dtype` between datasets is inconsistent, found torch.float64 and "
"torch.float32 found in `train_datasets`"
"torch.float32 in training datasets"
)
with pytest.raises(TypeError, match=match):
check_datasets([train_set, train_set_32_bit], [val_set])
check_datasets([train_set, train_set_32_bit], [])


def test_collate_fn():
Expand Down

0 comments on commit 507316a

Please sign in to comment.