diff --git a/env.common.yml b/env.common.yml index 32486a4ed..9550de837 100644 --- a/env.common.yml +++ b/env.common.yml @@ -8,6 +8,7 @@ dependencies: - black==22.3.0 - matplotlib - numba +- orjson - pip - pre-commit=2.10.* - pyg=2.2.0 diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py index fd37cd55f..4334033dc 100644 --- a/ocpmodels/datasets/ase_datasets.py +++ b/ocpmodels/datasets/ase_datasets.py @@ -1,13 +1,22 @@ -import ase +import bisect +import copy +import functools +import glob +import logging +import os import warnings -import numpy as np - from pathlib import Path +from abc import ABC, abstractmethod + +import ase +import numpy as np from torch import tensor from torch.utils.data import Dataset from tqdm import tqdm from ocpmodels.common.registry import registry +from ocpmodels.datasets.lmdb_database import LMDBDatabase +from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata from ocpmodels.preprocessing import AtomsToGraphs @@ -36,14 +45,19 @@ def apply_one_tags(atoms, skip_if_nonzero=True, skip_always=False): return atoms -class AseAtomsDataset(Dataset): +class AseAtomsDataset(Dataset, ABC): """ - This is a base Dataset that includes helpful utilities for turning - ASE atoms objects into OCP-usable data objects. + This is an abstract Dataset that includes helpful utilities for turning + ASE atoms objects into OCP-usable data objects. This should not be instantiated directly + as get_atoms_object and load_dataset_get_ids are not implemented in this base class. Derived classes must add at least two things: - self.get_atoms_object(): a function that takes an identifier and returns a corresponding atoms object - self.id: a list of all possible identifiers that can be passed into self.get_atoms_object() + self.get_atoms_object(id): a function that takes an identifier and returns a corresponding atoms object + + self.load_dataset_get_ids(config: dict): This function is responsible for any initialization/loads + of the dataset and importantly must return a list of all possible identifiers that can be passed into + self.get_atoms_object(id) + Identifiers need not be any particular type. """ @@ -51,32 +65,31 @@ def __init__(self, config, transform=None, atoms_transform=apply_one_tags): self.config = config a2g_args = config.get("a2g_args", {}) + + # Make sure we always include PBC info in the resulting atoms objects + a2g_args["r_pbc"] = True self.a2g = AtomsToGraphs(**a2g_args) self.transform = transform self.atoms_transform = atoms_transform if self.config.get("keep_in_memory", False): - self.data_objects = {} + self.__getitem__ = functools.cache(self.__getitem__) - # Derived classes should extend this functionality to also create self.id, + # Derived classes should extend this functionality to also create self.ids, # a list of identifiers that can be passed to get_atoms_object() + self.ids = self.load_dataset_get_ids(config) def __len__(self): - return len(self.id) + return len(self.ids) def __getitem__(self, idx): # Handle slicing if isinstance(idx, slice): - return [self[i] for i in range(len(self.id))[idx]] - - # Check if data object is already in memory - if self.config.get("keep_in_memory", False): - if self.id[idx] in self.data_objects: - return self.data_objects[self.id[idx]] + return [self[i] for i in range(*idx.indices(len(self.ids)))] # Get atoms object via derived class method - atoms = self.get_atoms_object(self.id[idx]) + atoms = self.get_atoms_object(self.ids[idx]) # Transform atoms object if self.atoms_transform is not None: @@ -84,9 +97,14 @@ def __getitem__(self, idx): atoms, **self.config.get("atoms_transform_args", {}) ) + if "sid" in atoms.info: + sid = atoms.info["sid"] + else: + sid = tensor([idx]) + # Convert to data object - data_object = self.a2g.convert(atoms) - data_object.sid = tensor([idx]) + data_object = self.a2g.convert(atoms, sid) + data_object.pbc = tensor(atoms.pbc) # Transform data object @@ -95,20 +113,50 @@ def __getitem__(self, idx): data_object, **self.config.get("transform_args", {}) ) - # Save in memory, if specified - if self.config.get("keep_in_memory", False): - self.data_objects[self.id[idx]] = data_object - return data_object + @abstractmethod def get_atoms_object(self, identifier): + # This function should return an ASE atoms object. raise NotImplementedError( - "Returns an ASE atoms object. Derived classes should implement this funciton." + "Returns an ASE atoms object. Derived classes should implement this function." + ) + + @abstractmethod + def load_dataset_get_ids(self, config): + # This function should return a list of ids that can be used to index into the database + raise NotImplementedError( + "Every ASE dataset needs to declare a function to load the dataset and return a list of ids." ) def close_db(self): - pass # This method is sometimes called by a trainer + pass + + def guess_target_metadata(self, num_samples=100): + metadata = {} + + if num_samples < len(self): + metadata["targets"] = guess_property_metadata( + [ + self.get_atoms_object(self.ids[idx]) + for idx in np.random.choice( + len(self), size=(num_samples,), replace=False + ) + ] + ) + else: + metadata["targets"] = guess_property_metadata( + [ + self.get_atoms_object(self.ids[idx]) + for idx in range(len(self)) + ] + ) + + return metadata + + def get_metadata(self): + return self.guess_target_metadata() @registry.register_dataset("ase_read") @@ -154,10 +202,7 @@ class AseReadDataset(AseAtomsDataset): """ - def __init__(self, config, transform=None, atoms_transform=apply_one_tags): - super(AseReadDataset, self).__init__( - config, transform, atoms_transform - ) + def load_dataset_get_ids(self, config): self.ase_read_args = config.get("ase_read_args", {}) if ":" in self.ase_read_args.get("index", ""): @@ -165,10 +210,11 @@ def __init__(self, config, transform=None, atoms_transform=apply_one_tags): "To read multiple structures from a single file, please use AseReadMultiStructureDataset." ) - self.path = Path(self.config["src"]) + self.path = Path(config["src"]) if self.path.is_file(): raise Exception("The specified src is not a directory") - self.id = list(self.path.glob(f'{self.config["pattern"]}')) + + return list(self.path.glob(f'{config["pattern"]}')) def get_atoms_object(self, identifier): try: @@ -235,11 +281,8 @@ class AseReadMultiStructureDataset(AseAtomsDataset): transform (callable, optional): Additional preprocessing function for the Data object """ - def __init__(self, config, transform=None, atoms_transform=apply_one_tags): - super(AseReadMultiStructureDataset, self).__init__( - config, transform, atoms_transform - ) - self.ase_read_args = self.config.get("ase_read_args", {}) + def load_dataset_get_ids(self, config): + self.ase_read_args = config.get("ase_read_args", {}) if not hasattr(self.ase_read_args, "index"): self.ase_read_args["index"] = ":" @@ -247,22 +290,22 @@ def __init__(self, config, transform=None, atoms_transform=apply_one_tags): f = open(config["index_file"], "r") index = f.readlines() - self.id = [] + ids = [] for line in index: filename = line.split(" ")[0] for i in range(int(line.split(" ")[1])): - self.id.append(f"{filename} {i}") + ids.append(f"{filename} {i}") - return + return ids - self.path = Path(self.config["src"]) + self.path = Path(config["src"]) if self.path.is_file(): raise Exception("The specified src is not a directory") - filenames = list(self.path.glob(f'{self.config["pattern"]}')) + filenames = list(self.path.glob(f'{config["pattern"]}')) - self.id = [] + ids = [] - if self.config.get("use_tqdm", True): + if config.get("use_tqdm", True): filenames = tqdm(filenames) for filename in filenames: try: @@ -271,27 +314,9 @@ def __init__(self, config, transform=None, atoms_transform=apply_one_tags): warnings.warn(f"{err} occured for: {filename}") else: for i, structure in enumerate(structures): - self.id.append(f"{filename} {i}") + ids.append(f"{filename} {i}") - if self.config.get("keep_in_memory", False): - # Transform atoms object - if self.atoms_transform is not None: - atoms = self.atoms_transform( - structure, - **self.config.get("atoms_transform_args", {}), - ) - - # Convert to data object - data_object = self.a2g.convert(atoms) - - # Transform data object - if self.transform is not None: - data_object = self.transform( - data_object, - **self.config.get("transform_args", {}), - ) - - self.data_objects[f"{filename} {i}"] = data_object + return ids def get_atoms_object(self, identifier): try: @@ -304,6 +329,35 @@ def get_atoms_object(self, identifier): return atoms + def get_metadata(self): + return {} + + +class dummy_list(list): + def __init__(self, max): + self.max = max + return + + def __len__(self): + return self.max + + def __getitem__(self, idx): + # Handle slicing + if isinstance(idx, slice): + return [self[i] for i in range(*idx.indices(self.max))] + + # Cast idx as int since it could be a tensor index + idx = int(idx) + + # Handle negative indices (referenced from end) + if idx < 0: + idx += self.max + + if 0 <= idx < self.max: + return idx + else: + raise IndexError + @registry.register_dataset("ase_db") class AseDBDataset(AseAtomsDataset): @@ -316,7 +370,20 @@ class AseDBDataset(AseAtomsDataset): args: config (dict): - src (str): The path to or connection address of your ASE DB + src (str): Either + - the path an ASE DB, + - the connection address of an ASE DB, + - a folder with multiple ASE DBs, + - a glob string to use to find ASE DBs, or + - a list of ASE db paths/addresses. + If a folder, every file will be attempted as an ASE DB, and warnings + are raised for any files that can't connect cleanly + + Note that for large datasets, ID loading can be slow and there can be many + ids, so it's advised to make loading the id list as easy as possible. There is not + an obvious way to get a full list of ids from most ASE dbs besides simply looping + through the entire dataset. See the AseLMDBDataset which was written with this usecase + in mind. connect_args (dict): Keyword arguments for ase.db.connect() @@ -344,31 +411,87 @@ class AseDBDataset(AseAtomsDataset): transform (callable, optional): Additional preprocessing function for the Data object """ - def __init__(self, config, transform=None, atoms_transform=apply_one_tags): - super(AseDBDataset, self).__init__(config, transform, atoms_transform) + def load_dataset_get_ids(self, config): - self.db = self.connect_db( - self.config["src"], self.config.get("connect_args", {}) - ) + if isinstance(config["src"], list): + filepaths = config["src"] + elif os.path.isfile(config["src"]): + filepaths = [config["src"]] + elif os.path.isdir(config["src"]): + filepaths = glob.glob(f'{config["src"]}/*') + else: + filepaths = glob.glob(config["src"]) + + self.dbs = [] + + for path in filepaths: + try: + self.dbs.append( + self.connect_db(path, config.get("connect_args", {})) + ) + except ValueError: + logging.warning( + f"Tried to connect to {path} but it's not an ASE database!" + ) + + self.select_args = config.get("select_args", {}) + + # In order to get all of the unique IDs using the default ASE db interface + # we have to load all the data and check ids using a select. This is extremely + # inefficient for large dataset. If the db we're using already presents a list of + # ids and there is no query, we can just use that list instead and save ourselves + # a lot of time! + self.db_ids = [] + for db in self.dbs: + if hasattr(db, "ids") and self.select_args == {}: + self.db_ids.append(db.ids) + else: + self.db_ids.append( + [row.id for row in db.select(**self.select_args)] + ) - self.select_args = self.config.get("select_args", {}) + idlens = [len(ids) for ids in self.db_ids] + self._idlen_cumulative = np.cumsum(idlens).tolist() - self.id = [row.id for row in self.db.select(**self.select_args)] + return dummy_list(sum(idlens)) - def get_atoms_object(self, identifier): - return self.db._get_row(identifier).toatoms() + def get_atoms_object(self, idx): + # Figure out which db this should be indexed from. + db_idx = bisect.bisect(self._idlen_cumulative, idx) + + # Extract index of element within that db + el_idx = idx + if db_idx != 0: + el_idx = idx - self._idlen_cumulative[db_idx - 1] + assert el_idx >= 0 + + atoms_row = self.dbs[db_idx]._get_row(self.db_ids[db_idx][el_idx]) + atoms = atoms_row.toatoms() + + if isinstance(atoms_row.data, dict): + atoms.info.update(atoms_row.data) + + return atoms def connect_db(self, address, connect_args={}): db_type = connect_args.get("type", "extract_from_name") if db_type == "lmdb" or ( db_type == "extract_from_name" and address.split(".")[-1] == "lmdb" ): - from ocpmodels.datasets.lmdb_database import LMDBDatabase - return LMDBDatabase(address, readonly=True, **connect_args) else: return ase.db.connect(address, **connect_args) def close_db(self): - if hasattr(self.db, "close"): - self.db.close() + for db in self.dbs: + if hasattr(db, "close"): + db.close() + + def get_metadata(self): + logging.warning( + "You specific a folder of ASE dbs, so it's impossible to know which metadata to use. Using the first!" + ) + if self.dbs[0].metadata == {}: + return self.guess_target_metadata() + else: + return copy.deepcopy(self.dbs[0].metadata) diff --git a/ocpmodels/datasets/lmdb_database.py b/ocpmodels/datasets/lmdb_database.py new file mode 100644 index 000000000..21e4ce565 --- /dev/null +++ b/ocpmodels/datasets/lmdb_database.py @@ -0,0 +1,357 @@ +""" +Copyright (c) Facebook, Inc. and its affiliates. + +This source code is modified from the ASE db json backend +and is thus licensed under the corresponding LGPL2.1 license + +The ASE notice for the LGPL2.1 license is available here: +https://gitlab.com/ase/ase/-/blob/master/LICENSE +""" + + +import base64 +import json +import os +import sys +import zlib +from contextlib import ExitStack + +import lmdb +import numpy as np +import orjson +from ase.db.core import Database, lock, now, ops +from ase.db.row import AtomsRow +from ase.io.jsonio import decode, encode + +# These are special keys in the ASE LMDB that hold +# metadata and other info +RESERVED_KEYS = ["nextid", "metadata", "deleted_ids"] + + +class LMDBDatabase(Database): + def __enter__(self): + return self + + def __init__( + self, + filename=None, + create_indices=True, + use_lock_file=False, + serial=False, + readonly=False, + *args, + **kwargs, + ): + """ + For the most part, this is identical to the standard ase db initiation + arguments, except that we add a readonly flag. + """ + super().__init__( + filename, create_indices, use_lock_file, serial, *args, **kwargs + ) + + # Add a readonly mode for when we're only training + # to make sure there's no parallel locks + self.readonly = readonly + + if self.readonly: + # Open a new env + self.env = lmdb.open( + self.filename, + subdir=False, + meminit=False, + map_async=True, + readonly=True, + lock=False, + ) + + # Open a transaction and keep it open for fast read/writes! + self.txn = self.env.begin(write=False) + + else: + # Open a new env with write access + self.env = lmdb.open( + self.filename, + map_size=1099511627776 * 2, + subdir=False, + meminit=False, + map_async=True, + ) + + self.txn = self.env.begin(write=True) + + # Load all ids based on keys in the DB. + self._load_ids() + + return + + def __exit__(self, exc_type, exc_value, tb): + self.close() + + pass + + def close(self): + # Close the lmdb environment and transaction + self.txn.commit() + self.env.close() + + return + + def _write(self, atoms, key_value_pairs, data, id): + Database._write(self, atoms, key_value_pairs, data) + + mtime = now() + + if isinstance(atoms, AtomsRow): + row = atoms + else: + row = AtomsRow(atoms) + row.ctime = mtime + row.user = os.getenv("USER") + + dct = {} + for key in row.__dict__: + if key[0] == "_" or key in row._keys or key == "id": + continue + dct[key] = row[key] + + dct["mtime"] = mtime + + if key_value_pairs: + dct["key_value_pairs"] = key_value_pairs + + if data: + dct["data"] = data + + constraints = row.get("constraints") + if constraints: + dct["constraints"] = [ + constraint.todict() for constraint in constraints + ] + + # json doesn't like Cell objects, so make it a cell + dct["cell"] = np.asarray(dct["cell"]) + + if id is None: + nextid = self._get_nextid() + id = nextid + nextid += 1 + else: + data = self.txn.get("{id}".encode("ascii")) + assert data is not None + + # Add the new entry, then add the id and write the nextid + self.txn.put( + f"{id}".encode("ascii"), + zlib.compress( + orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + self.ids.append(id) + self.txn.put( + "nextid".encode("ascii"), + zlib.compress( + orjson.dumps(nextid, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + + return id + + def delete(self, ids): + for id in ids: + self.txn.delete(f"{id}".encode("ascii")) + self.ids.remove(id) + + self.deleted_ids += ids + self.txn.put( + "deleted_ids".encode("ascii"), + zlib.compress( + orjson.dumps( + self.deleted_ids, option=orjson.OPT_SERIALIZE_NUMPY + ) + ), + ) + + def _get_row(self, id, include_data=True): + if id is None: + assert len(self.ids) == 1 + id = self.ids[0] + data = self.txn.get(f"{id}".encode("ascii")) + + if data is not None: + dct = orjson.loads(zlib.decompress(data)) + else: + raise KeyError(f"Id {id} missing from the database!") + + if not include_data: + dct.pop("data", None) + + dct["id"] = id + return AtomsRow(dct) + + def _get_row_by_index(self, index, include_data=True): + """Auxiliary function to get the ith entry, rather than + a specific id + """ + id = self.ids[index] + data = self.txn.get(f"{id}".encode("ascii")) + + if data is not None: + dct = orjson.loads(zlib.decompress(data)) + else: + raise KeyError(f"Id {id} missing from the database!") + + if not include_data: + dct.pop("data", None) + + dct["id"] = id + return AtomsRow(dct) + + def _select( + self, + keys, + cmps, + explain=False, + verbosity=0, + limit=None, + offset=0, + sort=None, + include_data=True, + columns="all", + ): + if explain: + yield {"explain": (0, 0, 0, "scan table")} + return + + if sort: + if sort[0] == "-": + reverse = True + sort = sort[1:] + else: + reverse = False + + def f(row): + return row.get(sort, missing) + + rows = [] + missing = [] + for row in self._select(keys, cmps): + key = row.get(sort) + if key is None: + missing.append((0, row)) + else: + rows.append((key, row)) + + rows.sort(reverse=reverse, key=lambda x: x[0]) + rows += missing + + if limit: + rows = rows[offset : offset + limit] + for key, row in rows: + yield row + return + + if not limit: + limit = -offset - 1 + + cmps = [(key, ops[op], val) for key, op, val in cmps] + n = 0 + for id in self.ids: + if n - offset == limit: + return + row = self._get_row(id, include_data=False) + + for key in keys: + if key not in row: + break + else: + for key, op, val in cmps: + if isinstance(key, int): + value = np.equal(row.numbers, key).sum() + else: + value = row.get(key) + if key == "pbc": + assert op in [ops["="], ops["!="]] + value = "".join("FT"[x] for x in value) + if value is None or not op(value, val): + break + else: + if n >= offset: + yield row + n += 1 + + @property + def metadata(self): + """Load the metadata from the DB if present""" + if self._metadata is None: + metadata = self.txn.get("metadata".encode("ascii")) + if metadata is None: + self._metadata = {} + else: + self._metadata = orjson.loads(zlib.decompress(metadata)) + + return self._metadata.copy() + + @metadata.setter + def metadata(self, dct): + self._metadata = dct + + # Put the updated metadata dictionary + self.txn.put( + "metadata".encode("ascii"), + zlib.compress( + orjson.dumps(dct, option=orjson.OPT_SERIALIZE_NUMPY) + ), + ) + + def _get_nextid(self): + """Get the id of the next row to be written""" + # Get the nextid + nextid_data = self.txn.get("nextid".encode("ascii")) + if nextid_data is not None: + nextid = orjson.loads(zlib.decompress(nextid_data)) + else: + # This db is empty; start at 1! + nextid = 1 + + return nextid + + def count(self, selection=None, **kwargs): + """Count rows. + + See the select() method for the selection syntax. Use db.count() or + len(db) to count all rows. + """ + if selection is not None: + n = 0 + for row in self.select(selection, **kwargs): + n += 1 + return n + else: + # Fast count if there's no queries! Just get number of ids + return len(self.ids) + + def _load_ids(self): + """Load ids from the DB + + Since ASE db ids are mostly 1-N integers, but can be missing entries + if ids have been deleted. To save space and operating under the assumption + that there will probably not be many deletions in most OCP datasets, + we just store the deleted ids. + """ + + # Load the deleted ids + deleted_ids_data = self.txn.get("deleted_ids".encode("ascii")) + if deleted_ids_data is None: + self.deleted_ids = [] + else: + self.deleted_ids = orjson.loads(zlib.decompress(deleted_ids_data)) + + # Reconstruct the full id list + self.ids = [ + i + for i in range(1, self._get_nextid()) + if i not in set(self.deleted_ids) + ] + + return diff --git a/ocpmodels/datasets/lmdb_dataset.py b/ocpmodels/datasets/lmdb_dataset.py index c2020f03e..12f2423a2 100644 --- a/ocpmodels/datasets/lmdb_dataset.py +++ b/ocpmodels/datasets/lmdb_dataset.py @@ -1,6 +1,5 @@ """ Copyright (c) Facebook, Inc. and its affiliates. - This source code is licensed under the MIT license found in the LICENSE file in the root directory of this source tree. """ @@ -22,6 +21,7 @@ from ocpmodels.common import distutils from ocpmodels.common.registry import registry from ocpmodels.common.utils import pyg2_data_transform +from ocpmodels.datasets.target_metadata_guesser import guess_property_metadata @registry.register_dataset("lmdb") @@ -30,10 +30,12 @@ class LmdbDataset(Dataset): r"""Dataset class to load from LMDB files containing relaxation trajectories or single point computations. - Useful for Structure to Energy & Force (S2EF), Initial State to Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks. - + The keys in the LMDB must be integers (stored as ascii objects) starting + from 0 through the length of the LMDB. For historical reasons any key named + "length" is ignored since that was used to infer length of many lmdbs in the same + folder, but lmdb lengths are now calculated directly from the number of keys. Args: config (dict): Dataset configuration transform (callable, optional): Data transform function. @@ -57,11 +59,20 @@ def __init__(self, config, transform=None): self._keys, self.envs = [], [] for db_path in db_paths: - self.envs.append(self.connect_db(db_path)) - length = pickle.loads( - self.envs[-1].begin().get("length".encode("ascii")) - ) - self._keys.append(list(range(length))) + cur_env = self.connect_db(db_path) + self.envs.append(cur_env) + + # If "length" encoded as ascii is present, use that + length_entry = cur_env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries + # in the LMDB + num_entries = cur_env.stat()["entries"] + + # Append the keys (0->num_entries) as a list + self._keys.append(list(range(num_entries))) keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(keylens).tolist() @@ -69,11 +80,18 @@ def __init__(self, config, transform=None): else: self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) - self._keys = [ - f"{j}".encode("ascii") - for j in range(self.env.stat()["entries"]) - ] - self.num_samples = len(self._keys) + + # If "length" encoded as ascii is present, use that + length_entry = self.env.begin().get("length".encode("ascii")) + if length_entry is not None: + num_entries = pickle.loads(length_entry) + else: + # Get the number of stores data from the number of entries + # in the LMDB + num_entries = self.env.stat()["entries"] + + self._keys = list(range(num_entries)) + self.num_samples = num_entries # If specified, limit dataset to only a portion of the entire dataset # total_shards: defines total chunks to partition dataset @@ -117,7 +135,9 @@ def __getitem__(self, idx): data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) data_object.id = f"{db_idx}_{el_idx}" else: - datapoint_pickled = self.env.begin().get(self._keys[idx]) + datapoint_pickled = self.env.begin().get( + f"{self._keys[idx]}".encode("ascii") + ) data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) if self.transform is not None: @@ -131,7 +151,7 @@ def connect_db(self, lmdb_path=None): subdir=False, readonly=True, lock=False, - readahead=False, + readahead=True, meminit=False, max_readers=1, ) @@ -144,6 +164,47 @@ def close_db(self): else: self.env.close() + def get_metadata(self, num_samples=100): + # This will interogate the classic OCP LMDB format to determine + # which properties are present and attempt to guess their shapes + # and whether they are intensive or extensive. + + # Grab an example data point + example_pyg_data = self.__getitem__(0) + + # Check for all properties we've used for OCP datasets in the past + props = [] + for potential_prop in [ + "y", + "y_relaxed", + "stress", + "stresses", + "force", + "forces", + ]: + if hasattr(example_pyg_data, potential_prop): + props.append(potential_prop) + + # Get a bunch of random data samples and the number of atoms + sample_pyg = [ + self[i] + for i in np.random.choice( + self.__len__(), size=(num_samples,), replace=False + ) + ] + atoms_lens = [data.natoms for data in sample_pyg] + + # Guess the metadata for targets for each found property + metadata = {} + metadata["targets"] = { + prop: guess_property_metadata( + atoms_lens, [getattr(data, prop) for data in sample_pyg] + ) + for prop in props + } + + return metadata + class SinglePointLmdbDataset(LmdbDataset): def __init__(self, config, transform=None): diff --git a/ocpmodels/datasets/oc22_lmdb_dataset.py b/ocpmodels/datasets/oc22_lmdb_dataset.py index 1170f6a89..2709e8592 100644 --- a/ocpmodels/datasets/oc22_lmdb_dataset.py +++ b/ocpmodels/datasets/oc22_lmdb_dataset.py @@ -32,6 +32,11 @@ class OC22LmdbDataset(Dataset): Useful for Structure to Energy & Force (S2EF), Initial State to Relaxed State (IS2RS), and Initial State to Relaxed Energy (IS2RE) tasks. + The keys in the LMDB must be integers (stored as ascii objects) starting + from 0 through the length of the LMDB. For historical reasons any key named + "length" is ignored since that was used to infer length of many lmdbs in the same + folder, but lmdb lengths are now calculated directly from the number of keys. + Args: config (dict): Dataset configuration transform (callable, optional): Data transform function. @@ -52,14 +57,20 @@ def __init__(self, config, transform=None): self._keys, self.envs = [], [] for db_path in db_paths: - self.envs.append(self.connect_db(db_path)) - try: - length = pickle.loads( - self.envs[-1].begin().get("length".encode("ascii")) - ) - except TypeError: - length = self.envs[-1].stat()["entries"] - self._keys.append(list(range(length))) + cur_env = self.connect_db(db_path) + self.envs.append(cur_env) + + # Get the number of stores data from the number of entries + # in the LMDB + num_entries = cur_env.stat()["entries"] + + # If "length" encoded as ascii is present, we have one fewer + # data than the stats suggest + if cur_env.begin().get("length".encode("ascii")) is not None: + num_entries -= 1 + + # Append the keys (0->num_entries) as a list + self._keys.append(list(range(num_entries))) keylens = [len(k) for k in self._keys] self._keylen_cumulative = np.cumsum(keylens).tolist() @@ -83,11 +94,16 @@ def __init__(self, config, transform=None): else: self.metadata_path = self.path.parent / "metadata.npz" self.env = self.connect_db(self.path) - self._keys = [ - f"{j}".encode("ascii") - for j in range(self.env.stat()["entries"]) - ] - self.num_samples = len(self._keys) + + num_entries = self.env.stat()["entries"] + + # If "length" encoded as ascii is present, we have one fewer + # data than the stats suggest + if self.env.begin().get("length".encode("ascii")) is not None: + num_entries -= 1 + + self._keys = list(range(num_entries)) + self.num_samples = num_entries self.transform = transform self.lin_ref = self.oc20_ref = False @@ -130,7 +146,9 @@ def __getitem__(self, idx): data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) data_object.id = f"{db_idx}_{el_idx}" else: - datapoint_pickled = self.env.begin().get(self._keys[idx]) + datapoint_pickled = self.env.begin().get( + f"{self._keys[idx]}".encode("ascii") + ) data_object = pyg2_data_transform(pickle.loads(datapoint_pickled)) if self.transform is not None: @@ -200,7 +218,7 @@ def connect_db(self, lmdb_path=None): subdir=False, readonly=True, lock=False, - readahead=False, + readahead=True, meminit=False, max_readers=1, ) diff --git a/ocpmodels/datasets/target_metadata_guesser.py b/ocpmodels/datasets/target_metadata_guesser.py new file mode 100644 index 000000000..8b2168f1a --- /dev/null +++ b/ocpmodels/datasets/target_metadata_guesser.py @@ -0,0 +1,197 @@ +import logging + +import numpy as np + + +def uniform_atoms_lengths(atoms_lens): + # If all of the structures have the same number of atoms, it's really hard to know + # whether the entries are intensive or extensive, and whether + # some of the entries are per-atom or not + return len(set(atoms_lens)) == 1 + + +def target_constant_shape(atoms_lens, target_samples): + # Given a bunch of atoms lengths, and the corresponding samples for the target, + # determine whether the shape is always the same regardless of atom size + return len(set([sample.shape for sample in target_samples])) == 1 + + +def target_per_atom(atoms_lens, target_samples): + # Given a bunch of atoms lengths, and the corresponding samples for the target, + # determine whether the target is per-atom (first dimension == # atoms, others constant) + + # If a sample target is just a number/float/etc, it can't be per-atom + if len(np.array(target_samples[0]).shape) == 0: + return False + + first_dim_proportional = all( + [ + np.array(sample).shape[0] == alen + for alen, sample in zip(atoms_lens, target_samples) + ] + ) + + if len(np.array(target_samples[0]).shape) == 1: + other_dim_constant = True + else: + other_dim_constant = ( + len(set([np.array(sample).shape[1:] for sample in target_samples])) + == 1 + ) + + if first_dim_proportional and other_dim_constant: + return True + else: + return False + + +def target_extensive(atoms_lens, target_samples, threshold=0.2): + # Guess whether a property is intensive or extensive. + # We guess by checking whether standard deviation of the per-atom + # properties capture >20% of the variation in the property + # Of course, with a small amount of data! + + # If the targets are all the same shapes, we shouldn't be asking if the property + # is intensive or extensive! + assert target_constant_shape( + atoms_lens, target_samples + ), "The shapes of this target are not constant!" + + # Get the per-atom normalized properties + try: + compiled_target_array = np.array( + [ + sample / atom_len + for sample, atom_len in zip(atoms_lens, target_samples) + ] + ) + except TypeError: + return False + + # Calculate the normalized standard deviation of each element in the property output + target_samples_mean = np.mean(compiled_target_array, axis=0) + target_samples_normalized = compiled_target_array / target_samples_mean + + # If there's not much variation in the per-atom normalized properties, + # guess extensive! + extensive_guess = target_samples_normalized.std(axis=0) < ( + threshold * target_samples_normalized.mean(axis=0) + ) + if extensive_guess.shape == (): + return extensive_guess + elif ( + target_samples_normalized.std(axis=0) + < (threshold * target_samples_normalized.mean(axis=0)) + ).all(): + return True + else: + return False + + +def guess_target_metadata(atoms_len, target_samples): + example_array = np.array(target_samples[0]) + if example_array.dtype == object or example_array.dtype == str: + return { + "shape": None, + "type": "unknown", + "extensive": None, + "units": "unknown", + "comment": "Guessed property metadata. The property didn't seem to be a numpy array with any numeric type, so we dob't know what to do.", + } + elif target_constant_shape(atoms_len, target_samples): + target_shape = np.array(target_samples[0]).shape + + if uniform_atoms_lengths(atoms_len): + if atoms_len[0] > 3 and target_per_atom(atoms_len, target_samples): + target_shape = list(target_samples[0].shape) + target_shape[0] = "N" + return { + "shape": tuple(target_shape), + "type": "per-atom", + "extensive": True, + "units": "unknown", + "comment": "Guessed property metadata. Because all the sampled atoms are the same length, we can't really know if it is per-atom or per-frame, but the first dimension happens to match the number of atoms.", + } + else: + return { + "shape": tuple(target_shape), + "type": "per-image", + "extensive": True, + "units": "unknown", + "comment": "Guessed property metadata. Because all the sampled atoms are the same length, we can't know if this is intensive of extensive, or per-image or per-frame", + } + + elif target_extensive(atoms_len, target_samples): + return { + "shape": tuple(target_shape), + "type": "per-image", + "extensive": True, + "comment": "Guessed property metadata. It appears to be extensive based on a quick correlation with atom sizes", + } + else: + return { + "shape": tuple(target_shape), + "type": "per-image", + "extensive": False, + "units": "unknown", + "comment": "Guess property metadata. It appears to be intensive based on a quick correlation with atom sizes.", + } + elif target_per_atom(atoms_len, target_samples): + target_shape = list(target_samples[0].shape)[1:] + return { + "shape": tuple(target_shape), + "type": "per-atom", + "extensive": True, + "units": "unknown", + "comment": "Guessed property metadata. It appears to be a per-atom property.", + } + else: + return { + "shape": None, + "type": "unknown", + "extensive": None, + "units": "unknown", + "comment": "Guessed property metadata. The property was variable across different samples and didn't seem to be a per-atom property", + } + + +def guess_property_metadata(atoms_list): + atoms = atoms_list[0] + atoms_len = [len(atoms) for atoms in atoms_list] + + targets = {} + + if hasattr(atoms, "info"): + for key in atoms.info: + # Grab the property samples from the list of atoms + target_samples = [ + np.array(atoms.info[key]) for atoms in atoms_list + ] + + # Guess the metadata + targets[f"info.{key}"] = guess_target_metadata( + atoms_len, target_samples + ) + + # Log a warning so the user knows what's happening + logging.warning( + f'Guessed metadata for atoms.info["{key}"]: {str(targets[f"info.{key}"])}' + ) + if hasattr(atoms, "calc") and atoms.calc is not None: + for key in atoms.calc.results: + # Grab the property samples from the list of atoms + target_samples = [ + np.array(atoms.calc.results[key]) for atoms in atoms_list + ] + + # Guess the metadata + targets[f"{key}"] = guess_target_metadata( + atoms_len, target_samples + ) + + # Log a warning so the user knows what's happening + logging.warning( + f'Guessed metadata for ASE calculator property ["{key}"]: {str(targets[key])}' + ) + + return targets diff --git a/ocpmodels/preprocessing/atoms_to_graphs.py b/ocpmodels/preprocessing/atoms_to_graphs.py index 55b9058d3..973a3e7d0 100644 --- a/ocpmodels/preprocessing/atoms_to_graphs.py +++ b/ocpmodels/preprocessing/atoms_to_graphs.py @@ -126,15 +126,15 @@ def _reshape_features(self, c_index, n_index, n_distance, offsets): return edge_index, edge_distances, cell_offsets - def convert( - self, - atoms, - ): + def convert(self, atoms, sid=None): """Convert a single atomic stucture to a graph. Args: atoms (ase.atoms.Atoms): An ASE atoms object. + sid (uniquely identifying object): An identifier that can be used to track the structure in downstream + tasks. Common sids used in OCP datasets include unique strings or integers. + Returns: data (torch_geometric.data.Data): A torch geometic data object with positions, atomic_numbers, tags, and optionally, energy, forces, distances, edges, and periodic boundary conditions. @@ -159,6 +159,10 @@ def convert( tags=tags, ) + # Optionally add a systemid (sid) to the object + if sid is not None: + data.sid = sid + # optionally include other properties if self.r_edges: # run internal functions to get padded indices and distances diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py index 7585a8102..6886f78fd 100644 --- a/tests/datasets/test_ase_datasets.py +++ b/tests/datasets/test_ase_datasets.py @@ -1,20 +1,29 @@ -import pytest -from ase import build, db -from ase.io import write, Trajectory import os + import numpy as np +import pytest +from ase import build, db +from ase.calculators.singlepoint import SinglePointCalculator +from ase.io import Trajectory, write from ocpmodels.datasets import ( - AseReadDataset, AseDBDataset, + AseReadDataset, AseReadMultiStructureDataset, ) +from ocpmodels.datasets.lmdb_database import LMDBDatabase structures = [ build.molecule("H2O", vacuum=4), build.bulk("Cu"), build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), ] +for atoms in structures: + calc = SinglePointCalculator(atoms, energy=1, forces=atoms.positions) + atoms.calc = calc + atoms.info["test_extensive_property"] = 3 * len(atoms) + +structures[2].set_pbc(True) def test_ase_read_dataset(): @@ -35,6 +44,7 @@ def test_ase_read_dataset(): assert len(dataset) == len(structures) data = dataset[0] + del data for i in range(len(structures)): os.remove( @@ -56,11 +66,11 @@ def test_ase_db_dataset(): except FileNotFoundError: pass - database = db.connect( + with db.connect( os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") - ) - for i, structure in enumerate(structures): - database.write(structure) + ) as database: + for i, structure in enumerate(structures): + database.write(structure) dataset = AseDBDataset( config={ @@ -73,6 +83,246 @@ def test_ase_db_dataset(): assert len(dataset) == len(structures) data = dataset[0] + del data + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") + ) + + +def test_ase_db_dataset_folder(): + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb1.db" + ) + ) + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb2.db" + ) + ) + except FileNotFoundError: + pass + + with db.connect( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") + ) as database: + for i, structure in enumerate(structures): + database.write(structure) + + with db.connect( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") + ) as database: + for i, structure in enumerate(structures): + database.write(structure) + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "./" + ), + } + ) + + assert len(dataset) == len(structures) * 2 + data = dataset[0] + del data + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") + ) + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") + ) + + +def test_ase_db_dataset_list(): + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb1.db" + ) + ) + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb2.db" + ) + ) + except FileNotFoundError: + pass + + with db.connect( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") + ) as database: + for i, structure in enumerate(structures): + database.write(structure) + + with db.connect( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") + ) as database: + for i, structure in enumerate(structures): + database.write(structure) + + dataset = AseDBDataset( + config={ + "src": [ + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb1.db" + ), + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb2.db" + ), + ] + } + ) + + assert len(dataset) == len(structures) * 2 + data = dataset[0] + del data + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb1.db") + ) + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb2.db") + ) + + +def test_ase_lmdb_dataset(): + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" + ) + ) + except FileNotFoundError: + pass + + with LMDBDatabase( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") + ) as database: + for i, structure in enumerate(structures): + database.write(structure) + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" + ), + } + ) + + assert len(dataset) == len(structures) + data = dataset[0] + del data + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") + ) + + +def test_lmdb_metadata_guesser(): + # Cleanup old lmdb in case it's left over from previous tests + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" + ) + ) + except FileNotFoundError: + pass + + # Write an LMDB + with LMDBDatabase( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") + ) as database: + for i, structure in enumerate(structures): + database.write(structure, data=structure.info) + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb" + ), + } + ) + + metadata = dataset.get_metadata() + + # Confirm energy metadata guessed properly + assert metadata["targets"]["energy"]["extensive"] is False + assert metadata["targets"]["energy"]["shape"] == () + assert metadata["targets"]["energy"]["type"] == "per-image" + + # Confirm forces metadata guessed properly + assert metadata["targets"]["forces"]["shape"] == (3,) + assert metadata["targets"]["forces"]["extensive"] is True + assert metadata["targets"]["forces"]["type"] == "per-atom" + + # Confirm forces metadata guessed properly + assert ( + metadata["targets"]["info.test_extensive_property"]["extensive"] + is True + ) + assert metadata["targets"]["info.test_extensive_property"]["shape"] == () + assert ( + metadata["targets"]["info.test_extensive_property"]["type"] + == "per-image" + ) + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.lmdb") + ) + + +def test_ase_metadata_guesser(): + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.db" + ) + ) + except FileNotFoundError: + pass + + with db.connect( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") + ) as database: + for i, structure in enumerate(structures): + database.write(structure, data=structure.info) + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.db" + ), + } + ) + + metadata = dataset.get_metadata() + + # Confirm energy metadata guessed properly + assert metadata["targets"]["energy"]["extensive"] is False + assert metadata["targets"]["energy"]["shape"] == () + assert metadata["targets"]["energy"]["type"] == "per-image" + + # Confirm forces metadata guessed properly + assert metadata["targets"]["forces"]["shape"] == (3,) + assert metadata["targets"]["forces"]["extensive"] is True + assert metadata["targets"]["forces"]["type"] == "per-atom" + + # Confirm forces metadata guessed properly + assert ( + metadata["targets"]["info.test_extensive_property"]["extensive"] + is True + ) + assert metadata["targets"]["info.test_extensive_property"]["shape"] == () + assert ( + metadata["targets"]["info.test_extensive_property"]["type"] + == "per-image" + ) + dataset = AseDBDataset( config={ "src": os.path.join( @@ -101,6 +351,7 @@ def test_ase_db_dataset(): assert len(dataset) == len(structures) + len(new_structures) - 1 data = dataset[:] + assert data os.remove( os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") diff --git a/tests/datasets/test_ase_lmdb.py b/tests/datasets/test_ase_lmdb.py new file mode 100644 index 000000000..b6c0b26e1 --- /dev/null +++ b/tests/datasets/test_ase_lmdb.py @@ -0,0 +1,183 @@ +import os +from pathlib import Path + +import numpy as np +import pytest +import tqdm +from ase import build +from ase.calculators.singlepoint import SinglePointCalculator +from ase.constraints import FixAtoms +from ase.io import write + +from ocpmodels.datasets.lmdb_database import LMDBDatabase + +DB_NAME = "ase_lmdb.lmdb" +N_WRITES = 100 +N_READS = 200 + + +def cleanup_asedb(): + if Path(DB_NAME).is_file(): + Path(DB_NAME).unlink() + if Path(f"{DB_NAME}-lock").is_file(): + Path(f"{DB_NAME}-lock").unlink() + + +test_structures = [ + build.molecule("H2O", vacuum=4), + build.bulk("Cu"), + build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), +] + +test_structures[2].set_constraint(FixAtoms(indices=[0, 1])) + + +def generate_random_structure(): + + # Make base slab + slab = build.fcc111("Cu", size=(4, 4, 3), vacuum=10.0) + + # Randomly set some elements + slab.set_chemical_symbols( + np.random.choice(["Cu", "Ag", "Au", "Pd"], size=(len(slab))) + ) + + # Randomly set some positions + slab.positions = np.random.normal(size=slab.positions.shape) + + # Add entries for energy/forces/stress/magmom/etc. + # Property must be one of the ASE core properties to + # go in to a singlepointcalculator and get stored as + # fields correctly + spc = SinglePointCalculator( + slab, + energy=np.random.normal(), + forces=np.random.normal(size=slab.positions.shape), + stress=np.random.normal(size=(3, 3)), + magmom=np.random.normal(size=(len(slab))), + ) + slab.set_calculator(spc) + + # Make up some other properties to show how we can include arbitrary outputs + slab.info["test_info_property_1"] = np.random.normal(size=(3, 3)) + slab.info["test_info_property_2"] = np.random.normal(size=(len(slab), 3)) + + return slab + + +def write_random_atoms(): + + slab = build.fcc111("Cu", size=(4, 4, 3), vacuum=10.0) + with LMDBDatabase(DB_NAME) as db: + + for structure in test_structures: + db.write(structure) + + for i in tqdm.tqdm(range(N_WRITES)): + slab = generate_random_structure() + + # Save the slab info, and make sure the info gets put in as data + db.write(slab, data=slab.info) + + +def test_aselmdb_write(): + + # Representative structure + write_random_atoms() + + with LMDBDatabase(DB_NAME, readonly=True) as db: + for i, structure in enumerate(test_structures): + assert str(structure) == str(db._get_row_by_index(i).toatoms()) + + cleanup_asedb() + + +def test_aselmdb_count(): + + # Representative structure + write_random_atoms() + + with LMDBDatabase(DB_NAME, readonly=True) as db: + assert db.count() == N_WRITES + len(test_structures) + + cleanup_asedb() + + +def test_aselmdb_delete(): + cleanup_asedb() + + # Representative structure + write_random_atoms() + + with LMDBDatabase(DB_NAME) as db: + for i in range(5): + + # Note the available ids list is updating + # but the ids themselves are fixed. + db.delete([db.ids[0]]) + + assert db.count() == N_WRITES + len(test_structures) - 5 + + cleanup_asedb() + + +def test_aselmdb_randomreads(): + + write_random_atoms() + + with LMDBDatabase(DB_NAME, readonly=True) as db: + for i in tqdm.tqdm(range(N_READS)): + total_size = db.count() + row = db._get_row_by_index(np.random.choice(total_size)).toatoms() + del row + cleanup_asedb() + + +def test_aselmdb_constraintread(): + + write_random_atoms() + + with LMDBDatabase(DB_NAME, readonly=True) as db: + atoms = db._get_row_by_index(2).toatoms() + + assert type(atoms.constraints[0]) == FixAtoms + + cleanup_asedb() + + +def update_keyvalue_pair(): + + write_random_atoms() + with LMDBDatabase(DB_NAME) as db: + db.update(1, test=5) + + with LMDBDatabase(DB_NAME) as db: + row = db.get_row_by_id(1) + assert row.test == 5 + + cleanup_asedb() + + +def update_atoms(): + + write_random_atoms() + with LMDBDatabase(DB_NAME) as db: + db.update(40, atoms=test_structures[-1]) + + with LMDBDatabase(DB_NAME) as db: + row = db.get_row_by_id(40) + assert str(row.toatoms()) == str(test_structures[-1]) + + cleanup_asedb() + + +def test_metadata(): + write_random_atoms() + + with LMDBDatabase(DB_NAME) as db: + db.metadata = {"test": True} + + with LMDBDatabase(DB_NAME, readonly=True) as db: + assert db.metadata["test"] is True + + cleanup_asedb()