Skip to content

Commit

Permalink
context manager to get temporary gloo pg
Browse files Browse the repository at this point in the history
Differential Revision: D62414608
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 23, 2024
1 parent 843835c commit bf7d741
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 1 deletion.
74 changes: 74 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_global_rank,
get_local_rank,
get_local_world_size,
get_or_create_gloo_pg,
get_process_group_backend_from_device,
get_tcp_init_method,
get_world_size,
Expand Down Expand Up @@ -463,3 +464,76 @@ def _test_method_for_rank_zero() -> str:
val_from_test_method = _test_method_for_rank_zero()
tc = unittest.TestCase()
tc.assertEqual(val_from_test_method, "foo")

@skip_if_not_distributed
def test_get_or_create_gloo_pg(self) -> None:
spawn_multi_process(2, "gloo", self._test_get_or_create_gloo_pg)

@staticmethod
@patch("torchtnt.utils.distributed.dist.destroy_process_group")
def _test_get_or_create_gloo_pg(mock_destroy_process_group: MagicMock) -> None:
tc = unittest.TestCase()

# Test not distributed - no-op
with patch(
"torchtnt.utils.distributed.dist.is_initialized",
return_value=False,
):
with get_or_create_gloo_pg() as pg:
tc.assertIsNone(pg)

mock_destroy_process_group.assert_not_called()

# Test no-op since gloo pg already exists
mock_destroy_process_group.reset_mock()
with get_or_create_gloo_pg() as pg:
tc.assertIs(pg, dist.group.WORLD)

mock_destroy_process_group.assert_not_called()

# Test creating new gloo candidate pg - no op
mock_destroy_process_group.reset_mock()
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
with get_or_create_gloo_pg(gloo_pg) as pg:
tc.assertIs(pg, gloo_pg)

mock_destroy_process_group.assert_not_called()

# Test with NCCL backend - should create a new gloo pg and destroy
mock_destroy_process_group.reset_mock()
with patch(
"torchtnt.utils.distributed.dist.get_backend",
return_value=dist.Backend.NCCL,
):
with get_or_create_gloo_pg() as pg:
pg = none_throws(pg)
tc.assertIsNot(pg, dist.group.WORLD)
tc.assertEqual(pg._get_backend_name(), dist.Backend.GLOO)

mock_destroy_process_group.assert_called_once_with(pg)

# Test exception handling with existing pg - forward exception, group should not be destroyed
mock_destroy_process_group.reset_mock()
with tc.assertRaisesRegex(Exception, "Test Exception"):
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
with get_or_create_gloo_pg(gloo_pg) as pg:
tc.assertIs(pg, gloo_pg)
raise Exception("Test Exception")

mock_destroy_process_group.assert_not_called()

# Test exception handling with new pg - forward exception, group should be destroyed
mock_destroy_process_group.reset_mock()
with tc.assertRaisesRegex(Exception, "Test Exception"):
with patch(
"torchtnt.utils.distributed.dist.get_backend",
return_value=dist.Backend.NCCL,
):
with get_or_create_gloo_pg() as pg:
tc.assertIsNot(pg, dist.group.WORLD)
tc.assertEqual(
none_throws(pg)._get_backend_name(), dist.Backend.GLOO
)
raise Exception("Test Exception")

mock_destroy_process_group.assert_called_once_with(pg)
46 changes: 45 additions & 1 deletion torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
# pyre-strict


import logging
import os
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, cast, Dict, Generator, List, Optional, TypeVar, Union

import torch
import torch.nn.functional as F
Expand All @@ -28,6 +30,8 @@
TParams = ParameterSpecification("TParams")
TReturn = TypeVar("TReturn")

logger: logging.Logger = logging.getLogger(__name__)


class PGWrapper:
"""
Expand Down Expand Up @@ -641,3 +645,43 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
return cast(T, val)

return wrapper


@contextmanager
def get_or_create_gloo_pg(
candidate_pg: Optional[dist.ProcessGroup] = None,
) -> Generator[Optional[dist.ProcessGroup], None, None]:
"""
Context manager to ensure that a gloo process group is used for the contained operations. First checks if the
WORLD process group, or the provided candidate process group, is already gloo-based. In case it is, that is returned.
Otherwise, a new gloo process group will be created and returned. Upon exiting the context, if a new process group
was created, it will be destroyed.
Note: If the distributed environment is not initialized, this context manager will return None and will be no-op.
Args:
candidate_pg: Optional process group to check if it is gloo-based. If None, the WORLD process group will be checked.
"""
gloo_pg_created = False

if not dist.is_initialized():
logger.info("Not in a distributed environment, gloo process group not created")
pg = None

else:
pg = candidate_pg or dist.group.WORLD
if dist.get_backend(pg) != dist.Backend.GLOO:
logger.info("Creating temporary gloo process group")
pg = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
gloo_pg_created = True

try:
yield pg

finally:
# Cleanup temporary gloo pg if it was created
if gloo_pg_created:
dist.destroy_process_group(pg)
logger.info("Destroyed temporary gloo process group")

0 comments on commit bf7d741

Please sign in to comment.