Skip to content

Commit

Permalink
add support for different backends and ddp in pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Jul 23, 2024
1 parent 0219e36 commit 2014801
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 10 deletions.
8 changes: 5 additions & 3 deletions src/fairchem/core/common/distutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,9 @@ def setup(config) -> None:
)
else:
config["local_rank"] = int(os.environ.get("LOCAL_RANK", config["local_rank"]))
dist.init_process_group(backend="nccl")
dist.init_process_group(
backend="nccl" if "backend" not in config else config["backend"]
)


def cleanup() -> None:
Expand Down Expand Up @@ -144,7 +146,7 @@ def all_reduce(
if not isinstance(data, torch.Tensor):
tensor = torch.tensor(data)
if device is not None:
tensor = tensor.cuda(device)
tensor = tensor.to(device)
dist.all_reduce(tensor, group=group)
if average:
tensor /= get_world_size()
Expand All @@ -162,7 +164,7 @@ def all_gather(data, group=dist.group.WORLD, device=None) -> list[torch.Tensor]:
if not isinstance(data, torch.Tensor):
tensor = torch.tensor(data)
if device is not None:
tensor = tensor.cuda(device)
tensor = tensor.to(device)
tensor_list = [tensor.new_zeros(tensor.shape) for _ in range(get_world_size())]
dist.all_gather(tensor_list, tensor, group=group)
if not isinstance(data, torch.Tensor):
Expand Down
17 changes: 11 additions & 6 deletions src/fairchem/core/common/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,18 +93,23 @@ def _init_pg_and_rank_and_launch_test(
test_method: callable,
args: list[object],
kwargs: dict[str, object],
init_process_group: bool = False,
) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = pg_setup_params.port
os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["RANK"] = str(rank)
# setup default process group
dist.init_process_group(
rank=rank,
world_size=pg_setup_params.world_size,
backend=pg_setup_params.backend,
timeout=timedelta(seconds=10), # setting up timeout for distributed collectives
)
if init_process_group:
dist.init_process_group(
rank=rank,
world_size=pg_setup_params.world_size,
backend=pg_setup_params.backend,
timeout=timedelta(
seconds=10
), # setting up timeout for distributed collectives
)
# setup gp
if pg_setup_params.use_gp:
config = {
Expand Down
4 changes: 3 additions & 1 deletion src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,9 @@ def load_model(self) -> None:
self.logger.log_summary({"num_params": self.model.num_params})

if distutils.initialized() and not self.config["noddp"]:
self.model = DistributedDataParallel(self.model, device_ids=[self.device])
self.model = DistributedDataParallel(
self.model, device_ids=None if self.cpu else [self.device]
)

@property
def _unwrapped_model(self):
Expand Down

0 comments on commit 2014801

Please sign in to comment.