Skip to content

Commit

Permalink
add slow get_metadata for ase; add tests for get_metadata (ase+lmdb);…
Browse files Browse the repository at this point in the history
… add test for make lmdb metadata sizes
  • Loading branch information
misko committed Jul 23, 2024
1 parent 65b84ec commit 0219e36
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
7 changes: 7 additions & 0 deletions src/fairchem/core/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,13 @@ def sample_property_metadata(self, num_samples: int = 100) -> dict:

return metadata

def get_metadata(self, attr, idx):
if attr != "natoms":
return None
if isinstance(idx, list):
return [self.get_metadata(attr, i) for i in idx]
return len(self.get_atoms(idx))


@registry.register_dataset("ase_read")
class AseReadDataset(AseAtomsDataset):
Expand Down
11 changes: 8 additions & 3 deletions src/fairchem/core/scripts/make_lmdb_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def get_data(index):
return index, natoms, neighbors


def main(args) -> None:
def make_lmdb_sizes(args) -> None:
path = assert_is_instance(args.data_path, str)
global dataset
if os.path.isdir(path):
Expand Down Expand Up @@ -63,7 +63,7 @@ def main(args) -> None:
np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors)


if __name__ == "__main__":
def get_lmdb_sizes_parser():
parser = argparse.ArgumentParser()
parser.add_argument(
"--data-path",
Expand All @@ -77,5 +77,10 @@ def main(args) -> None:
type=int,
help="Num of workers to parallelize across",
)
return parser


if __name__ == "__main__":
parser = get_lmdb_sizes_parser()
args: argparse.Namespace = parser.parse_args()
main(args)
make_lmdb_sizes(args)
4 changes: 4 additions & 0 deletions tests/core/datasets/test_ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ def test_ase_read_dataset(tmp_path, structures):
del data


def test_ase_get_metadata(ase_dataset):
assert ase_dataset[0].get_metadata("natoms", [0])[0] == 3


def test_ase_metadata_guesser(ase_dataset):
dataset, _ = ase_dataset

Expand Down
29 changes: 29 additions & 0 deletions tests/core/datasets/test_lmdb_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from fairchem.core.datasets.base_dataset import create_dataset

import numpy as np

from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes


def test_load_lmdb_dataset(tutorial_dataset_path):

lmdb_path = str(tutorial_dataset_path / "s2ef/val_20")

# make dataset metadata
parser = get_lmdb_sizes_parser()
args, override_args = parser.parse_known_args(["--data-path", lmdb_path])
make_lmdb_sizes(args)

config = {
"format": "lmdb",
"src": lmdb_path,
}

dataset = create_dataset(config, split="val")

assert dataset.get_metadata("natoms", 0) == dataset[0].natoms

all_metadata_natoms = np.array(dataset.get_metadata("natoms", range(len(dataset))))
all_natoms = np.array([datapoint.natoms for datapoint in dataset])

assert (all_natoms == all_metadata_natoms).all()

0 comments on commit 0219e36

Please sign in to comment.