Skip to content

Commit

Permalink
Make relaxation data more general (#714)
Browse files Browse the repository at this point in the history
* Changes to relaxation dataset

* Changes to relaxation dataset

---------

Co-authored-by: Muhammed Shuaibi <[email protected]>
Co-authored-by: Luis Barroso-Luque <[email protected]>
  • Loading branch information
3 people authored Jul 10, 2024
1 parent 712e723 commit 51a439e
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def __init__(
self.config["dataset"] = dataset.get("train", None)
self.config["val_dataset"] = dataset.get("val", None)
self.config["test_dataset"] = dataset.get("test", None)
self.config["relax_dataset"] = dataset.get("relax", None)
else:
self.config["dataset"] = dataset

Expand Down Expand Up @@ -339,22 +340,27 @@ def load_datasets(self) -> None:
self.test_sampler,
)

# load relaxation dataset
if "relax_dataset" in self.config["task"]:
self.relax_dataset = registry.get_dataset_class("lmdb")(
self.config["task"]["relax_dataset"]
)
self.relax_sampler = self.get_sampler(
self.relax_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.relax_loader = self.get_dataloader(
self.relax_dataset,
self.relax_sampler,
)
if self.config.get("relax_dataset", None):
if self.config["relax_dataset"].get("use_train_settings", True):
relax_config = self.config["dataset"].copy()
relax_config.update(self.config["relax_dataset"])
else:
relax_config = self.config["relax_dataset"]

self.relax_dataset = registry.get_dataset_class(
relax_config.get("format", "lmdb")
)(relax_config)
self.relax_sampler = self.get_sampler(
self.relax_dataset,
self.config["optim"].get(
"eval_batch_size", self.config["optim"]["batch_size"]
),
shuffle=False,
)
self.relax_loader = self.get_dataloader(
self.relax_dataset,
self.relax_sampler,
)

def load_task(self):
# Normalizer for the dataset.
Expand Down

0 comments on commit 51a439e

Please sign in to comment.