Skip to content

Commit

Permalink
Bug Fixes for Tutorial (#525)
Browse files Browse the repository at this point in the history
* Patch for ase datasets missing fid

* Handle empty configurations of connect_args, select_args, and a2g_args

* Use more meaningful sid and fid values

* Fallback for non-numeric sid values

* Update for new versions of numpy

* More fixes for new numpy versions

* Address review comments
  • Loading branch information
emsunshine committed Aug 2, 2023
1 parent e2c5bcf commit 907edd9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
28 changes: 21 additions & 7 deletions ocpmodels/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def __init__(
self.config = config

a2g_args = config.get("a2g_args", {})
if a2g_args is None:
a2g_args = {}

# Make sure we always include PBC info in the resulting atoms objects
a2g_args["r_pbc"] = True
Expand All @@ -81,8 +83,6 @@ def __init__(
if self.config.get("keep_in_memory", False):
self.__getitem__ = functools.cache(self.__getitem__)

# Derived classes should extend this functionality to also create self.ids,
# a list of identifiers that can be passed to get_atoms_object()
self.ids = self.load_dataset_get_ids(config)

def __len__(self) -> int:
Expand All @@ -102,15 +102,20 @@ def __getitem__(self, idx):
atoms, **self.config.get("atoms_transform_args", {})
)

if "sid" in atoms.info:
sid = atoms.info["sid"]
else:
sid = atoms.info.get("sid", self.ids[idx])
try:
sid = tensor([sid])
warnings.warn(
"Supplied sid is not numeric (or missing). Using dataset indices instead."
)
except:
sid = tensor([idx])

fid = atoms.info.get("fid", tensor([0]))

# Convert to data object
data_object = self.a2g.convert(atoms, sid)

data_object.pbc = tensor(atoms.pbc)
data_object.fid = fid

# Transform data object
if self.transform is not None:
Expand Down Expand Up @@ -332,6 +337,11 @@ def get_atoms_object(self, identifier):
warnings.warn(f"{err} occured for: {identifier}")
raise err

if "sid" not in atoms.info:
atoms.info["sid"] = "".join(identifier.split(" ")[:-1])
if "fid" not in atoms.info:
atoms.info["fid"] = int(identifier.split(" ")[-1])

return atoms

def get_metadata(self):
Expand Down Expand Up @@ -439,6 +449,8 @@ def load_dataset_get_ids(self, config) -> dummy_list:
)

self.select_args = config.get("select_args", {})
if self.select_args is None:
self.select_args = {}

# In order to get all of the unique IDs using the default ASE db interface
# we have to load all the data and check ids using a select. This is extremely
Expand Down Expand Up @@ -478,6 +490,8 @@ def get_atoms_object(self, idx):
return atoms

def connect_db(self, address, connect_args={}):
if connect_args is None:
connect_args = {}
db_type = connect_args.get("type", "extract_from_name")
if db_type == "lmdb" or (
db_type == "extract_from_name" and address.split(".")[-1] == "lmdb"
Expand Down
14 changes: 10 additions & 4 deletions ocpmodels/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,18 +790,24 @@ def save_results(
# Because of how distributed sampler works, some system ids
# might be repeated to make no. of samples even across GPUs.
_, idx = np.unique(gather_results["ids"], return_index=True)
gather_results["ids"] = np.array(gather_results["ids"])[idx]
gather_results["ids"] = np.array(
gather_results["ids"],
)[idx]
for k in keys:
if k == "forces":
gather_results[k] = np.concatenate(
np.array(gather_results[k])[idx]
np.array(gather_results[k], dtype=object)[idx]
)
elif k == "chunk_idx":
gather_results[k] = np.cumsum(
np.array(gather_results[k])[idx]
np.array(
gather_results[k],
)[idx]
)[:-1]
else:
gather_results[k] = np.array(gather_results[k])[idx]
gather_results[k] = np.array(
gather_results[k],
)[idx]

logging.info(f"Writing results to {full_path}")
np.savez_compressed(full_path, **gather_results)
10 changes: 7 additions & 3 deletions ocpmodels/trainers/forces_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,14 @@ def predict(
self.ema.restore()
return predictions

predictions["forces"] = np.array(predictions["forces"])
predictions["chunk_idx"] = np.array(predictions["chunk_idx"])
predictions["forces"] = np.array(predictions["forces"], dtype=object)
predictions["chunk_idx"] = np.array(
predictions["chunk_idx"],
)
predictions["energy"] = np.array(predictions["energy"])
predictions["id"] = np.array(predictions["id"])
predictions["id"] = np.array(
predictions["id"],
)
self.save_results(
predictions, results_file, keys=["energy", "forces", "chunk_idx"]
)
Expand Down

0 comments on commit 907edd9

Please sign in to comment.