Skip to content

Commit

Permalink
added test for dataset shuffling
Browse files Browse the repository at this point in the history
  • Loading branch information
BowenD-UCB committed Apr 22, 2024
1 parent 72f9108 commit 9e4bd5e
Showing 1 changed file with 40 additions and 3 deletions.
43 changes: 40 additions & 3 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import random

import numpy as np
import pytest
import torch
Expand All @@ -17,6 +19,7 @@
@pytest.fixture()
def structure_data() -> StructureData:
"""Create a graph with 3 nodes and 3 directed edges."""
random.seed(42)
structures, energies, forces, stresses, magmoms, structure_ids = (
[],
[],
Expand All @@ -25,15 +28,15 @@ def structure_data() -> StructureData:
[],
[],
)
for _ in range(100):
for index in range(100):
struct = NaCl.copy()
struct.perturb(0.1)
structures.append(struct)
energies.append(np.random.random(1))
forces.append(np.random.random([2, 3]))
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random([2, 1]))
structure_ids.append("tmp_id")
structure_ids.append(index)
return StructureData(
structures=structures,
energies=energies,
Expand All @@ -47,7 +50,8 @@ def structure_data() -> StructureData:
def test_structure_data(structure_data: StructureData) -> None:
get_one = structure_data[0]
assert isinstance(get_one[0], CrystalGraph)
assert get_one[0].mp_id == "tmp_id"
assert isinstance(get_one[0].mp_id, int)
assert get_one[0].mp_id == 42
assert isinstance(get_one[1], dict)
assert isinstance(get_one[1]["e"], torch.Tensor)
assert isinstance(get_one[1]["f"], torch.Tensor)
Expand Down Expand Up @@ -85,3 +89,36 @@ def test_structure_data_inconsistent_length():
== f"Inconsistent number of structures and labels: {len(structures)=}, "
f"{len(forces)=}"
)


def test_dataset_no_shuffling():
structures, energies, forces, stresses, magmoms, structure_ids = (
[],
[],
[],
[],
[],
[],
)
for index in range(100):
struct = NaCl.copy()
struct.perturb(0.1)
structures.append(struct)
energies.append(np.random.random(1))
forces.append(np.random.random([2, 3]))
stresses.append(np.random.random([3, 3]))
magmoms.append(np.random.random([2, 1]))
structure_ids.append(index)
structure_data = StructureData(
structures=structures,
energies=energies,
forces=forces,
stresses=stresses,
magmoms=magmoms,
structure_ids=structure_ids,
shuffle=False,
)

assert structure_data[0][0].mp_id == 0
assert structure_data[1][0].mp_id == 1
assert structure_data[2][0].mp_id == 2

0 comments on commit 9e4bd5e

Please sign in to comment.