Skip to content

Commit

Permalink
fix tests and balanced batch sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 23, 2024
1 parent 2014801 commit fbfe627
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 4 deletions.
4 changes: 3 additions & 1 deletion src/fairchem/core/common/data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,9 @@ def __init__(

def _get_natoms(self, batch_idx: list[int]):
if self.sampler.dataset.metadata_hasattr("natoms"):
return self.sampler.dataset.get_metadata("natoms", batch_idx)
return np.array(
self.sampler.dataset.get_metadata("natoms", batch_idx)
).reshape(-1)
if self.on_error == "warn_and_balance":
return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx])
return None
Expand Down
9 changes: 7 additions & 2 deletions src/fairchem/core/datasets/ase_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,15 @@ def sample_property_metadata(self, num_samples: int = 100) -> dict:
return metadata

def get_metadata(self, attr, idx):
# try the parent method
metadata = super().get_metadata(attr, idx)
if metadata is not None:
return metadata
# try to resolve it here
if attr != "natoms":
return None
if isinstance(idx, list):
return [self.get_metadata(attr, i) for i in idx]
if isinstance(idx, (list, np.ndarray)):
return np.array([self.get_metadata(attr, i) for i in idx])
return len(self.get_atoms(idx))


Expand Down
59 changes: 58 additions & 1 deletion tests/core/e2e/test_s2ef.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import numpy as np
import pytest
import yaml
from fairchem.core.common.test_utils import PGConfig, spawn_multi_process
from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

from fairchem.core._cli import Runner
Expand Down Expand Up @@ -84,13 +86,15 @@ def _run_main(
update_run_args_with=None,
save_checkpoint_to=None,
save_predictions_to=None,
world_size=1,
):
config_yaml = Path(rundir) / "train_and_val_on_val.yml"

with open(input_yaml) as yaml_file:
yaml_config = yaml.safe_load(yaml_file)
if update_dict_with is not None:
yaml_config = merge_dictionary(yaml_config, update_dict_with)
yaml_config["backend"] = "gloo"
with open(str(config_yaml), "w") as yaml_file:
yaml.dump(yaml_config, yaml_file)

Expand All @@ -110,7 +114,14 @@ def _run_main(
for arg_name, arg_value in run_args.items():
setattr(args, arg_name, arg_value)
config = build_config(args, override_args)
Runner()(config)

if world_size > 0:
pg_config = PGConfig(
backend="gloo", world_size=2, gp_group_size=1, use_gp=False
)
spawn_multi_process(pg_config, Runner(distributed=True), config)
else:
Runner()(config)

if save_checkpoint_to is not None:
checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt")
Expand Down Expand Up @@ -213,6 +224,52 @@ def test_train_and_predict(
tutorial_val_src=tutorial_val_src,
)

def test_ddp(self, configs, tutorial_val_src, torch_deterministic):
with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)

_ = _run_main(
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with={"seed": 0},
input_yaml=configs["equiformer_v2"],
world_size=2,
)

def test_balanced_batch_sampler_ddp(
self, configs, tutorial_val_src, torch_deterministic
):

# make dataset metadata
parser = get_lmdb_sizes_parser()
args, override_args = parser.parse_known_args(["--data-path", str(tutorial_val_src)])
make_lmdb_sizes(args)

with tempfile.TemporaryDirectory() as tempdirname:
tempdir = Path(tempdirname)

_ = _run_main(
rundir=str(tempdir),
update_dict_with={
"optim": {"max_epochs": 1, "load_balancing": "atoms"},
"dataset": oc20_lmdb_train_and_val_from_paths(
train_src=str(tutorial_val_src),
val_src=str(tutorial_val_src),
test_src=str(tutorial_val_src),
),
},
update_run_args_with={"seed": 0},
input_yaml=configs["equiformer_v2"],
world_size=2,
)

# train for a few steps and confirm same seeds get same results
def test_different_seeds(
self,
Expand Down

0 comments on commit fbfe627

Please sign in to comment.