diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index 1d7c80c990..0b83b0ccc0 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -183,6 +183,13 @@ def sample_property_metadata(self, num_samples: int = 100) -> dict: return metadata + def get_metadata(self, attr, idx): + if attr != "natoms": + return None + if isinstance(idx, list): + return [self.get_metadata(attr, i) for i in idx] + return len(self.get_atoms(idx)) + @registry.register_dataset("ase_read") class AseReadDataset(AseAtomsDataset): diff --git a/src/fairchem/core/scripts/make_lmdb_sizes.py b/src/fairchem/core/scripts/make_lmdb_sizes.py index 682fb58e65..1e68996502 100644 --- a/src/fairchem/core/scripts/make_lmdb_sizes.py +++ b/src/fairchem/core/scripts/make_lmdb_sizes.py @@ -28,7 +28,7 @@ def get_data(index): return index, natoms, neighbors -def main(args) -> None: +def make_lmdb_sizes(args) -> None: path = assert_is_instance(args.data_path, str) global dataset if os.path.isdir(path): @@ -63,7 +63,7 @@ def main(args) -> None: np.savez(outpath, natoms=sorted_natoms, neighbors=sorted_neighbors) -if __name__ == "__main__": +def get_lmdb_sizes_parser(): parser = argparse.ArgumentParser() parser.add_argument( "--data-path", @@ -77,5 +77,10 @@ def main(args) -> None: type=int, help="Num of workers to parallelize across", ) + return parser + + +if __name__ == "__main__": + parser = get_lmdb_sizes_parser() args: argparse.Namespace = parser.parse_args() - main(args) + make_lmdb_sizes(args) diff --git a/tests/core/datasets/test_ase_datasets.py b/tests/core/datasets/test_ase_datasets.py index 01ddfb0d9f..7b114d877f 100644 --- a/tests/core/datasets/test_ase_datasets.py +++ b/tests/core/datasets/test_ase_datasets.py @@ -119,6 +119,10 @@ def test_ase_read_dataset(tmp_path, structures): del data +def test_ase_get_metadata(ase_dataset): + assert ase_dataset[0].get_metadata("natoms", [0])[0] == 3 + + def test_ase_metadata_guesser(ase_dataset): dataset, _ = ase_dataset diff --git a/tests/core/datasets/test_lmdb_dataset.py b/tests/core/datasets/test_lmdb_dataset.py new file mode 100644 index 0000000000..f922e32ce3 --- /dev/null +++ b/tests/core/datasets/test_lmdb_dataset.py @@ -0,0 +1,29 @@ +from fairchem.core.datasets.base_dataset import create_dataset + +import numpy as np + +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes + + +def test_load_lmdb_dataset(tutorial_dataset_path): + + lmdb_path = str(tutorial_dataset_path / "s2ef/val_20") + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args(["--data-path", lmdb_path]) + make_lmdb_sizes(args) + + config = { + "format": "lmdb", + "src": lmdb_path, + } + + dataset = create_dataset(config, split="val") + + assert dataset.get_metadata("natoms", 0) == dataset[0].natoms + + all_metadata_natoms = np.array(dataset.get_metadata("natoms", range(len(dataset)))) + all_natoms = np.array([datapoint.natoms for datapoint in dataset]) + + assert (all_natoms == all_metadata_natoms).all()