Skip to content

Commit

Permalink
Context manager to get temporary gloo pg (pytorch#902)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#902

Differential Revision: D62414608
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 23, 2024
1 parent 843835c commit eeb0667
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 5 deletions.
6 changes: 2 additions & 4 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,14 @@ jobs:
shell: bash -l {0}
run: |
set -eux
conda activate test
pip install -r requirements.txt
pip install -r dev-requirements.txt
pip install --no-build-isolation -e .
conda install pytorch cpuonly -c pytorch-nightly
pip install --ignore-installed --pre torch --index-url https://download.pytorch.org/whl/nightly/cpu
pip install --no-build-isolation -e .
- name: Run unit tests with coverage
shell: bash -l {0}
run: |
set -eux
conda activate test
pytest --cov=. --cov-report xml tests -vv
- name: Upload Coverage to Codecov
uses: codecov/codecov-action@v2
Expand Down
74 changes: 74 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
get_global_rank,
get_local_rank,
get_local_world_size,
get_or_create_gloo_pg,
get_process_group_backend_from_device,
get_tcp_init_method,
get_world_size,
Expand Down Expand Up @@ -463,3 +464,76 @@ def _test_method_for_rank_zero() -> str:
val_from_test_method = _test_method_for_rank_zero()
tc = unittest.TestCase()
tc.assertEqual(val_from_test_method, "foo")

@skip_if_not_distributed
def test_get_or_create_gloo_pg(self) -> None:
spawn_multi_process(2, "gloo", self._test_get_or_create_gloo_pg)

@staticmethod
@patch("torchtnt.utils.distributed.dist.destroy_process_group")
def _test_get_or_create_gloo_pg(mock_destroy_process_group: MagicMock) -> None:
tc = unittest.TestCase()

# Test not distributed - no-op
with patch(
"torchtnt.utils.distributed.dist.is_initialized",
return_value=False,
):
with get_or_create_gloo_pg() as pg:
tc.assertIsNone(pg)

mock_destroy_process_group.assert_not_called()

# Test no-op since gloo pg already exists
mock_destroy_process_group.reset_mock()
with get_or_create_gloo_pg() as pg:
tc.assertIs(pg, dist.group.WORLD)

mock_destroy_process_group.assert_not_called()

# Test creating new gloo candidate pg - no op
mock_destroy_process_group.reset_mock()
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
with get_or_create_gloo_pg(gloo_pg) as pg:
tc.assertIs(pg, gloo_pg)

mock_destroy_process_group.assert_not_called()

# Test with NCCL backend - should create a new gloo pg and destroy
mock_destroy_process_group.reset_mock()
with patch(
"torchtnt.utils.distributed.dist.get_backend",
return_value=dist.Backend.NCCL,
):
with get_or_create_gloo_pg() as pg:
pg = none_throws(pg)
tc.assertIsNot(pg, dist.group.WORLD)
tc.assertEqual(pg._get_backend_name(), dist.Backend.GLOO)

mock_destroy_process_group.assert_called_once_with(pg)

# Test exception handling with existing pg - forward exception, group should not be destroyed
mock_destroy_process_group.reset_mock()
with tc.assertRaisesRegex(Exception, "Test Exception"):
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
with get_or_create_gloo_pg(gloo_pg) as pg:
tc.assertIs(pg, gloo_pg)
raise Exception("Test Exception")

mock_destroy_process_group.assert_not_called()

# Test exception handling with new pg - forward exception, group should be destroyed
mock_destroy_process_group.reset_mock()
with tc.assertRaisesRegex(Exception, "Test Exception"):
with patch(
"torchtnt.utils.distributed.dist.get_backend",
return_value=dist.Backend.NCCL,
):
with get_or_create_gloo_pg() as pg:
tc.assertIsNot(pg, dist.group.WORLD)
tc.assertEqual(
none_throws(pg)._get_backend_name(), dist.Backend.GLOO
)
raise Exception("Test Exception")

mock_destroy_process_group.assert_called_once_with(pg)
46 changes: 45 additions & 1 deletion torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
# pyre-strict


import logging
import os
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, cast, Dict, Generator, List, Optional, TypeVar, Union

import torch
import torch.nn.functional as F
Expand All @@ -28,6 +30,8 @@
TParams = ParameterSpecification("TParams")
TReturn = TypeVar("TReturn")

logger: logging.Logger = logging.getLogger(__name__)


class PGWrapper:
"""
Expand Down Expand Up @@ -641,3 +645,43 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
return cast(T, val)

return wrapper


@contextmanager
def get_or_create_gloo_pg(
candidate_pg: Optional[dist.ProcessGroup] = None,
) -> Generator[Optional[dist.ProcessGroup], None, None]:
"""
Context manager to ensure that a gloo process group is used for the contained operations. First checks if the
WORLD process group, or the provided candidate process group, is already gloo-based. In case it is, that is returned.
Otherwise, a new gloo process group will be created and returned. Upon exiting the context, if a new process group
was created, it will be destroyed.
Note: If the distributed environment is not initialized, this context manager will return None and will be no-op.
Args:
candidate_pg: Optional process group to check if it is gloo-based. If None, the WORLD process group will be checked.
"""
gloo_pg_created = False

if not dist.is_initialized():
logger.info("Not in a distributed environment, gloo process group not created")
pg = None

else:
pg = candidate_pg or dist.group.WORLD
if dist.get_backend(pg) != dist.Backend.GLOO:
logger.info("Creating temporary gloo process group")
pg = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
gloo_pg_created = True

try:
yield pg

finally:
# Cleanup temporary gloo pg if it was created
if gloo_pg_created:
dist.destroy_process_group(pg)
logger.info("Destroyed temporary gloo process group")

0 comments on commit eeb0667

Please sign in to comment.