Skip to content

Commit

Permalink
replace paths in basedataset with path
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 26, 2024
1 parent 4426870 commit 235965f
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 32 deletions.
37 changes: 15 additions & 22 deletions src/fairchem/core/datasets/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ def __init__(self, config: dict):
config (dict): dataset configuration
"""
self.config = config
self.paths = []
self.path = None

if "src" in self.config:
if isinstance(config["src"], str):
self.paths = [Path(self.config["src"])]
self.path = Path(self.config["src"])
else:
self.paths = tuple(Path(path) for path in config["src"])
raise ValueError("path is required to be a single path")

self.lin_ref = None
if self.config.get("lin_ref", False):
Expand All @@ -82,34 +82,27 @@ def indices(self):
@cached_property
def _metadata(self) -> DatasetMetadata:
# logic to read metadata file here
metadata_npzs = []
metadata_npz = None
if self.config.get("metadata_path", None) is not None:
metadata_npzs.append(
np.load(self.config["metadata_path"], allow_pickle=True)
)
metadata_npz = np.load(self.config["metadata_path"], allow_pickle=True)

else:
for path in self.paths:
if path.is_file():
metadata_file = path.parent / "metadata.npz"
else:
metadata_file = path / "metadata.npz"
if metadata_file.is_file():
metadata_npzs.append(np.load(metadata_file, allow_pickle=True))

if len(metadata_npzs) == 0:
if self.path.is_file():
metadata_file = self.path.parent / "metadata.npz"
else:
metadata_file = self.path / "metadata.npz"
if metadata_file.is_file():
metadata_npz = np.load(metadata_file, allow_pickle=True)

if metadata_npz is None:
logging.warning(
f"Could not find dataset metadata.npz files in '{self.paths}'"
f"Could not find dataset metadata.npz file for '{self.path}'"
)
return None

metadata = DatasetMetadata(
**{
field: np.concatenate([metadata[field] for metadata in metadata_npzs])
for field in DatasetMetadata._fields
}
**{field: metadata_npz[field] for field in DatasetMetadata._fields}
)

assert metadata.natoms.shape[0] == len(
self
), "Loaded metadata and dataset size mismatch."
Expand Down
5 changes: 0 additions & 5 deletions src/fairchem/core/datasets/lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ def __init__(self, config) -> None:
"train_on_oc20_total_energies", False
), "For training on total energies set dataset=oc22_lmdb"

assert (
len(self.paths) == 1
), f"{type(self)} does not support a list of src paths."
self.path = self.paths[0]

if not self.path.is_file():
db_paths = sorted(self.path.glob("*.lmdb"))
assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'"
Expand Down
5 changes: 0 additions & 5 deletions src/fairchem/core/datasets/oc22_lmdb_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ class OC22LmdbDataset(BaseDataset):
def __init__(self, config, transform=None) -> None:
super().__init__(config)

assert (
len(self.paths) == 1
), f"{type(self)} does not support a list of src paths."
self.path = self.paths[0]

self.data2train = self.config.get("data2train", "all")
if not self.path.is_file():
db_paths = sorted(self.path.glob("*.lmdb"))
Expand Down

0 comments on commit 235965f

Please sign in to comment.