Skip to content

Commit

Permalink
working relaxations on local, add submitit
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Sep 20, 2024
1 parent a740b29 commit 89dbc7d
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 23 deletions.
13 changes: 6 additions & 7 deletions src/fairchem/core/_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@

import copy
import logging
import os
from typing import TYPE_CHECKING

from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.elastic.utils.distributed import get_free_port
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common import distutils
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import (
build_config,
Expand All @@ -29,6 +28,8 @@
if TYPE_CHECKING:
import argparse

logger = logging.getLogger(__name__)


class Runner(Checkpointable):
def __init__(self) -> None:
Expand Down Expand Up @@ -73,6 +74,7 @@ def main(
from fairchem.core._cli_hydra import main

main(args, override_args)
return

# TODO: rename num_gpus -> num_ranks everywhere
assert (
Expand All @@ -84,7 +86,7 @@ def main(
slurm_add_params = config.get("slurm", None) # additional slurm arguments
configs = create_grid(config, args.sweep_yml) if args.sweep_yml else [config]

logging.info(f"Submitting {len(configs)} jobs")
logger.info(f"Submitting {len(configs)} jobs")
executor = AutoExecutor(folder=args.logdir / "%j", slurm_max_num_timeout=3)
executor.update_parameters(
name=args.identifier,
Expand Down Expand Up @@ -131,10 +133,7 @@ def main(
logging.info(
"Running in local mode without elastic launch (single gpu only)"
)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_PORT"] = str(get_free_port())
distutils.setup_env_local()
runner_wrapper(config)


Expand Down
70 changes: 63 additions & 7 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,28 @@
import argparse

from omegaconf import DictConfig

from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common import distutils
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import setup_imports
from fairchem.core.common.utils import get_timestamp_uid, setup_imports, setup_logging
from fairchem.core.components.runner import Runner

logger = logging.getLogger(__name__)


class Submitit(Checkpointable):
def __call__(self, dict_config) -> None:
def __call__(self, dict_config: DictConfig, cli_args: argparse.Namespace) -> None:
self.config = dict_config
# TODO: this is not needed if we stop instantiating models with Registry.
self.cli_args = cli_args
# TODO: setup_imports is not needed if we stop instantiating models with Registry.
setup_imports()
setup_logging()

distutils.setup(map_cli_args_to_dist_config(cli_args))
runner: Runner = hydra.utils.instantiate(dict_config.runner)
runner.load_state()
runner.run()
Expand All @@ -41,6 +51,17 @@ def checkpoint(self, *args, **kwargs):
return DelayedSubmission(new_runner, self.config)


def map_cli_args_to_dist_config(cli_args: argparse.Namespace) -> dict:
return {
"world_size": cli_args.num_nodes * cli_args.num_gpus,
"distributed_backend": "gloo" if cli_args.cpu else "nccl",
"submit": cli_args.submit,
"summit": None,
"cpu": cli_args.cpu,
"use_cuda_visibile_devices": True,
}


def get_hydra_config_from_yaml(
config_yml: str, overrides_args: list[str]
) -> DictConfig:
Expand All @@ -49,10 +70,13 @@ def get_hydra_config_from_yaml(
config_directory = os.path.dirname(os.path.abspath(config_yml))
config_name = os.path.basename(config_yml)
hydra.initialize_config_dir(config_directory)
# cfg = omegaconf.OmegaConf.load(args.config_yml)
return hydra.compose(config_name=config_name, overrides=overrides_args)


def runner_wrapper(config: DictConfig, cli_args: argparse.Namespace):
Submitit()(config, cli_args)


# this is meant as a future replacement for the main entrypoint
def main(
args: argparse.Namespace | None = None, override_args: list[str] | None = None
Expand All @@ -62,6 +86,38 @@ def main(
args, override_args = parser.parse_known_args()

cfg = get_hydra_config_from_yaml(args.config_yml, override_args)
srunner = Submitit()
logging.info("Running in local mode without elastic launch (single gpu only)")
srunner(cfg)
timestamp_id = get_timestamp_uid()
log_dir = os.path.join(args.run_dir, timestamp_id, "logs")
if args.submit: # Run on cluster
executor = AutoExecutor(folder=log_dir, slurm_max_num_timeout=3)
executor.update_parameters(
name=args.identifier,
mem_gb=args.slurm_mem,
timeout_min=args.slurm_timeout * 60,
slurm_partition=args.slurm_partition,
gpus_per_node=args.num_gpus,
cpus_per_task=8,
tasks_per_node=args.num_gpus,
nodes=args.num_nodes,
slurm_qos=args.slurm_qos,
slurm_account=args.slurm_account,
)
job = executor.submit(runner_wrapper, cfg, args)
logger.info(
f"Submitted job id: {timestamp_id}, slurm id: {job}, logs: {log_dir}"
)
else:
if args.num_gpus > 1:
logger.info(f"Running in local mode with {args.num_gpus} ranks")
launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=args.num_gpus,
rdzv_backend="c10d",
max_restarts=0,
)
elastic_launch(launch_config, runner_wrapper)(cfg, args)
else:
logger.info("Running in local mode without elastic launch")
distutils.setup_env_local()
runner_wrapper(cfg, args)
48 changes: 43 additions & 5 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

T = TypeVar("T")
DISTRIBUTED_PORT = 13356
CURRENT_DEVICE_STR = "CURRRENT_DEVICE"


def os_environ_get_or_throw(x: str) -> str:
Expand Down Expand Up @@ -72,7 +73,15 @@ def setup(config) -> None:
logging.info(
f"local rank: {config['local_rank']}, visible devices: {os.environ['CUDA_VISIBLE_DEVICES']}"
)
torch.cuda.set_device(config["local_rank"])

# In the new hydra runners, we setup the device for each rank as either cuda:0 or cpu
# after this point, the local rank should either be using "cpu" or "cuda"
if config.get("use_cuda_visibile_devices"):
assign_device_for_local_rank(config["cpu"], config["local_rank"])
else:
# in the old code, all ranks can see all devices but need to be assigned a device equal to their local rank
# this is dangerous and should be deprecated
torch.cuda.set_device(config["local_rank"])

dist.init_process_group(
backend="nccl",
Expand Down Expand Up @@ -110,11 +119,9 @@ def setup(config) -> None:
assert (
config["world_size"] == 1
), "Can only setup master address and port at this point for a single rank, otherwise we assume the processes and the comm addr/port have already been setup"
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(get_free_port())
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
setup_env_local()
config["local_rank"] = int(os.environ.get("LOCAL_RANK"))
assign_device_for_local_rank(config["cpu"], config["local_rank"])
dist.init_process_group(
backend=config["distributed_backend"],
rank=int(os.environ.get("RANK")),
Expand Down Expand Up @@ -210,3 +217,34 @@ def gather_objects(data: T, group: dist.ProcessGroup = dist.group.WORLD) -> list
output = [None for _ in range(get_world_size())] if is_master() else None
dist.gather_object(data, output, group=group, dst=0)
return output


def assign_device_for_local_rank(cpu: bool, local_rank: int):
assert (
os.environ.get(CURRENT_DEVICE_STR) is None
), "environment variable CURRENT_DEVICE is used for another purpose!"
if cpu:
os.environ[CURRENT_DEVICE_STR] = "cpu"
else:
# assert the cuda device to be the local rank
os.environ[CURRENT_DEVICE_STR] = "cuda"
os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank)


def get_device_for_local_rank():
cur_dev_env = os.environ.get(CURRENT_DEVICE_STR)
if cur_dev_env is not None:
return cur_dev_env
else:
device = "cuda" if torch.cuda.available() else "cpu"
logging.warn(
f"{CURRENT_DEVICE_STR} env variable not found, defaulting to {device}"
)
return device


def setup_env_local():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["LOCAL_RANK"] = "0"
os.environ["RANK"] = "0"
os.environ["MASTER_PORT"] = str(get_free_port())
5 changes: 3 additions & 2 deletions src/fairchem/core/common/relaxation/optimizers/lbfgs_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def __init__(
self.traj_dir = traj_dir
self.traj_names = traj_names
self.early_stop_batch = early_stop_batch
self.otf_graph = model.model._unwrapped_model.otf_graph
# self.otf_graph = model.model._unwrapped_model.otf_graph
self.otf_graph = True
assert not self.traj_dir or (
traj_dir and len(traj_names)
), "Trajectory names should be specified to save trajectories"
Expand Down Expand Up @@ -225,7 +226,7 @@ def __init__(self, model, transform=None) -> None:
self.transform = transform

def get_energy_and_forces(self, atoms, apply_constraint: bool = True):
predictions = self.model.predict(atoms, per_image=False, disable_tqdm=True)
predictions = self.model.predict(atoms)
energy = predictions["energy"]
forces = predictions["forces"]
if apply_constraint:
Expand Down
6 changes: 6 additions & 0 deletions src/fairchem/core/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import ast
import collections
import copy
import datetime
import errno
import importlib
import itertools
Expand All @@ -26,6 +27,7 @@
from itertools import product
from pathlib import Path
from typing import TYPE_CHECKING, Any
from uuid import uuid4

import numpy as np
import torch
Expand Down Expand Up @@ -1446,3 +1448,7 @@ def load_model_and_weights_from_checkpoint(checkpoint_path: str) -> nn.Module:
matched_dict = match_state_dict(model.state_dict(), checkpoint["state_dict"])
load_state_dict(model, matched_dict, strict=True)
return model


def get_timestamp_uid() -> str:
return datetime.datetime.now().strftime("%Y%m-%d%H-%M%S-") + str(uuid4())[:4]
4 changes: 2 additions & 2 deletions src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,9 +411,9 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None):
def predict(
self,
data_loader,
per_image: bool = True,
per_image: bool = False,
results_file: str | None = None,
disable_tqdm: bool = False,
disable_tqdm: bool = True,
):
if self.is_debug and per_image:
raise FileNotFoundError("Predictions require debug mode to be turned off.")
Expand Down

0 comments on commit 89dbc7d

Please sign in to comment.