Skip to content

Commit

Permalink
Support IS2RE-Direct Training with ASE Read Datasets (#579)
Browse files Browse the repository at this point in the history
* Support IS2RE direct with ase read datasets

* Make isort happy

* Make isort happier

* Add more ASE dataset tests

* Adjust error message

* Update TRAIN.md
  • Loading branch information
emsunshine committed Sep 30, 2023
1 parent 936a7be commit 82237be
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 36 deletions.
68 changes: 35 additions & 33 deletions TRAIN.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ final, relaxed state. This can be done by training a model to predict per-atom f
task and then running an iterative relaxation. Although we present an iterative approach, models that directly predict relaxed states are also possible. The iterative approach IS2RS task uses the same configuration files as the S2EF task `configs/s2ef` and follows the same training scheme above.

To perform an iterative relaxation, ensure the following is added to the configuration files of the models you wish to run relaxations on:
```
```yaml
# Relaxation options
relax_dataset:
src: data/is2re/all/val_id/data.lmdb # path to lmdb of systems to be relaxed (uses same lmdbs as is2re)
Expand All @@ -268,7 +268,7 @@ relax_opt:
```

After training, relaxations can be run by:
```
```bash
python main.py --mode run-relaxations --config-yml configs/s2ef/2M/schnet/schnet.yml \
--checkpoint checkpoints/[TIMESTAMP]/checkpoint.pt
```
Expand All @@ -281,7 +281,7 @@ EvalAI expects results to be structured in a specific format for a submission to
### S2EF/IS2RE:
1. Run predictions `--mode predict` on all 4 splits, generating `[s2ef/is2re]_predictions.npz` files for each split.
2. Run the following command:
```
```bash
python make_submission_file.py --id path/to/id/file.npz --ood-ads path/to/ood_ads/file.npz \
--ood-cat path/to/ood_cat/file.npz --ood-both path/to/ood_both/file.npz --out-path submission_file.npz
```
Expand All @@ -292,7 +292,7 @@ EvalAI expects results to be structured in a specific format for a submission to
### IS2RS:
1. Ensure `write_pos: True` is included in your configuration file. Run relaxations `--mode run-relaxations` on all 4 splits, generating `relaxed_positions.npz` files for each split.
2. Run the following command:
```
```bash
python make_submission_file.py --id path/to/id/relaxed_positions.npz --ood-ads path/to/ood_ads/relaxed_positions.npz \
--ood-cat path/to/ood_cat/relaxed_positions.npz --ood-both path/to/ood_both/relaxed_positions.npz --out-path is2rs_submission.npz
```
Expand All @@ -305,7 +305,7 @@ EvalAI expects results to be structured in a specific format for a submission to

For the IS2RE-Total task, the model takes the initial structure as input and predicts the total DFT energy of the relaxed structure. This task is more general and more challenging than the original OC20 IS2RE task that predicts adsorption energy. To train an OC22 IS2RE-Total model use the `EnergyTrainer` with the `OC22LmdbDataset` by including these lines in your configuration file:

```
```yaml
trainer: energy # Use the EnergyTrainer
task:
Expand All @@ -318,7 +318,7 @@ You can find examples configuration files in [`configs/oc22/is2re`](https://gith

The S2EF-Total task takes a structure and predicts the total DFT energy and per-atom forces. This differs from the original OC20 S2EF task because it predicts total energy instead of adsorption energy. To train an OC22 S2EF-Total model use the ForcesTrainer with the OC22LmdbDataset by including these lines in your configuration file:

```
```yaml
trainer: forces # Use the ForcesTrainer
task:
Expand All @@ -331,7 +331,7 @@ You can find examples configuration files in [`configs/oc22/s2ef`](https://githu

Training on OC20 total energies whether independently or jointly with OC22 requires a path to the `oc20_ref` (download link provided below) to be specified in the configuration file. These are necessary to convert OC20 adsorption energies into their corresponding total energies. The following changes in the configuration file capture these changes:

```
```yaml
task:
dataset: oc22_lmdb
...
Expand All @@ -358,7 +358,7 @@ EvalAI expects results to be structured in a specific format for a submission to
### S2EF-Total/IS2RE-Total:
1. Run predictions `--mode predict` on both the id and ood splits, generating `[s2ef/is2re]_predictions.npz` files for each split.
2. Run the following command:
```
```bash
python make_submission_file.py --dataset OC22 --id path/to/id/file.npz --ood path/to/ood_ads/file.npz --out-path submission_file.npz
```
Where `file.npz` corresponds to the respective `[s2ef/is2re]_predictions.npz` files generated for the corresponding task. The final submission file will be written to `submission_file.npz` (rename accordingly). The `dataset` argument specifies which dataset is being considered — this only needs to be set for OC22 predictions because OC20 is the default.
Expand All @@ -381,7 +381,7 @@ If your data is already in an [ASE Database](https://databases.fysik.dtu.dk/ase/

To use this dataset, we will just have to change our config files to use the ASE DB Dataset rather than the LMDB Dataset:

```
```yaml
task:
dataset: ase_db
Expand All @@ -399,6 +399,7 @@ dataset:
# Set these if you want to train on energy/forces
# Energy/force information must be in the ASE DB!
keep_in_memory: False # Keeping the dataset in memory reduces random reads and is extremely fast, but this is only feasible for relatively small datasets!
include_relaxed_energy: False # Read the last structure's energy and save as "y_relaxed" for IS2RE-Direct training
val:
src:
a2g_args:
Expand All @@ -418,29 +419,30 @@ It is possible to train/predict directly on ASE-readable files. This is only rec
### Single-Structure Files
This dataset assumes a single structure will be obtained from each file:

```
```yaml
task:
dataset: ase_read
dataset:
train:
src: # The folder that contains ASE-readable files
pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif".
ase_read_args:
# Keyword arguments for ase.io.read()
a2g_args:
# Include energy and forces for training purposes
# If True, the energy/forces must be readable from the file (ex. OUTCAR)
r_energy: True
r_forces: True
keep_in_memory: False
src: # The folder that contains ASE-readable files
pattern: # Pattern matching each file you want to read (e.g. "*/POSCAR"). Search recursively with two wildcards: "**/*.cif".
include_relaxed_energy: False # Read the last structure's energy and save as "y_relaxed" for IS2RE-Direct training
ase_read_args:
# Keyword arguments for ase.io.read()
a2g_args:
# Include energy and forces for training purposes
# If True, the energy/forces must be readable from the file (ex. OUTCAR)
r_energy: True
r_forces: True
keep_in_memory: False
```

### Multi-structure Files
This dataset supports reading files that each contain multiple structure (for example, an ASE .traj file). Using an index file, which tells the dataset how many structures each file contains, is recommended. Otherwise, the dataset is forced to load every file at startup and count the number of structures!

```
```yaml
task:
dataset: ase_read_multi
Expand All @@ -451,15 +453,15 @@ dataset:
/path/to/relaxation2.traj 150
...
# If using an index file, the src and pattern are not necessary
src: # The folder that contains ASE-readable files
pattern: # Pattern matching each file you want to read (e.g. "*.traj"). Search recursively with two wildcards: "**/*.xyz".

ase_read_args:
# Keyword arguments for ase.io.read()
a2g_args:
# Include energy and forces for training purposes
r_energy: True
r_forces: True
keep_in_memory: False
# If using an index file, the src and pattern are not necessary
src: # The folder that contains ASE-readable files
pattern: # Pattern matching each file you want to read (e.g. "*.traj"). Search recursively with two wildcards: "**/*.xyz".
ase_read_args:
# Keyword arguments for ase.io.read()
a2g_args:
# Include energy and forces for training purposes
r_energy: True
r_forces: True
keep_in_memory: False
```
30 changes: 30 additions & 0 deletions ocpmodels/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def __getitem__(self, idx):
data_object, **self.config.get("transform_args", {})
)

if self.config.get("include_relaxed_energy", False):
data_object.y_relaxed = self.get_relaxed_energy(self.ids[idx])

return data_object

@abstractmethod
Expand Down Expand Up @@ -201,6 +204,10 @@ class AseReadDataset(AseAtomsDataset):
to iterate over a dataset many times (e.g. training for many epochs).
Not recommended for large datasets.
include_relaxed_energy (bool): Include the relaxed energy in the resulting data object.
The relaxed structure is assumed to be the final structure in the file
(e.g. the last frame of a .traj).
atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable
transform_args (dict): Additional keyword arguments for the transform callable
Expand All @@ -224,6 +231,10 @@ def load_dataset_get_ids(self, config) -> List[Path]:
if self.path.is_file():
raise Exception("The specified src is not a directory")

if self.config.get("include_relaxed_energy", False):
self.relaxed_ase_read_args = copy.deepcopy(self.ase_read_args)
self.relaxed_ase_read_args["index"] = "-1"

return list(self.path.glob(f'{config["pattern"]}'))

def get_atoms_object(self, identifier):
Expand All @@ -235,6 +246,10 @@ def get_atoms_object(self, identifier):

return atoms

def get_relaxed_energy(self, identifier):
relaxed_atoms = ase.io.read(identifier, **self.relaxed_ase_read_args)
return relaxed_atoms.get_potential_energy(apply_constraint=False)


@registry.register_dataset("ase_read_multi")
class AseReadMultiStructureDataset(AseAtomsDataset):
Expand Down Expand Up @@ -279,6 +294,10 @@ class AseReadMultiStructureDataset(AseAtomsDataset):
to iterate over a dataset many times (e.g. training for many epochs).
Not recommended for large datasets.
include_relaxed_energy (bool): Include the relaxed energy in the resulting data object.
The relaxed structure is assumed to be the final structure in the file
(e.g. the last frame of a .traj).
use_tqdm (bool): Use TQDM progress bar when initializing dataset
atoms_transform_args (dict): Additional keyword arguments for the atoms_transform callable
Expand Down Expand Up @@ -347,6 +366,12 @@ def get_atoms_object(self, identifier):
def get_metadata(self):
return {}

def get_relaxed_energy(self, identifier):
relaxed_atoms = ase.io.read(
"".join(identifier.split(" ")[:-1]), **self.ase_read_args
)[-1]
return relaxed_atoms.get_potential_energy(apply_constraint=False)


class dummy_list(list):
def __init__(self, max) -> None:
Expand Down Expand Up @@ -513,3 +538,8 @@ def get_metadata(self):
return self.guess_target_metadata()
else:
return copy.deepcopy(self.dbs[0].metadata)

def get_relaxed_energy(self, identifier):
raise NotImplementedError(
"IS2RE-Direct training with an ASE DB is not currently supported."
)
53 changes: 50 additions & 3 deletions tests/datasets/test_ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,15 @@ def test_ase_read_dataset() -> None:
data = dataset[0]
del data

dataset.close_db()

for i in range(len(structures)):
os.remove(
os.path.join(
os.path.dirname(os.path.abspath(__file__)), f"{i}.cif"
)
)

dataset.close_db()


def test_ase_db_dataset() -> None:
try:
Expand Down Expand Up @@ -380,11 +380,18 @@ def test_ase_multiread_dataset() -> None:

atoms_objects = [build.bulk("Cu", a=a) for a in np.linspace(3.5, 3.7, 10)]

energies = np.linspace(1, 0, len(atoms_objects))

traj = Trajectory(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj"),
mode="w",
)
for atoms in atoms_objects:

for atoms, energy in zip(atoms_objects, energies):
calc = SinglePointCalculator(
atoms, energy=energy, forces=atoms.positions
)
atoms.calc = calc
traj.write(atoms)

dataset = AseReadMultiStructureDataset(
Expand Down Expand Up @@ -423,6 +430,46 @@ def test_ase_multiread_dataset() -> None:
assert len(dataset) == len(atoms_objects)
[dataset[:]]

dataset = AseReadMultiStructureDataset(
config={
"index_file": os.path.join(
os.path.dirname(os.path.abspath(__file__)), "test_index_file"
),
"a2g_args": {
"r_energy": True,
"r_forces": True,
},
"include_relaxed_energy": True,
}
)

assert len(dataset) == len(atoms_objects)
[dataset[:]]

assert hasattr(dataset[0], "y_relaxed")
assert dataset[0].y_relaxed != dataset[0].y
assert dataset[-1].y_relaxed == dataset[-1].y

dataset = AseReadDataset(
config={
"src": os.path.join(os.path.dirname(os.path.abspath(__file__))),
"pattern": "*.traj",
"ase_read_args": {
"index": "0",
},
"a2g_args": {
"r_energy": True,
"r_forces": True,
},
"include_relaxed_energy": True,
}
)

[dataset[:]]

assert hasattr(dataset[0], "y_relaxed")
assert dataset[0].y_relaxed != dataset[0].y

os.remove(
os.path.join(os.path.dirname(os.path.abspath(__file__)), "test.traj")
)
Expand Down

0 comments on commit 82237be

Please sign in to comment.