From fbfe627f27bdae7ea10f7dfa80b20efbb382f406 Mon Sep 17 00:00:00 2001 From: Misko Date: Tue, 23 Jul 2024 21:08:15 +0000 Subject: [PATCH] fix tests and balanced batch sampler --- src/fairchem/core/common/data_parallel.py | 4 +- src/fairchem/core/datasets/ase_datasets.py | 9 +++- tests/core/e2e/test_s2ef.py | 59 +++++++++++++++++++++- 3 files changed, 68 insertions(+), 4 deletions(-) diff --git a/src/fairchem/core/common/data_parallel.py b/src/fairchem/core/common/data_parallel.py index a0324017e..89c3b6744 100644 --- a/src/fairchem/core/common/data_parallel.py +++ b/src/fairchem/core/common/data_parallel.py @@ -196,7 +196,9 @@ def __init__( def _get_natoms(self, batch_idx: list[int]): if self.sampler.dataset.metadata_hasattr("natoms"): - return self.sampler.dataset.get_metadata("natoms", batch_idx) + return np.array( + self.sampler.dataset.get_metadata("natoms", batch_idx) + ).reshape(-1) if self.on_error == "warn_and_balance": return np.array([self.sampler.dataset[idx].num_nodes for idx in batch_idx]) return None diff --git a/src/fairchem/core/datasets/ase_datasets.py b/src/fairchem/core/datasets/ase_datasets.py index 0b83b0ccc..15c22322d 100644 --- a/src/fairchem/core/datasets/ase_datasets.py +++ b/src/fairchem/core/datasets/ase_datasets.py @@ -184,10 +184,15 @@ def sample_property_metadata(self, num_samples: int = 100) -> dict: return metadata def get_metadata(self, attr, idx): + # try the parent method + metadata = super().get_metadata(attr, idx) + if metadata is not None: + return metadata + # try to resolve it here if attr != "natoms": return None - if isinstance(idx, list): - return [self.get_metadata(attr, i) for i in idx] + if isinstance(idx, (list, np.ndarray)): + return np.array([self.get_metadata(attr, i) for i in idx]) return len(self.get_atoms(idx)) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index e8a4ae1bf..f234e4563 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -9,6 +9,8 @@ import numpy as np import pytest import yaml +from fairchem.core.common.test_utils import PGConfig, spawn_multi_process +from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes from tensorboard.backend.event_processing.event_accumulator import EventAccumulator from fairchem.core._cli import Runner @@ -84,6 +86,7 @@ def _run_main( update_run_args_with=None, save_checkpoint_to=None, save_predictions_to=None, + world_size=1, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" @@ -91,6 +94,7 @@ def _run_main( yaml_config = yaml.safe_load(yaml_file) if update_dict_with is not None: yaml_config = merge_dictionary(yaml_config, update_dict_with) + yaml_config["backend"] = "gloo" with open(str(config_yaml), "w") as yaml_file: yaml.dump(yaml_config, yaml_file) @@ -110,7 +114,14 @@ def _run_main( for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) config = build_config(args, override_args) - Runner()(config) + + if world_size > 0: + pg_config = PGConfig( + backend="gloo", world_size=2, gp_group_size=1, use_gp=False + ) + spawn_multi_process(pg_config, Runner(distributed=True), config) + else: + Runner()(config) if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") @@ -213,6 +224,52 @@ def test_train_and_predict( tutorial_val_src=tutorial_val_src, ) + def test_ddp(self, configs, tutorial_val_src, torch_deterministic): + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with={"seed": 0}, + input_yaml=configs["equiformer_v2"], + world_size=2, + ) + + def test_balanced_batch_sampler_ddp( + self, configs, tutorial_val_src, torch_deterministic + ): + + # make dataset metadata + parser = get_lmdb_sizes_parser() + args, override_args = parser.parse_known_args(["--data-path", str(tutorial_val_src)]) + make_lmdb_sizes(args) + + with tempfile.TemporaryDirectory() as tempdirname: + tempdir = Path(tempdirname) + + _ = _run_main( + rundir=str(tempdir), + update_dict_with={ + "optim": {"max_epochs": 1, "load_balancing": "atoms"}, + "dataset": oc20_lmdb_train_and_val_from_paths( + train_src=str(tutorial_val_src), + val_src=str(tutorial_val_src), + test_src=str(tutorial_val_src), + ), + }, + update_run_args_with={"seed": 0}, + input_yaml=configs["equiformer_v2"], + world_size=2, + ) + # train for a few steps and confirm same seeds get same results def test_different_seeds( self,