Skip to content

Commit

Permalink
FM-v4 branch into main (#752)
Browse files Browse the repository at this point in the history
* Update BalancedBatchSampler to use datasets' `data_sizes` method
Replace BalancedBatchSampler's `force_balancing` and `throw_on_error` parameters with `on_error`

* Remove python 3.10 syntax

* Documentation

* Added set_epoch method

* Format

* Changed "resolved dataset" message to be a debug log to reduce log spam

* Minor changes to support multitask

* add in pickle data set; add in stat functions for combining mean and variance

* checksums for equiformer

* detach compute metrics and add checksum function for linear layer

* change name to dataset_configs

* add seed option

* remove pickle dataset

* remove pickle dataset

* add experimental datatransform to ase_dataset

* clean up batchsampler and tests

* base dataset class

* move lin_ref to base dataset

* inherit basedataset for ase dataset

* filter indices prop

* updated import for ase dataset

* added create_dataset fn

* yaml load fix

* create dataset function instead of filtering in base

* remove filtered_indices

* make create_dataset and LMDBDatabase importable from datasets

* create_dataset cleanup

* test create_dataset

* use metadata.natoms directly and add it to subset

* use self.indices to handle shard

* rename _data_sizes

* fix Subset of metadata

* fix up to be mergeable

* merge in monorepo

* small fix for import and keyerror

* minor change to metadata, added full path option

* import updates

* minor fix to base dataset

* skip force_balance and seed

* adding get_metadata to base_dataset

* implement get_metadata for datasets; add tests for max_atoms and balanced partitioning

* a[:len(a)+1] does not throw error, change to check for this

* bug fix for base_dataset

* max atoms branch

* fix typo

* do pbc per system

* add option to use single system pbc

* add multiple mapping

* lint and github workflow fixes

* track parent checkpoint for logger grouping

* add generator to basedataset

* check path relative to yaml file

* add load and exit flag to base_trainer

* add in merge mean and std code to utils

* add log when passing through mean or computing; check other paths for includes

* add qos flag

* use slurm_qos instead of qos

* fix includes

* fix set init

* adding new notebook for using fairchem models with NEBs without CatTSunami enumeration (#764)

* adding new notebook for using fairchem models with NEBs

* adding md tutorials

* blocking code cells that arent needed or take too long

* remove files with diff whitespace

* add resolution flag to escn

* try to revert oxides

* revert typing

* remove white space

* extra line never reached

* move out of fmv4 into dev

* move avg num nodes

* optional import from experimental

* fix lint

* add comments, refactor common trainer args in a single dictionary

* add comments, refactor common trainer args in a single dictionary

* remove parent

---------

Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: Nima Shoghi <[email protected]>
Co-authored-by: Abhishek Das <[email protected]>
Co-authored-by: lbluque <[email protected]>
Co-authored-by: Brandon <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
Co-authored-by: Ray Gao <[email protected]>
Co-authored-by: Brook Wander <[email protected]>
Co-authored-by: Muhammed Shuaibi <[email protected]>
  • Loading branch information
10 people committed Aug 21, 2024
1 parent 427fb8d commit 3899aac
Show file tree
Hide file tree
Showing 6 changed files with 187 additions and 63 deletions.
149 changes: 106 additions & 43 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def save_checkpoint(
return filename


multitask_required_keys = {
"tasks",
"datasets",
"combined_dataset",
"model",
"optim",
}


class Complete:
def __call__(self, data):
device = data.edge_index.device
Expand Down Expand Up @@ -393,48 +402,83 @@ def create_dict_from_args(args: list, sep: str = "."):
return return_dict


def load_config(path: str, previous_includes: list | None = None):
if previous_includes is None:
previous_includes = []
# given a filename and set of paths , return the full file path
def find_relative_file_in_paths(filename, include_paths):
if os.path.exists(filename):
return filename
for path in include_paths:
include_filename = os.path.join(path, filename)
if os.path.exists(include_filename):
return include_filename
raise ValueError(f"Cannot find include YML {filename}")


def load_config(
path: str,
files_previously_included: list | None = None,
include_paths: list | None = None,
):
"""
Load a given config with any defined imports
When imports are present this is a recursive function called on imports.
To prevent any cyclic imports we keep track of already imported yml files
using files_previously_included
"""
if include_paths is None:
include_paths = []
if files_previously_included is None:
files_previously_included = []
path = Path(path)
if path in previous_includes:
if path in files_previously_included:
raise ValueError(
f"Cyclic config include detected. {path} included in sequence {previous_includes}."
f"Cyclic config include detected. {path} included in sequence {files_previously_included}."
)
previous_includes = [*previous_includes, path]
files_previously_included = [*files_previously_included, path]

with open(path) as fp:
direct_config = yaml.load(fp, Loader=UniqueKeyLoader)
current_config = yaml.load(fp, Loader=UniqueKeyLoader)

# Load config from included files.
includes = direct_config.pop("includes") if "includes" in direct_config else []
if not isinstance(includes, list):
raise AttributeError(f"Includes must be a list, '{type(includes)}' provided")
includes_listed_in_config = (
current_config.pop("includes") if "includes" in current_config else []
)
if not isinstance(includes_listed_in_config, list):
raise AttributeError(
f"Includes must be a list, '{type(includes_listed_in_config)}' provided"
)

config = {}
config_from_includes = {}
duplicates_warning = []
duplicates_error = []

for include in includes:
for include in includes_listed_in_config:
include_filename = find_relative_file_in_paths(
include, [os.path.dirname(path), *include_paths]
)
include_config, inc_dup_warning, inc_dup_error = load_config(
include, previous_includes
include_filename, files_previously_included
)
duplicates_warning += inc_dup_warning
duplicates_error += inc_dup_error

# Duplicates between includes causes an error
config, merge_dup_error = merge_dicts(config, include_config)
config_from_includes, merge_dup_error = merge_dicts(
config_from_includes, include_config
)
duplicates_error += merge_dup_error

# Duplicates between included and main file causes warnings
config, merge_dup_warning = merge_dicts(config, direct_config)
config_from_includes, merge_dup_warning = merge_dicts(
config_from_includes, current_config
)
duplicates_warning += merge_dup_warning
return config_from_includes, duplicates_warning, duplicates_error

return config, duplicates_warning, duplicates_error


def build_config(args, args_override):
config, duplicates_warning, duplicates_error = load_config(args.config_yml)
def build_config(args, args_override, include_paths=None):
config, duplicates_warning, duplicates_error = load_config(
args.config_yml, include_paths=include_paths
)
if len(duplicates_warning) > 0:
logging.warning(
f"Overwritten config parameters from included configs "
Expand Down Expand Up @@ -999,34 +1043,53 @@ class _TrainingContext:
task_name = "s2ef"
elif trainer_name in ["energy", "equiformerv2_energy"]:
task_name = "is2re"
elif "multitask" in trainer_name:
task_name = "multitask"
else:
task_name = "ocp"

trainer_cls = registry.get_trainer_class(trainer_name)
assert trainer_cls is not None, "Trainer not found"
trainer = trainer_cls(
task=config.get("task", {}),
model=config["model"],
outputs=config.get("outputs", {}),
dataset=config["dataset"],
optimizer=config["optim"],
loss_functions=config.get("loss_functions", {}),
evaluation_metrics=config.get("evaluation_metrics", {}),
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "wandb"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
slurm=config.get("slurm", {}),
noddp=config.get("noddp", False),
name=task_name,
gp_gpus=config.get("gp_gpus"),
)

trainer_config = {
"model": config["model"],
"optimizer": config["optim"],
"identifier": config["identifier"],
"timestamp_id": config.get("timestamp_id", None),
"run_dir": config.get("run_dir", "./"),
"is_debug": config.get("is_debug", False),
"print_every": config.get("print_every", 10),
"seed": config.get("seed", 0),
"logger": config.get("logger", "wandb"),
"local_rank": config["local_rank"],
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"noddp": config.get("noddp", False),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}

if task_name == "multitask":
trainer_config.update(
{
"tasks": config.get("tasks", {}),
"dataset_configs": config["datasets"],
"combined_dataset_config": config.get("combined_dataset", {}),
"evaluations": config.get("evaluations", {}),
}
)
else:
trainer_config.update(
{
"task": config.get("task", {}),
"outputs": config.get("outputs", {}),
"dataset": config["dataset"],
"loss_functions": config.get("loss_functions", {}),
"evaluation_metrics": config.get("evaluation_metrics", {}),
}
)
trainer = trainer_cls(**trainer_config)

task_cls = registry.get_task_class(config["mode"])
assert task_cls is not None, "Task not found"
Expand Down
21 changes: 17 additions & 4 deletions src/fairchem/core/datasets/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
from torch_geometric.data import Data


def rename_data_object_keys(data_object: Data, key_mapping: dict[str, str]) -> Data:
def rename_data_object_keys(
data_object: Data, key_mapping: dict[str, str | list[str]]
) -> Data:
"""Rename data object keys
Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
new_key can be a list of new keys, for example,
prev_key: energy
new_key: [common_energy, oc20_energy]
This is currently required when we use a single target/label for multiple tasks
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = key_mapping[_property]
if new_property not in data_object:
list_of_new_keys = key_mapping[_property]
if isinstance(list_of_new_keys, str):
list_of_new_keys = [list_of_new_keys]
for new_property in list_of_new_keys:
if new_property == _property:
continue
assert new_property not in data_object
data_object[new_property] = data_object[_property]
if _property not in list_of_new_keys:
del data_object[_property]

return data_object
3 changes: 2 additions & 1 deletion src/fairchem/core/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
return torch.mean(dists)
elif self.reduction == "sum":
return torch.sum(dists)
return None

return dists


class AtomwiseL2Loss(nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions src/fairchem/core/modules/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

if TYPE_CHECKING:
from torch_geometric.data import Data
from contextlib import suppress

with suppress(ImportError):
# TODO remove this in favor of a better solution
# We should never be importing * from a module
from fairchem.experimental.foundation_models.multi_task_dataloader.transforms.data_object import * # noqa


class DataTransforms:
Expand Down
34 changes: 20 additions & 14 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import random
import sys
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -232,6 +233,8 @@ def load(self) -> None:
self.load_loss()
self.load_optimizer()
self.load_extras()
if self.config["optim"].get("load_datasets_and_model_then_exit", False):
sys.exit(0)

def set_seed(self, seed) -> None:
# https://pytorch.org/docs/stable/notes/randomness.html
Expand Down Expand Up @@ -792,6 +795,22 @@ def update_best(
disable_tqdm=disable_eval_tqdm,
)

def _aggregate_metrics(self, metrics):
aggregated_metrics = {}
for k in metrics:
aggregated_metrics[k] = {
"total": distutils.all_reduce(
metrics[k]["total"], average=False, device=self.device
),
"numel": distutils.all_reduce(
metrics[k]["numel"], average=False, device=self.device
),
}
aggregated_metrics[k]["metric"] = (
aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]
)
return aggregated_metrics

@torch.no_grad()
def validate(self, split: str = "val", disable_tqdm: bool = False):
ensure_fitted(self._unwrapped_model, warn=True)
Expand Down Expand Up @@ -833,20 +852,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False):
metrics = self._compute_metrics(out, batch, evaluator, metrics)
metrics = evaluator.update("loss", loss.item(), metrics)

aggregated_metrics = {}
for k in metrics:
aggregated_metrics[k] = {
"total": distutils.all_reduce(
metrics[k]["total"], average=False, device=self.device
),
"numel": distutils.all_reduce(
metrics[k]["numel"], average=False, device=self.device
),
}
aggregated_metrics[k]["metric"] = (
aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]
)
metrics = aggregated_metrics
metrics = self._aggregate_metrics(metrics)

log_dict = {k: metrics[k]["metric"] for k in metrics}
log_dict.update({"epoch": self.epoch})
Expand Down
37 changes: 36 additions & 1 deletion tests/core/common/test_yaml_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import os
import tempfile

import pytest
import yaml

from fairchem.core.common.utils import UniqueKeyLoader
from fairchem.core.common.utils import UniqueKeyLoader, load_config


@pytest.fixture(scope="class")
Expand All @@ -32,6 +33,14 @@ def valid_yaml_config():
"""


@pytest.fixture(scope="class")
def include_path_in_yaml_config():
return """
includes:
- other.yml
"""


def test_invalid_config(invalid_yaml_config):
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(invalid_yaml_config.encode())
Expand All @@ -49,3 +58,29 @@ def test_valid_config(valid_yaml_config):

with open(fname) as fp:
yaml.load(fp, Loader=UniqueKeyLoader)


def test_load_config_with_include_path(include_path_in_yaml_config, valid_yaml_config):
with tempfile.TemporaryDirectory() as tempdirname:

this_yml_path = f"{tempdirname}/this.yml"
with open(this_yml_path, "w") as fp:
fp.write(include_path_in_yaml_config)

# the include does not exist throw an error!
with pytest.raises(ValueError):
load_config(this_yml_path)

other_yml_path = f"{tempdirname}/subfolder"
os.mkdir(other_yml_path)
other_yml_full_filename = f"{other_yml_path}/other.yml"
with open(other_yml_full_filename, "w") as fp:
fp.write(valid_yaml_config)

# the include does not exist throw an error!
with pytest.raises(ValueError):
load_config(this_yml_path)

# the include does not exist throw an error!
loaded_config = load_config(this_yml_path, include_paths=[other_yml_path])
assert set(loaded_config[0].keys()) == set(["key1", "key2"])

0 comments on commit 3899aac

Please sign in to comment.