diff --git a/src/fairchem/core/datasets/base_dataset.py b/src/fairchem/core/datasets/base_dataset.py index 2ca26596c..ceaade573 100644 --- a/src/fairchem/core/datasets/base_dataset.py +++ b/src/fairchem/core/datasets/base_dataset.py @@ -52,13 +52,13 @@ def __init__(self, config: dict): config (dict): dataset configuration """ self.config = config - self.paths = [] + self.path = None if "src" in self.config: if isinstance(config["src"], str): - self.paths = [Path(self.config["src"])] + self.path = Path(self.config["src"]) else: - self.paths = tuple(Path(path) for path in config["src"]) + raise ValueError("path is required to be a single path") self.lin_ref = None if self.config.get("lin_ref", False): @@ -82,34 +82,27 @@ def indices(self): @cached_property def _metadata(self) -> DatasetMetadata: # logic to read metadata file here - metadata_npzs = [] + metadata_npz = None if self.config.get("metadata_path", None) is not None: - metadata_npzs.append( - np.load(self.config["metadata_path"], allow_pickle=True) - ) + metadata_npz = np.load(self.config["metadata_path"], allow_pickle=True) else: - for path in self.paths: - if path.is_file(): - metadata_file = path.parent / "metadata.npz" - else: - metadata_file = path / "metadata.npz" - if metadata_file.is_file(): - metadata_npzs.append(np.load(metadata_file, allow_pickle=True)) - - if len(metadata_npzs) == 0: + if self.path.is_file(): + metadata_file = self.path.parent / "metadata.npz" + else: + metadata_file = self.path / "metadata.npz" + if metadata_file.is_file(): + metadata_npz = np.load(metadata_file, allow_pickle=True) + + if metadata_npz is None: logging.warning( - f"Could not find dataset metadata.npz files in '{self.paths}'" + f"Could not find dataset metadata.npz file for '{self.path}'" ) return None metadata = DatasetMetadata( - **{ - field: np.concatenate([metadata[field] for metadata in metadata_npzs]) - for field in DatasetMetadata._fields - } + **{field: metadata_npz[field] for field in DatasetMetadata._fields} ) - assert metadata.natoms.shape[0] == len( self ), "Loaded metadata and dataset size mismatch." diff --git a/src/fairchem/core/datasets/lmdb_dataset.py b/src/fairchem/core/datasets/lmdb_dataset.py index b70dc0902..900d4bf7c 100644 --- a/src/fairchem/core/datasets/lmdb_dataset.py +++ b/src/fairchem/core/datasets/lmdb_dataset.py @@ -58,11 +58,6 @@ def __init__(self, config) -> None: "train_on_oc20_total_energies", False ), "For training on total energies set dataset=oc22_lmdb" - assert ( - len(self.paths) == 1 - ), f"{type(self)} does not support a list of src paths." - self.path = self.paths[0] - if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb")) assert len(db_paths) > 0, f"No LMDBs found in '{self.path}'" diff --git a/src/fairchem/core/datasets/oc22_lmdb_dataset.py b/src/fairchem/core/datasets/oc22_lmdb_dataset.py index 15c627a5a..f9368e4f8 100644 --- a/src/fairchem/core/datasets/oc22_lmdb_dataset.py +++ b/src/fairchem/core/datasets/oc22_lmdb_dataset.py @@ -44,11 +44,6 @@ class OC22LmdbDataset(BaseDataset): def __init__(self, config, transform=None) -> None: super().__init__(config) - assert ( - len(self.paths) == 1 - ), f"{type(self)} does not support a list of src paths." - self.path = self.paths[0] - self.data2train = self.config.get("data2train", "all") if not self.path.is_file(): db_paths = sorted(self.path.glob("*.lmdb"))