From 7ba1fa3eca278cade181892b8edd891a690aa795 Mon Sep 17 00:00:00 2001 From: rayg1234 <7001989+rayg1234@users.noreply.github.com> Date: Tue, 10 Sep 2024 14:05:45 -0700 Subject: [PATCH] Delete distributed option - Always use DDP (#833) * delete distributed option * update * fix test * fix test * more cleanup * update * update * fix test * fix test * update * fix tests and slurm runs * update * typo * revert 2 files * test build docs * test build docs * fix book * fix book * fix book --- docs/core/fine-tuning/fine-tuning-oxides.md | 4 +- docs/core/inference.md | 4 +- .../advanced/fine-tuning-in-python.md | 2 +- src/fairchem/core/_cli.py | 46 ++++++++++--------- src/fairchem/core/common/distutils.py | 18 ++++++-- src/fairchem/core/common/flags.py | 17 ------- src/fairchem/core/common/utils.py | 18 +++----- src/fairchem/core/trainers/base_trainer.py | 10 ++-- src/fairchem/core/trainers/ocp_trainer.py | 11 ++--- tests/core/e2e/test_e2e_commons.py | 27 ++--------- tests/core/e2e/test_e2e_finetune_hydra.py | 4 +- tests/core/e2e/test_s2ef.py | 37 ++++----------- tests/core/e2e/test_s2efs.py | 14 +++--- tests/core/test_cli.py | 12 ++--- 14 files changed, 85 insertions(+), 139 deletions(-) diff --git a/docs/core/fine-tuning/fine-tuning-oxides.md b/docs/core/fine-tuning/fine-tuning-oxides.md index 39c39cad40..41f5a5e32e 100644 --- a/docs/core/fine-tuning/fine-tuning-oxides.md +++ b/docs/core/fine-tuning/fine-tuning-oxides.md @@ -258,7 +258,7 @@ You can follow how the training is going by opening a terminal and running You can also visit it in a browser at [train.txt](./train.txt). You have to periodically refresh the view to see updates though. -This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens. +This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens. If you have a gpu or multiple gpus, you should use the flag --num-gpus= and remove the --cpu flag. ```{code-cell} ipython3 :tags: [hide-output] @@ -267,7 +267,7 @@ import time from fairchem.core.common.tutorial_utils import fairchem_main t0 = time.time() -! python {fairchem_main()} --mode train --config-yml {yml} --checkpoint {checkpoint_path} --run-dir fine-tuning --identifier ft-oxides --amp > train.txt 2>&1 +! python {fairchem_main()} --mode train --config-yml {yml} --checkpoint {checkpoint_path} --run-dir fine-tuning --identifier ft-oxides --cpu > train.txt 2>&1 print(f'Elapsed time = {time.time() - t0:1.1f} seconds') ``` diff --git a/docs/core/inference.md b/docs/core/inference.md index 63ce484653..a0701cb09f 100644 --- a/docs/core/inference.md +++ b/docs/core/inference.md @@ -98,7 +98,7 @@ yml = generate_yml_config(checkpoint_path, 'config.yml', yml ``` -It is a good idea to redirect the output to a file. If the output gets too large here, the notebook may fail to save. Normally I would use a redirect like `2&>1`, but this does not work with the main.py method. An alternative here is to open a terminal and run it there. +It is a good idea to redirect the output to a file. If the output gets too large here, the notebook may fail to save. Normally I would use a redirect like `2&>1`, but this does not work with the main.py method. An alternative here is to open a terminal and run it there. If you have a gpu or multiple gpus, you should use the flag --num-gpus= and remove the --cpu flag. ```{code-cell} ipython3 %%capture inference @@ -106,7 +106,7 @@ import time from fairchem.core.common.tutorial_utils import fairchem_main t0 = time.time() -! python {fairchem_main()} --mode predict --config-yml {yml} --checkpoint {checkpoint_path} --amp +! python {fairchem_main()} --mode predict --config-yml {yml} --checkpoint {checkpoint_path} --cpu print(f'Elapsed time = {time.time() - t0:1.1f} seconds') ``` diff --git a/docs/tutorials/advanced/fine-tuning-in-python.md b/docs/tutorials/advanced/fine-tuning-in-python.md index 0eeb8e5485..035385e6a4 100644 --- a/docs/tutorials/advanced/fine-tuning-in-python.md +++ b/docs/tutorials/advanced/fine-tuning-in-python.md @@ -119,7 +119,7 @@ parser = flags.get_parser() args, args_override = parser.parse_known_args(["--mode=train", "--config-yml=config.yml", f"--checkpoint={checkpoint_path}", - "--amp"]) + "--cpu"]) args, args_override ``` diff --git a/src/fairchem/core/_cli.py b/src/fairchem/core/_cli.py index f1270496a9..26da8c7cd2 100644 --- a/src/fairchem/core/_cli.py +++ b/src/fairchem/core/_cli.py @@ -9,10 +9,12 @@ import copy import logging +import os from typing import TYPE_CHECKING from submitit import AutoExecutor from submitit.helpers import Checkpointable, DelayedSubmission +from torch.distributed.elastic.utils.distributed import get_free_port from torch.distributed.launcher.api import LaunchConfig, elastic_launch from fairchem.core.common.flags import flags @@ -29,12 +31,11 @@ class Runner(Checkpointable): - def __init__(self, distributed: bool = False) -> None: + def __init__(self) -> None: self.config = None - self.distributed = distributed def __call__(self, config: dict) -> None: - with new_trainer_context(config=config, distributed=self.distributed) as ctx: + with new_trainer_context(config=config) as ctx: self.config = ctx.config self.task = ctx.task self.trainer = ctx.trainer @@ -42,7 +43,7 @@ def __call__(self, config: dict) -> None: self.task.run() def checkpoint(self, *args, **kwargs): - new_runner = Runner(self.distributed) + new_runner = Runner() self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True) self.config["checkpoint"] = self.task.chkpt_path self.config["timestamp_id"] = self.trainer.timestamp_id @@ -52,18 +53,20 @@ def checkpoint(self, *args, **kwargs): return DelayedSubmission(new_runner, self.config) -def runner_wrapper(distributed: bool, config: dict): - Runner(distributed=distributed)(config) +def runner_wrapper(config: dict): + Runner()(config) -def main(): +def main(args: argparse.Namespace | None = None, override_args: list[str] | None = None): """Run the main fairchem program.""" setup_logging() - parser: argparse.ArgumentParser = flags.get_parser() - args: argparse.Namespace - override_args: list[str] - args, override_args = parser.parse_known_args() + if args is None: + parser: argparse.ArgumentParser = flags.get_parser() + args, override_args = parser.parse_known_args() + + # TODO: rename num_gpus -> num_ranks everywhere + assert args.num_gpus > 0, "num_gpus is used to determine number ranks, so it must be at least 1" config = build_config(args, override_args) if args.submit: # Run on cluster @@ -79,7 +82,7 @@ def main(): slurm_partition=args.slurm_partition, gpus_per_node=args.num_gpus, cpus_per_task=(config["optim"]["num_workers"] + 1), - tasks_per_node=(args.num_gpus if args.distributed else 1), + tasks_per_node=args.num_gpus, nodes=args.num_nodes, slurm_additional_parameters=slurm_add_params, slurm_qos=args.slurm_qos, @@ -88,15 +91,15 @@ def main(): for config in configs: config["slurm"] = copy.deepcopy(executor.parameters) config["slurm"]["folder"] = str(executor.folder) - jobs = executor.map_array(Runner(distributed=args.distributed), configs) + jobs = executor.map_array(Runner(), configs) logging.info(f"Submitted jobs: {', '.join([job.job_id for job in jobs])}") log_file = save_experiment_log(args, jobs, configs) logging.info(f"Experiment log saved to: {log_file}") else: # Run locally on a single node, n-processes - if args.distributed: + if args.num_gpus > 1: logging.info( - f"Running in distributed local mode with {args.num_gpus} ranks" + f"Running in local mode with {args.num_gpus} ranks" ) # HACK to disable multiprocess dataloading in local mode # there is an open issue where LMDB's environment cannot be pickled and used @@ -114,13 +117,14 @@ def main(): rdzv_backend="c10d", max_restarts=0, ) - elastic_launch(launch_config, runner_wrapper)(args.distributed, config) + elastic_launch(launch_config, runner_wrapper)(config) else: - logging.info("Running in non-distributed local mode") - assert ( - args.num_gpus == 1 - ), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu" - runner_wrapper(args.distributed, config) + logging.info("Running in local mode without elastic launch (single gpu only)") + os.environ["MASTER_ADDR"] = "localhost" + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + os.environ["MASTER_PORT"] = str(get_free_port()) + runner_wrapper(config) if __name__ == "__main__": diff --git a/src/fairchem/core/common/distutils.py b/src/fairchem/core/common/distutils.py index f6bf88ccaf..e54ae7d969 100644 --- a/src/fairchem/core/common/distutils.py +++ b/src/fairchem/core/common/distutils.py @@ -15,11 +15,12 @@ import torch import torch.distributed as dist +from torch.distributed.elastic.utils.distributed import get_free_port from fairchem.core.common.typing import none_throws T = TypeVar("T") - +DISTRIBUTED_PORT = 13356 def os_environ_get_or_throw(x: str) -> str: if x not in os.environ: @@ -40,7 +41,7 @@ def setup(config) -> None: ) config["init_method"] = "tcp://{host}:{port}".format( host=hostnames.split()[0].decode("utf-8"), - port=config["distributed_port"], + port=DISTRIBUTED_PORT, ) nnodes = int(os_environ_get_or_throw("SLURM_NNODES")) ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") @@ -67,10 +68,11 @@ def setup(config) -> None: ) # ensures GPU0 does not have extra context/higher peak memory + logging.info(f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}") torch.cuda.set_device(config["local_rank"]) dist.init_process_group( - backend=config["distributed_backend"], + backend="nccl", init_method=config["init_method"], world_size=config["world_size"], rank=config["rank"], @@ -101,8 +103,14 @@ def setup(config) -> None: timeout=timeout, ) else: - config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"])) - dist.init_process_group(backend=config.get("backend", "nccl"), timeout=timeout) + if not os.environ.get("MASTER_ADDR"): + assert config["world_size"] == 1, "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(get_free_port()) + os.environ["LOCAL_RANK"] = "0" + os.environ["RANK"] = "0" + config["local_rank"] = int(os.environ.get("LOCAL_RANK")) + dist.init_process_group(backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), world_size=config["world_size"], timeout=timeout) def cleanup() -> None: diff --git a/src/fairchem/core/common/flags.py b/src/fairchem/core/common/flags.py index 266e1e640b..65a9243f27 100644 --- a/src/fairchem/core/common/flags.py +++ b/src/fairchem/core/common/flags.py @@ -118,9 +118,6 @@ def add_core_args(self) -> None: self.parser.add_argument( "--num-gpus", default=1, type=int, help="Number of GPUs to request" ) - self.parser.add_argument( - "--distributed", action="store_true", help="Run with DDP" - ) self.parser.add_argument( "--cpu", action="store_true", help="Run CPU only training" ) @@ -130,20 +127,6 @@ def add_core_args(self) -> None: type=int, help="Number of Nodes to request", ) - self.parser.add_argument( - "--distributed-port", - type=int, - default=13356, - help="Port on master for DDP", - ) - self.parser.add_argument( - "--distributed-backend", - type=str, - default="nccl", - help="Backend for DDP", - ) - self.parser.add_argument("--local-rank", default=0, type=int, help="Local rank") - self.parser.add_argument("--no-ddp", action="store_true", help="Do not use DDP") self.parser.add_argument( "--gp-gpus", type=int, diff --git a/src/fairchem/core/common/utils.py b/src/fairchem/core/common/utils.py index 955ea1e062..e9d498153a 100644 --- a/src/fairchem/core/common/utils.py +++ b/src/fairchem/core/common/utils.py @@ -505,11 +505,8 @@ def build_config(args, args_override, include_paths=None): config["submit"] = args.submit config["summit"] = args.summit # Distributed - config["local_rank"] = args.local_rank - config["distributed_port"] = args.distributed_port config["world_size"] = args.num_nodes * args.num_gpus - config["distributed_backend"] = args.distributed_backend - config["noddp"] = args.no_ddp + config["distributed_backend"] = "gloo" if args.cpu else "nccl" config["gp_gpus"] = args.gp_gpus # Check for overridden parameters. @@ -1012,7 +1009,7 @@ def setup_env_vars() -> None: @contextmanager -def new_trainer_context(*, config: dict[str, Any], distributed: bool = False): +def new_trainer_context(*, config: dict[str, Any]): from fairchem.core.common import distutils, gp_utils from fairchem.core.common.registry import registry @@ -1031,10 +1028,9 @@ class _TrainingContext: original_config = config config = copy.deepcopy(original_config) - if distributed: - distutils.setup(config) - if config["gp_gpus"] is not None: - gp_utils.setup_gp(config) + distutils.setup(config) + if config["gp_gpus"] is not None: + gp_utils.setup_gp(config) try: setup_imports(config) trainer_name = config.get("trainer", "ocp") @@ -1065,7 +1061,6 @@ class _TrainingContext: "amp": config.get("amp", False), "cpu": config.get("cpu", False), "slurm": config.get("slurm", {}), - "noddp": config.get("noddp", False), "name": task_name, "gp_gpus": config.get("gp_gpus"), } @@ -1101,8 +1096,7 @@ class _TrainingContext: if distutils.is_master(): logging.info(f"Total time taken: {time.time() - start_time}") finally: - if distributed: - distutils.cleanup() + distutils.cleanup() def _resolve_scale_factor_submodule(model: nn.Module, name: str): diff --git a/src/fairchem/core/trainers/base_trainer.py b/src/fairchem/core/trainers/base_trainer.py index 94becb924c..bf1b8b608a 100644 --- a/src/fairchem/core/trainers/base_trainer.py +++ b/src/fairchem/core/trainers/base_trainer.py @@ -73,18 +73,19 @@ def __init__( loss_functions, evaluation_metrics, identifier: str, + # TODO: dealing with local rank is dangerous + # T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use + local_rank: int, timestamp_id: str | None = None, run_dir: str | None = None, is_debug: bool = False, print_every: int = 100, seed: int | None = None, logger: str = "wandb", - local_rank: int = 0, amp: bool = False, cpu: bool = False, name: str = "ocp", slurm=None, - noddp: bool = False, gp_gpus: int | None = None, ) -> None: if slurm is None: @@ -96,6 +97,7 @@ def __init__( self.step = 0 if torch.cuda.is_available() and not self.cpu: + logging.info(f"local rank base: {local_rank}") self.device = torch.device(f"cuda:{local_rank}") else: self.device = torch.device("cpu") @@ -141,7 +143,6 @@ def __init__( ), }, "slurm": slurm, - "noddp": noddp, "gp_gpus": gp_gpus, } # AMP Scaler @@ -545,10 +546,9 @@ def load_model(self) -> None: ) self.logger.log_summary({"num_params": num_params}) - if distutils.initialized() and not self.config["noddp"]: + if distutils.initialized(): self.model = DistributedDataParallel( self.model, - device_ids=None if self.cpu else [self.device], ) @property diff --git a/src/fairchem/core/trainers/ocp_trainer.py b/src/fairchem/core/trainers/ocp_trainer.py index 0ced35bef3..fb3f479987 100644 --- a/src/fairchem/core/trainers/ocp_trainer.py +++ b/src/fairchem/core/trainers/ocp_trainer.py @@ -63,13 +63,10 @@ class OCPTrainer(BaseTrainer): (default: :obj:`None`) logger (str, optional): Type of logger to be used. (default: :obj:`wandb`) - local_rank (int, optional): Local rank of the process, only applicable for distributed training. - (default: :obj:`0`) amp (bool, optional): Run using automatic mixed precision. (default: :obj:`False`) slurm (dict): Slurm configuration. Currently just for keeping track. (default: :obj:`{}`) - noddp (bool, optional): Run model without DDP. """ def __init__( @@ -82,17 +79,18 @@ def __init__( loss_functions, evaluation_metrics, identifier, + # TODO: dealing with local rank is dangerous + # T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use + local_rank, timestamp_id=None, run_dir=None, is_debug=False, print_every=100, seed=None, logger="wandb", - local_rank=0, amp=False, cpu=False, slurm=None, - noddp=False, name="ocp", gp_gpus=None, ): @@ -107,17 +105,16 @@ def __init__( loss_functions=loss_functions, evaluation_metrics=evaluation_metrics, identifier=identifier, + local_rank=local_rank, timestamp_id=timestamp_id, run_dir=run_dir, is_debug=is_debug, print_every=print_every, seed=seed, logger=logger, - local_rank=local_rank, amp=amp, cpu=cpu, slurm=slurm, - noddp=noddp, name=name, gp_gpus=gp_gpus, ) diff --git a/tests/core/e2e/test_e2e_commons.py b/tests/core/e2e/test_e2e_commons.py index ef2b860bff..256d3b0871 100644 --- a/tests/core/e2e/test_e2e_commons.py +++ b/tests/core/e2e/test_e2e_commons.py @@ -8,14 +8,8 @@ import yaml from tensorboard.backend.event_processing.event_accumulator import EventAccumulator -from fairchem.core._cli import Runner +from fairchem.core._cli import main from fairchem.core.common.flags import flags -from fairchem.core.common.test_utils import ( - PGConfig, - init_env_rank_and_launch_test, - spawn_multi_process, -) -from fairchem.core.common.utils import build_config def oc20_lmdb_train_and_val_from_paths( @@ -110,7 +104,7 @@ def _run_main( update_run_args_with=None, save_checkpoint_to=None, save_predictions_to=None, - world_size=0, + world_size=1, ): config_yaml = Path(rundir) / "train_and_val_on_val.yml" update_yaml_with_dict(input_yaml, config_yaml, update_dict_with) @@ -125,24 +119,11 @@ def _run_main( # run parser = flags.get_parser() args, override_args = parser.parse_known_args( - ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu"] + ["--mode", "train", "--seed", "100", "--config-yml", "config.yml", "--cpu", "--num-gpus", str(world_size)] ) for arg_name, arg_value in run_args.items(): setattr(args, arg_name, arg_value) - config = build_config(args, override_args) - - if world_size > 0: - pg_config = PGConfig( - backend="gloo", world_size=world_size, gp_group_size=1, use_gp=False - ) - spawn_multi_process( - pg_config, - Runner(distributed=True), - init_env_rank_and_launch_test, - config, - ) - else: - Runner()(config) + main(args, override_args) if save_checkpoint_to is not None: checkpoints = glob.glob(f"{rundir}/checkpoints/*/checkpoint.pt") diff --git a/tests/core/e2e/test_e2e_finetune_hydra.py b/tests/core/e2e/test_e2e_finetune_hydra.py index df9def3b0c..cec60e394c 100644 --- a/tests/core/e2e/test_e2e_finetune_hydra.py +++ b/tests/core/e2e/test_e2e_finetune_hydra.py @@ -37,7 +37,7 @@ def make_checkpoint(tempdir: str, data_source: Path, seed: int) -> str: }, update_run_args_with={"seed": seed}, save_checkpoint_to=ck_path, - world_size=0, + world_size=1, ) assert os.path.isfile(ck_path) return ck_path @@ -70,7 +70,7 @@ def run_main_with_ft_hydra(tempdir: str, }, update_run_args_with=run_args, save_checkpoint_to=output_checkpoint, - world_size=0, + world_size=1, ) diff --git a/tests/core/e2e/test_s2ef.py b/tests/core/e2e/test_s2ef.py index 10e3203c91..8b5aebab80 100644 --- a/tests/core/e2e/test_s2ef.py +++ b/tests/core/e2e/test_s2ef.py @@ -8,19 +8,17 @@ import numpy as np import numpy.testing as npt import pytest -from fairchem.core._cli import Runner -from fairchem.core.modules.scaling.fit import compute_scaling_factors from test_e2e_commons import ( _run_main, oc20_lmdb_train_and_val_from_paths, update_yaml_with_dict, ) +from fairchem.core.common.flags import flags from fairchem.core.common.utils import build_config, setup_logging +from fairchem.core.modules.scaling.fit import compute_scaling_factors from fairchem.core.scripts.make_lmdb_sizes import get_lmdb_sizes_parser, make_lmdb_sizes -from fairchem.core.common.flags import flags - setup_logging() @@ -197,15 +195,6 @@ def test_train_and_predict( configs, tutorial_val_src, ): - # test without ddp - self.smoke_test_train( - input_yaml=configs[model_name], - tutorial_val_src=tutorial_val_src, - otf_norms=otf_norms, - world_size=0, - num_workers=2, - ) - # test with ddp but no wokers self.smoke_test_train( input_yaml=configs[model_name], tutorial_val_src=tutorial_val_src, @@ -254,19 +243,15 @@ def test_max_num_atoms(self, configs, tutorial_val_src, torch_deterministic): ) @pytest.mark.parametrize( - ("world_size", "ddp"), + ("world_size"), [ - pytest.param( - 2, - True, - ), - pytest.param(0, False), + pytest.param(2), + pytest.param(1), ], ) def test_ddp( self, world_size, - ddp, configs, tutorial_val_src, torch_deterministic, @@ -274,8 +259,6 @@ def test_ddp( with tempfile.TemporaryDirectory() as tempdirname: tempdir = Path(tempdirname) extra_args = {"seed": 0} - if not ddp: - extra_args["no_ddp"] = True _ = _run_main( rundir=str(tempdir), update_dict_with={ @@ -292,14 +275,14 @@ def test_ddp( ) @pytest.mark.parametrize( - ("world_size", "ddp"), + ("world_size"), [ - pytest.param(2, True), - pytest.param(0, False), + pytest.param(2), + pytest.param(1), ], ) def test_balanced_batch_sampler_ddp( - self, world_size, ddp, configs, tutorial_val_src, torch_deterministic + self, world_size, configs, tutorial_val_src, torch_deterministic ): # make dataset metadata parser = get_lmdb_sizes_parser() @@ -311,8 +294,6 @@ def test_balanced_batch_sampler_ddp( with tempfile.TemporaryDirectory() as tempdirname: tempdir = Path(tempdirname) extra_args = {"seed": 0} - if not ddp: - extra_args["no_ddp"] = True _ = _run_main( rundir=str(tempdir), update_dict_with={ diff --git a/tests/core/e2e/test_s2efs.py b/tests/core/e2e/test_s2efs.py index 037979e605..81b945b8b2 100644 --- a/tests/core/e2e/test_s2efs.py +++ b/tests/core/e2e/test_s2efs.py @@ -9,16 +9,14 @@ # TODO add GemNet! @pytest.mark.parametrize( - ("model_name", "ddp"), + ("model_name"), [ - ("equiformer_v2_hydra", False), - ("escn_hydra", False), - ("equiformer_v2_hydra", True), - ("escn_hydra", True), + ("equiformer_v2_hydra"), + ("escn_hydra"), ], ) def test_smoke_s2efs_predict( - model_name, ddp, configs, dummy_binary_dataset_path, tmpdir + model_name, configs, dummy_binary_dataset_path, tmpdir ): # train an s2ef model just to have one input_yaml = configs[model_name] @@ -76,13 +74,13 @@ def test_smoke_s2efs_predict( "max_epochs": 2, "eval_every": 4, "batch_size": 5, - "num_workers": 0 if ddp else 2, + "num_workers": 0, }, **updates, }, save_checkpoint_to=checkpoint_path, save_predictions_to=training_predictions_filename, - world_size=1 if ddp else 0, + world_size=1, ) assert "train/energy_mae" in acc.Tags()["scalars"] assert "val/energy_mae" in acc.Tags()["scalars"] diff --git a/tests/core/test_cli.py b/tests/core/test_cli.py index 5a43723d73..c6e3848a21 100644 --- a/tests/core/test_cli.py +++ b/tests/core/test_cli.py @@ -6,9 +6,7 @@ from fairchem.core._cli import main -def fake_runner(distributed: bool, config: dict): - assert not distributed - assert config["local_rank"] == 0 +def fake_runner(config: dict): assert config["world_size"] == 1 def test_cli(): @@ -23,16 +21,18 @@ def test_cli(): sys.argv[1:] = sys_args main() -def test_cli_distributed(): +def test_cli_multi_rank(): with patch("fairchem.core._cli.elastic_launch") as mock_elastic_launch: sys_args = ["--debug", - "--distributed", "--mode", "train", "--identifier", "test", "--config-yml", - "configs/oc22/s2ef/equiformer_v2/equiformer_v2_N@18_L@6_M@2_e4_f100_121M.yml"] + "configs/oc22/s2ef/equiformer_v2/equiformer_v2_N@18_L@6_M@2_e4_f100_121M.yml", + "--cpu", + "--num-gpus", + "2"] sys.argv[1:] = sys_args main() mock_elastic_launch.assert_called_once()