Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FM-v4 branch into main #752

Merged
merged 101 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
101 commits
Select commit Hold shift + click to select a range
ae4add3
Update BalancedBatchSampler to use datasets' `data_sizes` method
nimashoghi Aug 24, 2023
01fe2b4
Remove python 3.10 syntax
nimashoghi Aug 24, 2023
2bf8213
Documentation
nimashoghi Aug 24, 2023
7ba5b8a
Added set_epoch method
Aug 28, 2023
a367d1e
Format
Aug 30, 2023
46e3c57
Changed "resolved dataset" message to be a debug log to reduce log spam
Aug 30, 2023
87714f5
Minor changes to support multitask
abhshkdz Mar 16, 2024
3105359
add in pickle data set; add in stat functions for combining mean and …
misko Apr 5, 2024
e170f53
checksums for equiformer
misko Apr 8, 2024
3ea4dc4
detach compute metrics and add checksum function for linear layer
misko Apr 16, 2024
bbda257
Merge branch 'main' into fm-v2-pickle
misko Apr 16, 2024
102667f
change name to dataset_configs
misko Apr 16, 2024
2c571fd
add seed option
misko Apr 18, 2024
319a597
remove pickle dataset
misko Apr 18, 2024
d1f2ccf
remove pickle dataset
misko Apr 18, 2024
1e7548d
add experimental datatransform to ase_dataset
misko Apr 23, 2024
845bce3
update with main
lbluque Apr 30, 2024
86da069
clean up batchsampler and tests
lbluque Apr 30, 2024
ff628dd
base dataset class
lbluque May 1, 2024
122197f
move lin_ref to base dataset
lbluque May 1, 2024
fb4ce16
inherit basedataset for ase dataset
lbluque May 1, 2024
c9e1759
filter indices prop
lbluque May 3, 2024
85b8ab9
updated import for ase dataset
wood-b May 6, 2024
e227de3
Merge branch 'fm-v3' of github.com:Open-Catalyst-Project/ocp into fm-v3
wood-b May 6, 2024
95d3e6f
added create_dataset fn
wood-b May 6, 2024
b6c640e
yaml load fix
lbluque May 7, 2024
7fa1904
create dataset function instead of filtering in base
lbluque May 7, 2024
04c96bf
remove filtered_indices
lbluque May 7, 2024
ea35b57
make create_dataset and LMDBDatabase importable from datasets
lbluque May 8, 2024
dc98285
create_dataset cleanup
lbluque May 8, 2024
2339916
test create_dataset
lbluque May 8, 2024
9b58cc7
use metadata.natoms directly and add it to subset
lbluque May 10, 2024
63c03fc
use self.indices to handle shard
lbluque May 10, 2024
76322aa
rename _data_sizes
lbluque May 14, 2024
0e7e4a8
merge with main-legacy
lbluque May 14, 2024
bb41b13
merge with main-legacy + no more data_sizes
lbluque May 15, 2024
b4e22bc
fix Subset of metadata
lbluque May 15, 2024
a6cc2c2
fix up to be mergeable
misko May 15, 2024
7033d10
merge in monorepo
misko May 15, 2024
899a227
small fix for import and keyerror
misko May 16, 2024
29b6e68
minor change to metadata, added full path option
wood-b May 17, 2024
f9b15cd
Merge branch 'main' into balanced-batch-sampler+base-dataset
wood-b May 18, 2024
dc59f96
import updates
wood-b May 18, 2024
505cc24
Merge branch 'balanced-batch-sampler+base-dataset' into fm-v4
wood-b May 18, 2024
44234b7
minor fix to base dataset
wood-b May 19, 2024
63348fd
skip force_balance and seed
misko May 19, 2024
45a2b4a
adding get_metadata to base_dataset
wood-b May 20, 2024
64b8df2
implement get_metadata for datasets; add tests for max_atoms and bala…
misko May 20, 2024
fec7fc7
merge in basedataset branch
misko May 20, 2024
80fea27
a[:len(a)+1] does not throw error, change to check for this
misko May 21, 2024
80c8e6b
Merge branch 'balanced-batch-sampler+base-dataset' into fm-v4
misko May 21, 2024
f4910bc
bug fix for base_dataset
wood-b May 22, 2024
e93e73f
max atoms branch
misko May 28, 2024
883a15f
fix typo
misko May 28, 2024
e58a53a
Merge branch 'max_atoms' into fm-v4
misko May 28, 2024
9b87082
do pbc per system
misko May 30, 2024
d8cf857
add option to use single system pbc
misko May 31, 2024
061abf9
add multiple mapping
misko May 31, 2024
0b3f9fe
merge
misko May 31, 2024
5277b4f
Merge branch 'fm-v4-add-multiple-mapping' into fm-v4
misko May 31, 2024
18c15f8
lint and github workflow fixes
misko May 31, 2024
fb24889
track parent checkpoint for logger grouping
mshuaibii Jun 1, 2024
57a2eaf
add generator to basedataset
misko Jun 4, 2024
870fd22
Merge branch 'fm-v4' of github.com:Open-Catalyst-Project/ocp into fm-v4
misko Jun 4, 2024
e50120d
check path relative to yaml file
misko Jun 5, 2024
2e557ad
add load and exit flag to base_trainer
misko Jun 11, 2024
7ef4aec
add in merge mean and std code to utils
misko Jun 12, 2024
20e62b5
add log when passing through mean or computing; check other paths for…
misko Jun 21, 2024
87869b6
add qos flag
misko Jun 27, 2024
0850f34
use slurm_qos instead of qos
misko Jun 27, 2024
75b7e9e
fix includes
misko Jul 1, 2024
49dfca7
fix set init
misko Jul 2, 2024
94f6ce1
merge main
rayg1234 Jul 2, 2024
e10575c
Merge remote-tracking branch 'origin/main' into fm-v4
rayg1234 Jul 2, 2024
5743a59
adding new notebook for using fairchem models with NEBs without CatTS…
brookwander Jul 16, 2024
4880d0c
merge main
rayg1234 Jul 16, 2024
692147d
Merge branch 'main' into fm-v4
mshuaibii Jul 19, 2024
881890e
Merge remote-tracking branch 'origin/main' into fm-v4
rayg1234 Jul 25, 2024
f284190
Merge branch 'main' into fm-v4
misko Aug 2, 2024
ed0e936
merge
misko Aug 2, 2024
25327b0
merge main
misko Aug 5, 2024
089de08
Merge branch 'balanced-batch-sampler+base-dataset' into fm-v4
misko Aug 6, 2024
8be3c78
remove files with diff whitespace
misko Aug 6, 2024
14a073b
Merge branch 'main' into fm-v4
misko Aug 9, 2024
7a71c46
add resolution flag to escn
misko Aug 12, 2024
2aca348
Merge branch 'add_resolution_flag_to_escn' into fm-v4
misko Aug 12, 2024
371eb31
try to revert oxides
misko Aug 12, 2024
a23434c
revert typing
misko Aug 12, 2024
f11ac5e
remove white space
misko Aug 12, 2024
8951360
extra line never reached
misko Aug 13, 2024
67229dc
move out of fmv4 into dev
misko Aug 13, 2024
b031719
Merge branch 'main' into fm-v4
misko Aug 13, 2024
fc269b8
move avg num nodes
misko Aug 13, 2024
039f9e6
Merge branch 'main' into fm-v4
rayg1234 Aug 14, 2024
3ec098c
Merge branch 'main' into fm-v4
misko Aug 14, 2024
21eecd4
Merge remote-tracking branch 'origin' into fm-v4
rayg1234 Aug 20, 2024
f3e1c38
optional import from experimental
misko Aug 20, 2024
f2302bf
fix lint
misko Aug 20, 2024
69648fb
add comments, refactor common trainer args in a single dictionary
misko Aug 20, 2024
0b4c5ee
add comments, refactor common trainer args in a single dictionary
misko Aug 20, 2024
07efac0
remove parent
misko Aug 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 106 additions & 43 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,15 @@ def save_checkpoint(
return filename


multitask_required_keys = {
"tasks",
"datasets",
"combined_dataset",
"model",
"optim",
}


class Complete:
def __call__(self, data):
device = data.edge_index.device
Expand Down Expand Up @@ -393,48 +402,83 @@ def create_dict_from_args(args: list, sep: str = "."):
return return_dict


def load_config(path: str, previous_includes: list | None = None):
if previous_includes is None:
previous_includes = []
# given a filename and set of paths , return the full file path
def find_relative_file_in_paths(filename, include_paths):
if os.path.exists(filename):
return filename
for path in include_paths:
include_filename = os.path.join(path, filename)
if os.path.exists(include_filename):
return include_filename
raise ValueError(f"Cannot find include YML {filename}")


def load_config(
path: str,
files_previously_included: list | None = None,
include_paths: list | None = None,
):
"""
Load a given config with any defined imports
When imports are present this is a recursive function called on imports.
To prevent any cyclic imports we keep track of already imported yml files
using files_previously_included
"""
if include_paths is None:
include_paths = []
if files_previously_included is None:
files_previously_included = []
path = Path(path)
if path in previous_includes:
if path in files_previously_included:
raise ValueError(
f"Cyclic config include detected. {path} included in sequence {previous_includes}."
f"Cyclic config include detected. {path} included in sequence {files_previously_included}."
)
previous_includes = [*previous_includes, path]
files_previously_included = [*files_previously_included, path]

with open(path) as fp:
direct_config = yaml.load(fp, Loader=UniqueKeyLoader)
current_config = yaml.load(fp, Loader=UniqueKeyLoader)

# Load config from included files.
includes = direct_config.pop("includes") if "includes" in direct_config else []
if not isinstance(includes, list):
raise AttributeError(f"Includes must be a list, '{type(includes)}' provided")
includes_listed_in_config = (
current_config.pop("includes") if "includes" in current_config else []
)
if not isinstance(includes_listed_in_config, list):
raise AttributeError(
f"Includes must be a list, '{type(includes_listed_in_config)}' provided"
)

config = {}
config_from_includes = {}
duplicates_warning = []
duplicates_error = []

for include in includes:
for include in includes_listed_in_config:
include_filename = find_relative_file_in_paths(
include, [os.path.dirname(path), *include_paths]
)
include_config, inc_dup_warning, inc_dup_error = load_config(
include, previous_includes
include_filename, files_previously_included
)
duplicates_warning += inc_dup_warning
duplicates_error += inc_dup_error

# Duplicates between includes causes an error
config, merge_dup_error = merge_dicts(config, include_config)
config_from_includes, merge_dup_error = merge_dicts(
config_from_includes, include_config
)
duplicates_error += merge_dup_error

# Duplicates between included and main file causes warnings
config, merge_dup_warning = merge_dicts(config, direct_config)
config_from_includes, merge_dup_warning = merge_dicts(
config_from_includes, current_config
)
duplicates_warning += merge_dup_warning
return config_from_includes, duplicates_warning, duplicates_error

return config, duplicates_warning, duplicates_error


def build_config(args, args_override):
config, duplicates_warning, duplicates_error = load_config(args.config_yml)
def build_config(args, args_override, include_paths=None):
config, duplicates_warning, duplicates_error = load_config(
args.config_yml, include_paths=include_paths
)
if len(duplicates_warning) > 0:
logging.warning(
f"Overwritten config parameters from included configs "
Expand Down Expand Up @@ -999,34 +1043,53 @@ class _TrainingContext:
task_name = "s2ef"
elif trainer_name in ["energy", "equiformerv2_energy"]:
task_name = "is2re"
elif "multitask" in trainer_name:
task_name = "multitask"
else:
task_name = "ocp"

trainer_cls = registry.get_trainer_class(trainer_name)
assert trainer_cls is not None, "Trainer not found"
trainer = trainer_cls(
task=config.get("task", {}),
model=config["model"],
outputs=config.get("outputs", {}),
dataset=config["dataset"],
optimizer=config["optim"],
loss_functions=config.get("loss_functions", {}),
evaluation_metrics=config.get("evaluation_metrics", {}),
identifier=config["identifier"],
timestamp_id=config.get("timestamp_id", None),
run_dir=config.get("run_dir", "./"),
is_debug=config.get("is_debug", False),
print_every=config.get("print_every", 10),
seed=config.get("seed", 0),
logger=config.get("logger", "wandb"),
local_rank=config["local_rank"],
amp=config.get("amp", False),
cpu=config.get("cpu", False),
slurm=config.get("slurm", {}),
noddp=config.get("noddp", False),
name=task_name,
gp_gpus=config.get("gp_gpus"),
)

trainer_config = {
"model": config["model"],
"optimizer": config["optim"],
"identifier": config["identifier"],
"timestamp_id": config.get("timestamp_id", None),
"run_dir": config.get("run_dir", "./"),
"is_debug": config.get("is_debug", False),
"print_every": config.get("print_every", 10),
"seed": config.get("seed", 0),
"logger": config.get("logger", "wandb"),
"local_rank": config["local_rank"],
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"noddp": config.get("noddp", False),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}

if task_name == "multitask":
trainer_config.update(
{
"tasks": config.get("tasks", {}),
"dataset_configs": config["datasets"],
"combined_dataset_config": config.get("combined_dataset", {}),
"evaluations": config.get("evaluations", {}),
}
)
else:
trainer_config.update(
{
"task": config.get("task", {}),
"outputs": config.get("outputs", {}),
"dataset": config["dataset"],
"loss_functions": config.get("loss_functions", {}),
"evaluation_metrics": config.get("evaluation_metrics", {}),
}
)
trainer = trainer_cls(**trainer_config)

task_cls = registry.get_task_class(config["mode"])
assert task_cls is not None, "Task not found"
Expand Down
21 changes: 17 additions & 4 deletions src/fairchem/core/datasets/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,32 @@
from torch_geometric.data import Data


def rename_data_object_keys(data_object: Data, key_mapping: dict[str, str]) -> Data:
def rename_data_object_keys(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on what this does

data_object: Data, key_mapping: dict[str, str | list[str]]
) -> Data:
"""Rename data object keys
Args:
data_object: data object
key_mapping: dictionary specifying keys to rename and new names {prev_key: new_key}
new_key can be a list of new keys, for example,
prev_key: energy
new_key: [common_energy, oc20_energy]
This is currently required when we use a single target/label for multiple tasks
"""
for _property in key_mapping:
# catch for test data not containing labels
if _property in data_object:
new_property = key_mapping[_property]
if new_property not in data_object:
list_of_new_keys = key_mapping[_property]
if isinstance(list_of_new_keys, str):
list_of_new_keys = [list_of_new_keys]
for new_property in list_of_new_keys:
if new_property == _property:
continue
assert new_property not in data_object
data_object[new_property] = data_object[_property]
if _property not in list_of_new_keys:
del data_object[_property]

return data_object
3 changes: 2 additions & 1 deletion src/fairchem/core/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor):
return torch.mean(dists)
elif self.reduction == "sum":
return torch.sum(dists)
return None

return dists


class AtomwiseL2Loss(nn.Module):
Expand Down
6 changes: 6 additions & 0 deletions src/fairchem/core/modules/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@

if TYPE_CHECKING:
from torch_geometric.data import Data
from contextlib import suppress

with suppress(ImportError):
# TODO remove this in favor of a better solution
# We should never be importing * from a module
from fairchem.experimental.foundation_models.multi_task_dataloader.transforms.data_object import * # noqa
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment



class DataTransforms:
Expand Down
34 changes: 20 additions & 14 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import logging
import os
import random
import sys
from abc import ABC, abstractmethod
from itertools import chain
from typing import TYPE_CHECKING
Expand Down Expand Up @@ -232,6 +233,8 @@ def load(self) -> None:
self.load_loss()
self.load_optimizer()
self.load_extras()
if self.config["optim"].get("load_datasets_and_model_then_exit", False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this needed if this is the last line of the init anyways?

sys.exit(0)

def set_seed(self, seed) -> None:
# https://pytorch.org/docs/stable/notes/randomness.html
Expand Down Expand Up @@ -792,6 +795,22 @@ def update_best(
disable_tqdm=disable_eval_tqdm,
)

def _aggregate_metrics(self, metrics):
aggregated_metrics = {}
for k in metrics:
aggregated_metrics[k] = {
"total": distutils.all_reduce(
metrics[k]["total"], average=False, device=self.device
),
"numel": distutils.all_reduce(
metrics[k]["numel"], average=False, device=self.device
),
}
aggregated_metrics[k]["metric"] = (
aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]
)
return aggregated_metrics

@torch.no_grad()
def validate(self, split: str = "val", disable_tqdm: bool = False):
ensure_fitted(self._unwrapped_model, warn=True)
Expand Down Expand Up @@ -833,20 +852,7 @@ def validate(self, split: str = "val", disable_tqdm: bool = False):
metrics = self._compute_metrics(out, batch, evaluator, metrics)
metrics = evaluator.update("loss", loss.item(), metrics)

aggregated_metrics = {}
for k in metrics:
aggregated_metrics[k] = {
"total": distutils.all_reduce(
metrics[k]["total"], average=False, device=self.device
),
"numel": distutils.all_reduce(
metrics[k]["numel"], average=False, device=self.device
),
}
aggregated_metrics[k]["metric"] = (
aggregated_metrics[k]["total"] / aggregated_metrics[k]["numel"]
)
metrics = aggregated_metrics
metrics = self._aggregate_metrics(metrics)

log_dict = {k: metrics[k]["metric"] for k in metrics}
log_dict.update({"epoch": self.epoch})
Expand Down
37 changes: 36 additions & 1 deletion tests/core/common/test_yaml_loader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import os
import tempfile

import pytest
import yaml

from fairchem.core.common.utils import UniqueKeyLoader
from fairchem.core.common.utils import UniqueKeyLoader, load_config


@pytest.fixture(scope="class")
Expand All @@ -32,6 +33,14 @@ def valid_yaml_config():
"""


@pytest.fixture(scope="class")
def include_path_in_yaml_config():
return """
includes:
- other.yml
"""


def test_invalid_config(invalid_yaml_config):
with tempfile.NamedTemporaryFile(delete=False) as fp:
fp.write(invalid_yaml_config.encode())
Expand All @@ -49,3 +58,29 @@ def test_valid_config(valid_yaml_config):

with open(fname) as fp:
yaml.load(fp, Loader=UniqueKeyLoader)


def test_load_config_with_include_path(include_path_in_yaml_config, valid_yaml_config):
with tempfile.TemporaryDirectory() as tempdirname:

this_yml_path = f"{tempdirname}/this.yml"
with open(this_yml_path, "w") as fp:
fp.write(include_path_in_yaml_config)

# the include does not exist throw an error!
with pytest.raises(ValueError):
load_config(this_yml_path)

other_yml_path = f"{tempdirname}/subfolder"
os.mkdir(other_yml_path)
other_yml_full_filename = f"{other_yml_path}/other.yml"
with open(other_yml_full_filename, "w") as fp:
fp.write(valid_yaml_config)

# the include does not exist throw an error!
with pytest.raises(ValueError):
load_config(this_yml_path)

# the include does not exist throw an error!
loaded_config = load_config(this_yml_path, include_paths=[other_yml_path])
assert set(loaded_config[0].keys()) == set(["key1", "key2"])