Skip to content

Commit

Permalink
Fix dataset logic (#771)
Browse files Browse the repository at this point in the history
* adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764)

* adding new notebook for using fairchem models with NEBs

* adding md tutorials

* blocking code cells that arent needed or take too long

* fix dataset config logic

* add empty val/test if not defined

* add empty dicts for all missing datasets

---------

Co-authored-by: Brook Wander <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
Co-authored-by: zulissimeta <[email protected]>
  • Loading branch information
4 people authored Jul 19, 2024
1 parent 22db7bd commit c41a960
Showing 1 changed file with 23 additions and 33 deletions.
56 changes: 23 additions & 33 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,18 @@ def __init__(
if len(dataset) > 2:
self.config["test_dataset"] = dataset[2]
elif isinstance(dataset, dict):
self.config["dataset"] = dataset.get("train", {})
self.config["val_dataset"] = dataset.get("val", {})
self.config["test_dataset"] = dataset.get("test", {})
self.config["relax_dataset"] = dataset.get("relax", {})
# or {} in cases where "dataset": None is explicitly defined
self.config["dataset"] = dataset.get("train", {}) or {}
self.config["val_dataset"] = dataset.get("val", {}) or {}
self.config["test_dataset"] = dataset.get("test", {}) or {}
self.config["relax_dataset"] = dataset.get("relax", {}) or {}
else:
self.config["dataset"] = dataset
self.config["dataset"] = dataset or {}

# add empty dicts for missing datasets
for dataset_name in ("val_dataset", "test_dataset", "relax_dataset"):
if dataset_name not in self.config:
self.config[dataset_name] = {}

if not is_debug and distutils.is_master():
os.makedirs(self.config["cmd"]["checkpoint_dir"], exist_ok=True)
Expand Down Expand Up @@ -277,19 +283,8 @@ def load_datasets(self) -> None:
self.val_loader = None
self.test_loader = None

# Default all of the dataset portions to {} if
# they don't exist, or are null
if not self.config.get("dataset", None):
self.config["dataset"] = {}
if not self.config.get("val_dataset", None):
self.config["val_dataset"] = {}
if not self.config.get("test_dataset", None):
self.config["test_dataset"] = {}
if not self.config.get("relax_dataset", None):
self.config["relax_dataset"] = {}

# load train, val, test datasets
if self.config["dataset"] and self.config["dataset"].get("src", None):
if "src" in self.config["dataset"]:
logging.info(
f"Loading dataset: {self.config['dataset'].get('format', 'lmdb')}"
)
Expand All @@ -307,7 +302,7 @@ def load_datasets(self) -> None:
self.train_sampler,
)

if self.config["val_dataset"]:
if "src" in self.config["val_dataset"]:
if self.config["val_dataset"].get("use_train_settings", True):
val_config = self.config["dataset"].copy()
val_config.update(self.config["val_dataset"])
Expand All @@ -329,13 +324,8 @@ def load_datasets(self) -> None:
self.val_sampler,
)

if self.config["test_dataset"]:
if (
self.config["test_dataset"].get("use_train_settings", True)
and self.config[
"dataset"
] # if there's no training dataset, we have nothing to copy
):
if "src" in self.config["test_dataset"]:
if self.config["test_dataset"].get("use_train_settings", True):
test_config = self.config["dataset"].copy()
test_config.update(self.config["test_dataset"])
else:
Expand Down Expand Up @@ -407,16 +397,16 @@ def load_task(self):
"outputs"
][target_name].get("level", "system")
if "train_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget][
"train_on_free_atoms"
] = self.config["outputs"][target_name].get(
"train_on_free_atoms", True
self.output_targets[subtarget]["train_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("train_on_free_atoms", True)
)
if "eval_on_free_atoms" not in self.output_targets[subtarget]:
self.output_targets[subtarget][
"eval_on_free_atoms"
] = self.config["outputs"][target_name].get(
"eval_on_free_atoms", True
self.output_targets[subtarget]["eval_on_free_atoms"] = (
self.config[
"outputs"
][target_name].get("eval_on_free_atoms", True)
)

# TODO: Assert that all targets, loss fn, metrics defined are consistent
Expand Down

0 comments on commit c41a960

Please sign in to comment.