Skip to content

Commit

Permalink
Context manager to get temporary gloo pg (#902)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #902

Differential Revision: D62414608
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Sep 23, 2024
1 parent 843835c commit 6e9aaf6
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 2 deletions.
97 changes: 96 additions & 1 deletion tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,15 @@

import os
import unittest
from typing import Literal, Optional, Union
from typing import Any, Callable, Literal, Optional, Union
from unittest.mock import MagicMock, patch
from urllib.parse import parse_qs, urlparse

import torch
import torch.distributed as dist
import torch.distributed.launcher as launcher
from pyre_extensions import none_throws
from torch.distributed import ProcessGroup
from torchtnt.utils.distributed import (
_validate_global_rank_world_size,
all_gather_tensors,
Expand All @@ -25,6 +26,7 @@
get_global_rank,
get_local_rank,
get_local_world_size,
get_or_create_gloo_pg,
get_process_group_backend_from_device,
get_tcp_init_method,
get_world_size,
Expand Down Expand Up @@ -463,3 +465,96 @@ def _test_method_for_rank_zero() -> str:
val_from_test_method = _test_method_for_rank_zero()
tc = unittest.TestCase()
tc.assertEqual(val_from_test_method, "foo")

@skip_if_not_distributed
def test_get_or_create_gloo_pg(self) -> None:
spawn_multi_process(2, "gloo", self._test_get_or_create_gloo_pg)

@staticmethod
@patch("torchtnt.utils.distributed.dist.destroy_process_group")
def _test_get_or_create_gloo_pg(mock_destroy_process_group: MagicMock) -> None:
tc = unittest.TestCase()

# Test not distributed - no-op
with patch(
"torchtnt.utils.distributed.dist.is_initialized",
return_value=False,
):
with get_or_create_gloo_pg() as pg:
tc.assertIsNone(pg)

mock_destroy_process_group.assert_not_called()

# Test no-op since gloo pg already exists
mock_destroy_process_group.reset_mock()
with get_or_create_gloo_pg() as pg:
tc.assertIs(pg, dist.group.WORLD)

mock_destroy_process_group.assert_not_called()

# Test creating new gloo candidate pg - no op
mock_destroy_process_group.reset_mock()
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
with get_or_create_gloo_pg(gloo_pg) as pg:
tc.assertIs(pg, gloo_pg)

mock_destroy_process_group.assert_not_called()

# Test with NCCL backend - should create a new gloo pg and destroy
mock_destroy_process_group.reset_mock()

with patch(
"torchtnt.utils.distributed.dist.get_backend",
side_effect=_get_backend_side_effect(),
):
with get_or_create_gloo_pg() as pg:
pg = none_throws(pg)
tc.assertIsNot(pg, dist.group.WORLD)
tc.assertEqual(pg._get_backend_name(), dist.Backend.GLOO)

mock_destroy_process_group.assert_called_once_with(pg)

# Test exception handling with existing pg - forward exception, group should not be destroyed
mock_destroy_process_group.reset_mock()
with tc.assertRaisesRegex(Exception, "Test Exception"):
gloo_pg = dist.new_group(backend=dist.Backend.GLOO)
with get_or_create_gloo_pg(gloo_pg) as pg:
tc.assertIs(pg, gloo_pg)
raise Exception("Test Exception")

mock_destroy_process_group.assert_not_called()

# Test exception handling with new pg - forward exception, group should be destroyed
mock_destroy_process_group.reset_mock()
with tc.assertRaisesRegex(Exception, "Test Exception"):
with patch(
"torchtnt.utils.distributed.dist.get_backend",
side_effect=_get_backend_side_effect(),
):
with get_or_create_gloo_pg() as pg:
tc.assertIsNot(pg, dist.group.WORLD)
tc.assertEqual(
none_throws(pg)._get_backend_name(), dist.Backend.GLOO
)
raise Exception("Test Exception")

mock_destroy_process_group.assert_called_once_with(pg)


def _get_backend_side_effect() -> Callable[[Optional[ProcessGroup]], str]:
"""
Get a side effect for the get_backend function that returns NCCL the first time it is called,
and then will return GLOO for subsequent calls. For use with _test_get_or_create_gloo_pg.
"""
called_get_backend = False

def get_backend(_) -> str:
# We just want to return NCCL the first time we call this function.
nonlocal called_get_backend
if not called_get_backend:
called_get_backend = True
return dist.Backend.NCCL
else:
return dist.Backend.GLOO # real PG

return get_backend
46 changes: 45 additions & 1 deletion torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
# pyre-strict


import logging
import os
import tempfile
from contextlib import contextmanager
from dataclasses import dataclass
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, cast, Dict, Generator, List, Optional, TypeVar, Union

import torch
import torch.nn.functional as F
Expand All @@ -28,6 +30,8 @@
TParams = ParameterSpecification("TParams")
TReturn = TypeVar("TReturn")

logger: logging.Logger = logging.getLogger(__name__)


class PGWrapper:
"""
Expand Down Expand Up @@ -641,3 +645,43 @@ def wrapper(*args: Any, **kwargs: Any) -> T:
return cast(T, val)

return wrapper


@contextmanager
def get_or_create_gloo_pg(
candidate_pg: Optional[dist.ProcessGroup] = None,
) -> Generator[Optional[dist.ProcessGroup], None, None]:
"""
Context manager to ensure that a gloo process group is used for the contained operations. First checks if the
WORLD process group, or the provided candidate process group, is already gloo-based. In case it is, that is returned.
Otherwise, a new gloo process group will be created and returned. Upon exiting the context, if a new process group
was created, it will be destroyed.
Note: If the distributed environment is not initialized, this context manager will return None and will be no-op.
Args:
candidate_pg: Optional process group to check if it is gloo-based. If None, the WORLD process group will be checked.
"""
gloo_pg_created = False

if not dist.is_initialized():
logger.info("Not in a distributed environment, gloo process group not created")
pg = None

else:
pg = candidate_pg or dist.group.WORLD
if dist.get_backend(pg) != dist.Backend.GLOO:
logger.info("Creating temporary gloo process group")
pg = dist.new_group(
timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO
)
gloo_pg_created = True

try:
yield pg

finally:
# Cleanup temporary gloo pg if it was created
if gloo_pg_created:
dist.destroy_process_group(pg)
logger.info("Destroyed temporary gloo process group")

0 comments on commit 6e9aaf6

Please sign in to comment.