Skip to content

Commit

Permalink
Delete distributed option - Always use DDP (#833)
Browse files Browse the repository at this point in the history
* delete distributed option

* update

* fix test

* fix test

* more cleanup

* update

* update

* fix test

* fix test

* update

* fix tests and slurm runs

* update

* typo

* revert 2 files

* test build docs

* test build docs

* fix book

* fix book

* fix book
  • Loading branch information
rayg1234 committed Sep 10, 2024
1 parent eddb484 commit 7ba1fa3
Show file tree
Hide file tree
Showing 14 changed files with 85 additions and 139 deletions.
4 changes: 2 additions & 2 deletions docs/core/fine-tuning/fine-tuning-oxides.md
Original file line number Diff line number Diff line change
Expand Up @@ -258,7 +258,7 @@ You can follow how the training is going by opening a terminal and running

You can also visit it in a browser at [train.txt](./train.txt). You have to periodically refresh the view to see updates though.

This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens.
This can take up to 30 minutes for 80 epochs, so we only do a few here to see what happens. If you have a gpu or multiple gpus, you should use the flag --num-gpus=<number of gpus> and remove the --cpu flag.

```{code-cell} ipython3
:tags: [hide-output]
Expand All @@ -267,7 +267,7 @@ import time
from fairchem.core.common.tutorial_utils import fairchem_main
t0 = time.time()
! python {fairchem_main()} --mode train --config-yml {yml} --checkpoint {checkpoint_path} --run-dir fine-tuning --identifier ft-oxides --amp > train.txt 2>&1
! python {fairchem_main()} --mode train --config-yml {yml} --checkpoint {checkpoint_path} --run-dir fine-tuning --identifier ft-oxides --cpu > train.txt 2>&1
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
```

Expand Down
4 changes: 2 additions & 2 deletions docs/core/inference.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,15 @@ yml = generate_yml_config(checkpoint_path, 'config.yml',
yml
```

It is a good idea to redirect the output to a file. If the output gets too large here, the notebook may fail to save. Normally I would use a redirect like `2&>1`, but this does not work with the main.py method. An alternative here is to open a terminal and run it there.
It is a good idea to redirect the output to a file. If the output gets too large here, the notebook may fail to save. Normally I would use a redirect like `2&>1`, but this does not work with the main.py method. An alternative here is to open a terminal and run it there. If you have a gpu or multiple gpus, you should use the flag --num-gpus=<number of gpus> and remove the --cpu flag.

```{code-cell} ipython3
%%capture inference
import time
from fairchem.core.common.tutorial_utils import fairchem_main
t0 = time.time()
! python {fairchem_main()} --mode predict --config-yml {yml} --checkpoint {checkpoint_path} --amp
! python {fairchem_main()} --mode predict --config-yml {yml} --checkpoint {checkpoint_path} --cpu
print(f'Elapsed time = {time.time() - t0:1.1f} seconds')
```

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/advanced/fine-tuning-in-python.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ parser = flags.get_parser()
args, args_override = parser.parse_known_args(["--mode=train",
"--config-yml=config.yml",
f"--checkpoint={checkpoint_path}",
"--amp"])
"--cpu"])
args, args_override
```

Expand Down
46 changes: 25 additions & 21 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@

import copy
import logging
import os
from typing import TYPE_CHECKING

from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common.flags import flags
Expand All @@ -29,20 +31,19 @@


class Runner(Checkpointable):
def __init__(self, distributed: bool = False) -> None:
def __init__(self) -> None:
self.config = None
self.distributed = distributed

def __call__(self, config: dict) -> None:
with new_trainer_context(config=config, distributed=self.distributed) as ctx:
with new_trainer_context(config=config) as ctx:
self.config = ctx.config
self.task = ctx.task
self.trainer = ctx.trainer
self.task.setup(self.trainer)
self.task.run()

def checkpoint(self, *args, **kwargs):
new_runner = Runner(self.distributed)
new_runner = Runner()
self.trainer.save(checkpoint_file="checkpoint.pt", training_state=True)
self.config["checkpoint"] = self.task.chkpt_path
self.config["timestamp_id"] = self.trainer.timestamp_id
Expand All @@ -52,18 +53,20 @@ def checkpoint(self, *args, **kwargs):
return DelayedSubmission(new_runner, self.config)


def runner_wrapper(distributed: bool, config: dict):
Runner(distributed=distributed)(config)
def runner_wrapper(config: dict):
Runner()(config)


def main():
def main(args: argparse.Namespace | None = None, override_args: list[str] | None = None):
"""Run the main fairchem program."""
setup_logging()

parser: argparse.ArgumentParser = flags.get_parser()
args: argparse.Namespace
override_args: list[str]
args, override_args = parser.parse_known_args()
if args is None:
parser: argparse.ArgumentParser = flags.get_parser()
args, override_args = parser.parse_known_args()

# TODO: rename num_gpus -> num_ranks everywhere
assert args.num_gpus > 0, "num_gpus is used to determine number ranks, so it must be at least 1"
config = build_config(args, override_args)

if args.submit: # Run on cluster
Expand All @@ -79,7 +82,7 @@ def main():
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=(config["optim"]["num_workers"] + 1),
tasks_per_node=(args.num_gpus if args.distributed else 1),
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
slurm_additional_parameters=slurm_add_params,
slurm_qos=args.slurm_qos,
Expand All @@ -88,15 +91,15 @@ def main():
for config in configs:
config["slurm"] = copy.deepcopy(executor.parameters)
config["slurm"]["folder"] = str(executor.folder)
jobs = executor.map_array(Runner(distributed=args.distributed), configs)
jobs = executor.map_array(Runner(), configs)
logging.info(f"Submitted jobs: {', '.join([job.job_id for job in jobs])}")
log_file = save_experiment_log(args, jobs, configs)
logging.info(f"Experiment log saved to: {log_file}")

else: # Run locally on a single node, n-processes
if args.distributed:
if args.num_gpus > 1:
logging.info(
f"Running in distributed local mode with {args.num_gpus} ranks"
f"Running in local mode with {args.num_gpus} ranks"
)
# HACK to disable multiprocess dataloading in local mode
# there is an open issue where LMDB's environment cannot be pickled and used
Expand All @@ -114,13 +117,14 @@ def main():
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(args.distributed, config)
elastic_launch(launch_config, runner_wrapper)(config)
else:
logging.info("Running in non-distributed local mode")
assert (
args.num_gpus == 1
), "Can only run with a single gpu in non distributed local mode, use --distributed flag instead if using >1 gpu"
runner_wrapper(args.distributed, config)
logging.info("Running in local mode without elastic launch (single gpu only)")
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_PORT"] = str(get_free_port())
runner_wrapper(config)


if __name__ == "__main__":
Expand Down
18 changes: 13 additions & 5 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@

import torch
import torch.distributed as dist
from torch.distributed.elastic.utils.distributed import get_free_port

from fairchem.core.common.typing import none_throws

T = TypeVar("T")

DISTRIBUTED_PORT = 13356

def os_environ_get_or_throw(x: str) -> str:
if x not in os.environ:
Expand All @@ -40,7 +41,7 @@ def setup(config) -> None:
)
config["init_method"] = "tcp://{host}:{port}".format(
host=hostnames.split()[0].decode("utf-8"),
port=config["distributed_port"],
port=DISTRIBUTED_PORT,
)
nnodes = int(os_environ_get_or_throw("SLURM_NNODES"))
ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE")
Expand All @@ -67,10 +68,11 @@ def setup(config) -> None:
)

# ensures GPU0 does not have extra context/higher peak memory
logging.info(f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}")
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
backend=config["distributed_backend"],
backend="nccl",
init_method=config["init_method"],
world_size=config["world_size"],
rank=config["rank"],
Expand Down Expand Up @@ -101,8 +103,14 @@ def setup(config) -> None:
timeout=timeout,
)
else:
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend=config.get("backend", "nccl"), timeout=timeout)
if not os.environ.get("MASTER_ADDR"):
assert config["world_size"] == 1, "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
dist.init_process_group(backend=config["distributed_backend"], rank=int(os.environ.get("RANK")), world_size=config["world_size"], timeout=timeout)


def cleanup() -> None:
Expand Down
17 changes: 0 additions & 17 deletions src/fairchem/core/common/flags.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,6 @@ def add_core_args(self) -> None:
self.parser.add_argument(
"--num-gpus", default=1, type=int, help="Number of GPUs to request"
)
self.parser.add_argument(
"--distributed", action="store_true", help="Run with DDP"
)
self.parser.add_argument(
"--cpu", action="store_true", help="Run CPU only training"
)
Expand All @@ -130,20 +127,6 @@ def add_core_args(self) -> None:
type=int,
help="Number of Nodes to request",
)
self.parser.add_argument(
"--distributed-port",
type=int,
default=13356,
help="Port on master for DDP",
)
self.parser.add_argument(
"--distributed-backend",
type=str,
default="nccl",
help="Backend for DDP",
)
self.parser.add_argument("--local-rank", default=0, type=int, help="Local rank")
self.parser.add_argument("--no-ddp", action="store_true", help="Do not use DDP")
self.parser.add_argument(
"--gp-gpus",
type=int,
Expand Down
18 changes: 6 additions & 12 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,11 +505,8 @@ def build_config(args, args_override, include_paths=None):
config["submit"] = args.submit
config["summit"] = args.summit
# Distributed
config["local_rank"] = args.local_rank
config["distributed_port"] = args.distributed_port
config["world_size"] = args.num_nodes * args.num_gpus
config["distributed_backend"] = args.distributed_backend
config["noddp"] = args.no_ddp
config["distributed_backend"] = "gloo" if args.cpu else "nccl"
config["gp_gpus"] = args.gp_gpus

# Check for overridden parameters.
Expand Down Expand Up @@ -1012,7 +1009,7 @@ def setup_env_vars() -> None:


@contextmanager
def new_trainer_context(*, config: dict[str, Any], distributed: bool = False):
def new_trainer_context(*, config: dict[str, Any]):
from fairchem.core.common import distutils, gp_utils
from fairchem.core.common.registry import registry

Expand All @@ -1031,10 +1028,9 @@ class _TrainingContext:
original_config = config
config = copy.deepcopy(original_config)

if distributed:
distutils.setup(config)
if config["gp_gpus"] is not None:
gp_utils.setup_gp(config)
distutils.setup(config)
if config["gp_gpus"] is not None:
gp_utils.setup_gp(config)
try:
setup_imports(config)
trainer_name = config.get("trainer", "ocp")
Expand Down Expand Up @@ -1065,7 +1061,6 @@ class _TrainingContext:
"amp": config.get("amp", False),
"cpu": config.get("cpu", False),
"slurm": config.get("slurm", {}),
"noddp": config.get("noddp", False),
"name": task_name,
"gp_gpus": config.get("gp_gpus"),
}
Expand Down Expand Up @@ -1101,8 +1096,7 @@ class _TrainingContext:
if distutils.is_master():
logging.info(f"Total time taken: {time.time() - start_time}")
finally:
if distributed:
distutils.cleanup()
distutils.cleanup()


def _resolve_scale_factor_submodule(model: nn.Module, name: str):
Expand Down
10 changes: 5 additions & 5 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,19 @@ def __init__(
loss_functions,
evaluation_metrics,
identifier: str,
# TODO: dealing with local rank is dangerous
# T201111838 remove this and use CUDA_VISIBILE_DEVICES instead so trainers don't need to know about which devie to use
local_rank: int,
timestamp_id: str | None = None,
run_dir: str | None = None,
is_debug: bool = False,
print_every: int = 100,
seed: int | None = None,
logger: str = "wandb",
local_rank: int = 0,
amp: bool = False,
cpu: bool = False,
name: str = "ocp",
slurm=None,
noddp: bool = False,
gp_gpus: int | None = None,
) -> None:
if slurm is None:
Expand All @@ -96,6 +97,7 @@ def __init__(
self.step = 0

if torch.cuda.is_available() and not self.cpu:
logging.info(f"local rank base: {local_rank}")
self.device = torch.device(f"cuda:{local_rank}")
else:
self.device = torch.device("cpu")
Expand Down Expand Up @@ -141,7 +143,6 @@ def __init__(
),
},
"slurm": slurm,
"noddp": noddp,
"gp_gpus": gp_gpus,
}
# AMP Scaler
Expand Down Expand Up @@ -545,10 +546,9 @@ def load_model(self) -> None:
)
self.logger.log_summary({"num_params": num_params})

if distutils.initialized() and not self.config["noddp"]:
if distutils.initialized():
self.model = DistributedDataParallel(
self.model,
device_ids=None if self.cpu else [self.device],
)

@property
Expand Down
Loading

0 comments on commit 7ba1fa3

Please sign in to comment.