diff --git a/TRAIN.md b/TRAIN.md index 159f7c2ad..165f853d5 100644 --- a/TRAIN.md +++ b/TRAIN.md @@ -16,6 +16,12 @@ - [Joint Training](#joint-training) - [Create EvalAI submission files](#create-evalai-oc22-submission-files) - [S2EF-Total/IS2RE-Total](#s2ef-totalis2re-total) +- [Using Your Own Data](#using-your-own-data) + - [Writing an LMDB](#writing-an-lmdb) + - [Using an ASE Database](#using-an-ase-database) + - [Using ASE-Readable Files](#using-ase-readable-files) + - [Single-Structure Files](#single-structure-files) + - [Multi-Structure Files](#multi-structure-files) ## Getting Started @@ -323,3 +329,103 @@ EvalAI expects results to be structured in a specific format for a submission to ``` Where `file.npz` corresponds to the respective `[s2ef/is2re]_predictions.npz` files generated for the corresponding task. The final submission file will be written to `submission_file.npz` (rename accordingly). The `dataset` argument specifies which dataset is being considered — this only needs to be set for OC22 predictions because OC20 is the default. 3. Upload `submission_file.npz` to EvalAI. + + +# Using Your Own Data + +There are multiple ways to train and evaluate OCP models on data other than OC20 and OC22. Writing an LMDB is the most performant option. However, ASE-based dataset formats are also included as a convenience for people with existing data who simply want to try OCP tools without needing to learn about LMDBs. + +This tutorial will briefly discuss the basic use of these dataset formats. For more detailed information about the ASE datasets, see the [source code and docstrings](ocpmodels/datasets/ase_datasets.py). + +## Writing an LMDB + +Storing your data in an LMDB ensures very fast random read speeds for the fastest supported throughput. This is the recommended option for the majority of OCP use cases. For more information about writing your data to an LMDB, please see the [LMDB Dataset Tutorial](https://github.com/Open-Catalyst-Project/ocp/blob/main/tutorials/lmdb_dataset_creation.ipynb). + +## Using an ASE Database + +If your data is already in an [ASE Database](https://databases.fysik.dtu.dk/ase/ase/db/db.html), no additional preprocessing is necessary before running training/prediction! Although the ASE DB backends may not be sufficiently high throughput for all use cases, they are generally considered "fast enough" to train on a reasonably-sized dataset with 1-2 GPUs or predict with a single GPU. If you want to effictively utilize more resources than this, please be aware of the potential for this bottleneck and consider writing your data to an LMDB. If your dataset is small enough to fit in CPU memory, use the `keep_in_memory: True` option to avoid this bottleneck. + +To use this dataset, we will just have to change our config files to use the ASE DB Dataset rather than the LMDB Dataset: + +``` +task: + dataset: ase_db + +dataset: + train: + src: # The path/address to your ASE DB + connect_args: + # Keyword arguments for ase.db.connect() + select_args: + # Keyword arguments for ase.db.select() + # These can be used to query/filter the ASE DB + a2g_args: + r_energy: True + r_forces: True + # Set these if you want to train on energy/forces + # Energy/force information must be in the ASE DB! + keep_in_memory: False # Keeping the dataset in memory reduces random reads and is extremely fast, but this is only feasible for relatively small datasets! + val: + src: + a2g_args: + r_energy: True + r_forces: True + test: + src: + a2g_args: + r_energy: False + r_forces: False + # It is not necessary to have energy or forces if you are just making predictions. +``` +## Using ASE-Readable Files + +It is possible to train/predict directly on ASE-readable files. This is only recommended for smaller datasets, as directories of many small files do not scale efficiently on all computing infrastructures. There are two options for loading data with the ASE reader: + +### Single-Structure Files +This dataset assumes a single structure will be obtained from each file: + +``` +task: + dataset: ase_read + +dataset: + train: + src: # The folder that contains ASE-readable files + pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif". + + ase_read_args: + # Keyword arguments for ase.io.read() + a2g_args: + # Include energy and forces for training purposes + # If True, the energy/forces must be readable from the file (ex. OUTCAR) + r_energy: True + r_forces: True + keep_in_memory: False +``` + +### Multi-structure Files +This dataset supports reading files that each contain multiple structure (for example, an ASE .traj file). Using an index file, which tells the dataset how many structures each file contains, is recommended. Otherwise, the dataset is forced to load every file at startup and count the number of structures! + +``` +task: + dataset: ase_read_multi + +dataset: + train: + index_file: Filepath to an index file which contains each filename and the number of structures in each file. e.g.: + /path/to/relaxation1.traj 200 + /path/to/relaxation2.traj 150 + ... + + # If using an index file, the src and pattern are not necessary + src: # The folder that contains ASE-readable files + pattern: # Pattern matching each file you want to read (e.g. "*.traj"). Search recursively with two wildcards: "**/*.xyz". + + ase_read_args: + # Keyword arguments for ase.io.read() + a2g_args: + # Include energy and forces for training purposes + r_energy: True + r_forces: True + keep_in_memory: False +``` diff --git a/ocpmodels/datasets/__init__.py b/ocpmodels/datasets/__init__.py index 9ed38d832..36d8881ec 100644 --- a/ocpmodels/datasets/__init__.py +++ b/ocpmodels/datasets/__init__.py @@ -10,3 +10,9 @@ data_list_collater, ) from .oc22_lmdb_dataset import OC22LmdbDataset + +from .ase_datasets import ( + AseReadDataset, + AseReadMultiStructureDataset, + AseDBDataset, +) diff --git a/ocpmodels/datasets/ase_datasets.py b/ocpmodels/datasets/ase_datasets.py new file mode 100644 index 000000000..fd37cd55f --- /dev/null +++ b/ocpmodels/datasets/ase_datasets.py @@ -0,0 +1,374 @@ +import ase +import warnings +import numpy as np + +from pathlib import Path +from torch import tensor +from torch.utils.data import Dataset +from tqdm import tqdm + +from ocpmodels.common.registry import registry +from ocpmodels.preprocessing import AtomsToGraphs + + +def apply_one_tags(atoms, skip_if_nonzero=True, skip_always=False): + """ + This function will apply tags of 1 to an ASE atoms object. + It is used as an atoms_transform in the datasets contained in this file. + + Certain models will treat atoms differently depending on their tags. + For example, GemNet-OC by default will only compute triplet and quadruplet interactions + for atoms with non-zero tags. This model throws an error if there are no tagged atoms. + For this reason, the default behavior is to tag atoms in structures with no tags. + + args: + skip_if_nonzero (bool): If at least one atom has a nonzero tag, do not tag any atoms + + skip_always (bool): Do not apply any tags. This arg exists so that this function can be disabled + without needing to pass a callable (which is currently difficult to do with main.py) + """ + if skip_always: + return atoms + + if np.all(atoms.get_tags() == 0) or not skip_if_nonzero: + atoms.set_tags(np.ones(len(atoms))) + + return atoms + + +class AseAtomsDataset(Dataset): + """ + This is a base Dataset that includes helpful utilities for turning + ASE atoms objects into OCP-usable data objects. + + Derived classes must add at least two things: + self.get_atoms_object(): a function that takes an identifier and returns a corresponding atoms object + self.id: a list of all possible identifiers that can be passed into self.get_atoms_object() + Identifiers need not be any particular type. + """ + + def __init__(self, config, transform=None, atoms_transform=apply_one_tags): + self.config = config + + a2g_args = config.get("a2g_args", {}) + self.a2g = AtomsToGraphs(**a2g_args) + + self.transform = transform + self.atoms_transform = atoms_transform + + if self.config.get("keep_in_memory", False): + self.data_objects = {} + + # Derived classes should extend this functionality to also create self.id, + # a list of identifiers that can be passed to get_atoms_object() + + def __len__(self): + return len(self.id) + + def __getitem__(self, idx): + # Handle slicing + if isinstance(idx, slice): + return [self[i] for i in range(len(self.id))[idx]] + + # Check if data object is already in memory + if self.config.get("keep_in_memory", False): + if self.id[idx] in self.data_objects: + return self.data_objects[self.id[idx]] + + # Get atoms object via derived class method + atoms = self.get_atoms_object(self.id[idx]) + + # Transform atoms object + if self.atoms_transform is not None: + atoms = self.atoms_transform( + atoms, **self.config.get("atoms_transform_args", {}) + ) + + # Convert to data object + data_object = self.a2g.convert(atoms) + data_object.sid = tensor([idx]) + data_object.pbc = tensor(atoms.pbc) + + # Transform data object + if self.transform is not None: + data_object = self.transform( + data_object, **self.config.get("transform_args", {}) + ) + + # Save in memory, if specified + if self.config.get("keep_in_memory", False): + self.data_objects[self.id[idx]] = data_object + + return data_object + + def get_atoms_object(self, identifier): + raise NotImplementedError( + "Returns an ASE atoms object. Derived classes should implement this funciton." + ) + + def close_db(self): + pass + # This method is sometimes called by a trainer + + +@registry.register_dataset("ase_read") +class AseReadDataset(AseAtomsDataset): + """ + This Dataset uses ase.io.read to load data from a directory on disk. + This is intended for small-scale testing and demonstrations of OCP. + Larger datasets are better served by the efficiency of other dataset types + such as LMDB. + + For a full list of ASE-readable filetypes, see + https://wiki.fysik.dtu.dk/ase/ase/io/io.html + + args: + config (dict): + src (str): The source folder that contains your ASE-readable files + + pattern (str): Filepath matching each file you want to read + ex. "*/POSCAR", "*.cif", "*.xyz" + search recursively with two wildcards: "**/POSCAR" or "**/*.cif" + + a2g_args (dict): Keyword arguments for ocpmodels.preprocessing.AtomsToGraphs() + default options will work for most users + + If you are using this for a training dataset, set + "r_energy":True and/or "r_forces":True as appropriate + In that case, energy/forces must be in the files you read (ex. OUTCAR) + + ase_read_args (dict): Keyword arguments for ase.io.read() + + keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need + to iterate over a dataset many times (e.g. training for many epochs). + Not recommended for large datasets. + + atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable + + transform_args (dict): Additional keyword arguments for the transform callable + + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms + object. Useful for applying tags, for example. + + transform (callable, optional): Additional preprocessing function for the Data object + + """ + + def __init__(self, config, transform=None, atoms_transform=apply_one_tags): + super(AseReadDataset, self).__init__( + config, transform, atoms_transform + ) + self.ase_read_args = config.get("ase_read_args", {}) + + if ":" in self.ase_read_args.get("index", ""): + raise NotImplementedError( + "To read multiple structures from a single file, please use AseReadMultiStructureDataset." + ) + + self.path = Path(self.config["src"]) + if self.path.is_file(): + raise Exception("The specified src is not a directory") + self.id = list(self.path.glob(f'{self.config["pattern"]}')) + + def get_atoms_object(self, identifier): + try: + atoms = ase.io.read(identifier, **self.ase_read_args) + except Exception as err: + warnings.warn(f"{err} occured for: {identifier}") + raise err + + return atoms + + +@registry.register_dataset("ase_read_multi") +class AseReadMultiStructureDataset(AseAtomsDataset): + """ + This Dataset can read multiple structures from each file using ase.io.read. + The disadvantage is that all files must be read at startup. + This is a significant cost for large datasets. + + This is intended for small-scale testing and demonstrations of OCP. + Larger datasets are better served by the efficiency of other dataset types + such as LMDB. + + For a full list of ASE-readable filetypes, see + https://wiki.fysik.dtu.dk/ase/ase/io/io.html + + args: + config (dict): + src (str): The source folder that contains your ASE-readable files + + pattern (str): Filepath matching each file you want to read + ex. "*.traj", "*.xyz" + search recursively with two wildcards: "**/POSCAR" or "**/*.cif" + + index_file (str): Filepath to an indexing file, which contains each filename + and the number of structures contained in each file. For instance: + + /path/to/relaxation1.traj 200 + /path/to/relaxation2.traj 150 + + This will overrule the src and pattern that you specify! + + a2g_args (dict): Keyword arguments for ocpmodels.preprocessing.AtomsToGraphs() + default options will work for most users + + If you are using this for a training dataset, set + "r_energy":True and/or "r_forces":True as appropriate + In that case, energy/forces must be in the files you read (ex. OUTCAR) + + ase_read_args (dict): Keyword arguments for ase.io.read() + + keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need + to iterate over a dataset many times (e.g. training for many epochs). + Not recommended for large datasets. + + use_tqdm (bool): Use TQDM progress bar when initializing dataset + + atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable + + transform_args (dict): Additional keyword arguments for the transform callable + + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms + object. Useful for applying tags, for example. + + transform (callable, optional): Additional preprocessing function for the Data object + """ + + def __init__(self, config, transform=None, atoms_transform=apply_one_tags): + super(AseReadMultiStructureDataset, self).__init__( + config, transform, atoms_transform + ) + self.ase_read_args = self.config.get("ase_read_args", {}) + if not hasattr(self.ase_read_args, "index"): + self.ase_read_args["index"] = ":" + + if config.get("index_file", None) is not None: + f = open(config["index_file"], "r") + index = f.readlines() + + self.id = [] + for line in index: + filename = line.split(" ")[0] + for i in range(int(line.split(" ")[1])): + self.id.append(f"{filename} {i}") + + return + + self.path = Path(self.config["src"]) + if self.path.is_file(): + raise Exception("The specified src is not a directory") + filenames = list(self.path.glob(f'{self.config["pattern"]}')) + + self.id = [] + + if self.config.get("use_tqdm", True): + filenames = tqdm(filenames) + for filename in filenames: + try: + structures = ase.io.read(filename, **self.ase_read_args) + except Exception as err: + warnings.warn(f"{err} occured for: {filename}") + else: + for i, structure in enumerate(structures): + self.id.append(f"{filename} {i}") + + if self.config.get("keep_in_memory", False): + # Transform atoms object + if self.atoms_transform is not None: + atoms = self.atoms_transform( + structure, + **self.config.get("atoms_transform_args", {}), + ) + + # Convert to data object + data_object = self.a2g.convert(atoms) + + # Transform data object + if self.transform is not None: + data_object = self.transform( + data_object, + **self.config.get("transform_args", {}), + ) + + self.data_objects[f"{filename} {i}"] = data_object + + def get_atoms_object(self, identifier): + try: + atoms = ase.io.read( + "".join(identifier.split(" ")[:-1]), **self.ase_read_args + )[int(identifier.split(" ")[-1])] + except Exception as err: + warnings.warn(f"{err} occured for: {identifier}") + raise err + + return atoms + + +@registry.register_dataset("ase_db") +class AseDBDataset(AseAtomsDataset): + """ + This Dataset connects to an ASE Database, allowing the storage of atoms objects + with a variety of backends including JSON, SQLite, and database server options. + + For more information, see: + https://databases.fysik.dtu.dk/ase/ase/db/db.html + + args: + config (dict): + src (str): The path to or connection address of your ASE DB + + connect_args (dict): Keyword arguments for ase.db.connect() + + select_args (dict): Keyword arguments for ase.db.select() + You can use this to query/filter your database + + a2g_args (dict): Keyword arguments for ocpmodels.preprocessing.AtomsToGraphs() + default options will work for most users + + If you are using this for a training dataset, set + "r_energy":True and/or "r_forces":True as appropriate + In that case, energy/forces must be in the database + + keep_in_memory (bool): Store data in memory. This helps avoid random reads if you need + to iterate over a dataset many times (e.g. training for many epochs). + Not recommended for large datasets. + + atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable + + transform_args (dict): Additional keyword arguments for the transform callable + + atoms_transform (callable, optional): Additional preprocessing function applied to the Atoms + object. Useful for applying tags, for example. + + transform (callable, optional): Additional preprocessing function for the Data object + """ + + def __init__(self, config, transform=None, atoms_transform=apply_one_tags): + super(AseDBDataset, self).__init__(config, transform, atoms_transform) + + self.db = self.connect_db( + self.config["src"], self.config.get("connect_args", {}) + ) + + self.select_args = self.config.get("select_args", {}) + + self.id = [row.id for row in self.db.select(**self.select_args)] + + def get_atoms_object(self, identifier): + return self.db._get_row(identifier).toatoms() + + def connect_db(self, address, connect_args={}): + db_type = connect_args.get("type", "extract_from_name") + if db_type == "lmdb" or ( + db_type == "extract_from_name" and address.split(".")[-1] == "lmdb" + ): + from ocpmodels.datasets.lmdb_database import LMDBDatabase + + return LMDBDatabase(address, readonly=True, **connect_args) + else: + return ase.db.connect(address, **connect_args) + + def close_db(self): + if hasattr(self.db, "close"): + self.db.close() diff --git a/tests/datasets/test_ase_datasets.py b/tests/datasets/test_ase_datasets.py new file mode 100644 index 000000000..7585a8102 --- /dev/null +++ b/tests/datasets/test_ase_datasets.py @@ -0,0 +1,183 @@ +import pytest +from ase import build, db +from ase.io import write, Trajectory +import os +import numpy as np + +from ocpmodels.datasets import ( + AseReadDataset, + AseDBDataset, + AseReadMultiStructureDataset, +) + +structures = [ + build.molecule("H2O", vacuum=4), + build.bulk("Cu"), + build.fcc111("Pt", size=[2, 2, 3], vacuum=8, periodic=True), +] + + +def test_ase_read_dataset(): + for i, structure in enumerate(structures): + write( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), f"{i}.cif" + ), + structure, + ) + + dataset = AseReadDataset( + config={ + "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), + "pattern": "*.cif", + } + ) + + assert len(dataset) == len(structures) + data = dataset[0] + + for i in range(len(structures)): + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), f"{i}.cif" + ) + ) + + dataset.close_db() + + +def test_ase_db_dataset(): + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.db" + ) + ) + except FileNotFoundError: + pass + + database = db.connect( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") + ) + for i, structure in enumerate(structures): + database.write(structure) + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.db" + ), + } + ) + + assert len(dataset) == len(structures) + data = dataset[0] + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.db" + ), + } + ) + + database.delete([1]) + + new_structures = [ + build.molecule("CH3COOH", vacuum=4), + build.bulk("Al"), + ] + + for i, structure in enumerate(new_structures): + database.write(structure) + + dataset = AseDBDataset( + config={ + "src": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "asedb.db" + ), + } + ) + + assert len(dataset) == len(structures) + len(new_structures) - 1 + data = dataset[:] + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "asedb.db") + ) + + dataset.close_db() + + +def test_ase_multiread_dataset(): + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test.traj" + ) + ) + except FileNotFoundError: + pass + + try: + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_index_file" + ) + ) + except FileNotFoundError: + pass + + atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)] + + traj = Trajectory( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj"), + mode="w", + ) + for atoms in atoms_objects: + traj.write(atoms) + + dataset = AseReadMultiStructureDataset( + config={ + "src": os.path.join(os.path.dirname(os.path.abspath(__file__))), + "pattern": "*.traj", + "keep_in_memory": True, + "atoms_transform_args": { + "skip_always": True, + }, + } + ) + + assert len(dataset) == len(atoms_objects) + [dataset[:]] + + f = open( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_index_file" + ), + "w", + ) + f.write( + f"{os.path.join(os.path.dirname(os.path.abspath(__file__)), 'test.traj')} {len(atoms_objects)}" + ) + f.close() + + dataset = AseReadMultiStructureDataset( + config={ + "index_file": os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_index_file" + ) + }, + ) + + assert len(dataset) == len(atoms_objects) + [dataset[:]] + + os.remove( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj") + ) + os.remove( + os.path.join( + os.path.dirname(os.path.abspath(__file__)), "test_index_file" + ) + )