diff --git a/tests/utils/test_distributed.py b/tests/utils/test_distributed.py index 437f016144..d26e3d1d6f 100644 --- a/tests/utils/test_distributed.py +++ b/tests/utils/test_distributed.py @@ -9,7 +9,7 @@ import os import unittest -from typing import Literal, Optional, Union +from typing import Callable, Literal, Optional, Union from unittest.mock import MagicMock, patch from urllib.parse import parse_qs, urlparse @@ -17,6 +17,7 @@ import torch.distributed as dist import torch.distributed.launcher as launcher from pyre_extensions import none_throws +from torch.distributed import ProcessGroup from torchtnt.utils.distributed import ( _validate_global_rank_world_size, all_gather_tensors, @@ -25,6 +26,7 @@ get_global_rank, get_local_rank, get_local_world_size, + get_or_create_gloo_pg, get_process_group_backend_from_device, get_tcp_init_method, get_world_size, @@ -463,3 +465,96 @@ def _test_method_for_rank_zero() -> str: val_from_test_method = _test_method_for_rank_zero() tc = unittest.TestCase() tc.assertEqual(val_from_test_method, "foo") + + @skip_if_not_distributed + def test_get_or_create_gloo_pg(self) -> None: + spawn_multi_process(2, "gloo", self._test_get_or_create_gloo_pg) + + @staticmethod + @patch("torchtnt.utils.distributed.dist.destroy_process_group") + def _test_get_or_create_gloo_pg(mock_destroy_process_group: MagicMock) -> None: + tc = unittest.TestCase() + + # Test not distributed - no-op + with patch( + "torchtnt.utils.distributed.dist.is_initialized", + return_value=False, + ): + with get_or_create_gloo_pg() as pg: + tc.assertIsNone(pg) + + mock_destroy_process_group.assert_not_called() + + # Test no-op since gloo pg already exists + mock_destroy_process_group.reset_mock() + with get_or_create_gloo_pg() as pg: + tc.assertIs(pg, dist.group.WORLD) + + mock_destroy_process_group.assert_not_called() + + # Test creating new gloo candidate pg - no op + mock_destroy_process_group.reset_mock() + gloo_pg = dist.new_group(backend=dist.Backend.GLOO) + with get_or_create_gloo_pg(gloo_pg) as pg: + tc.assertIs(pg, gloo_pg) + + mock_destroy_process_group.assert_not_called() + + # Test with NCCL backend - should create a new gloo pg and destroy + mock_destroy_process_group.reset_mock() + + with patch( + "torchtnt.utils.distributed.dist.get_backend", + side_effect=_get_backend_side_effect(), + ): + with get_or_create_gloo_pg() as pg: + pg = none_throws(pg) + tc.assertIsNot(pg, dist.group.WORLD) + tc.assertEqual(pg._get_backend_name(), dist.Backend.GLOO) + + mock_destroy_process_group.assert_called_once_with(pg) + + # Test exception handling with existing pg - forward exception, group should not be destroyed + mock_destroy_process_group.reset_mock() + with tc.assertRaisesRegex(Exception, "Test Exception"): + gloo_pg = dist.new_group(backend=dist.Backend.GLOO) + with get_or_create_gloo_pg(gloo_pg) as pg: + tc.assertIs(pg, gloo_pg) + raise Exception("Test Exception") + + mock_destroy_process_group.assert_not_called() + + # Test exception handling with new pg - forward exception, group should be destroyed + mock_destroy_process_group.reset_mock() + with tc.assertRaisesRegex(Exception, "Test Exception"): + with patch( + "torchtnt.utils.distributed.dist.get_backend", + side_effect=_get_backend_side_effect(), + ): + with get_or_create_gloo_pg() as pg: + tc.assertIsNot(pg, dist.group.WORLD) + tc.assertEqual( + none_throws(pg)._get_backend_name(), dist.Backend.GLOO + ) + raise Exception("Test Exception") + + mock_destroy_process_group.assert_called_once_with(pg) + + +def _get_backend_side_effect() -> Callable[[Optional[ProcessGroup]], str]: + """ + Get a side effect for the get_backend function that returns NCCL the first time it is called, + and then will return GLOO for subsequent calls. For use with _test_get_or_create_gloo_pg. + """ + called_get_backend = False + + def get_backend(_) -> str: + # We just want to return NCCL the first time we call this function. + nonlocal called_get_backend + if not called_get_backend: + called_get_backend = True + return dist.Backend.NCCL + else: + return dist.Backend.GLOO # real PG + + return get_backend diff --git a/torchtnt/utils/distributed.py b/torchtnt/utils/distributed.py index 996bd5fe48..a7b44f593a 100644 --- a/torchtnt/utils/distributed.py +++ b/torchtnt/utils/distributed.py @@ -8,12 +8,14 @@ # pyre-strict +import logging import os import tempfile +from contextlib import contextmanager from dataclasses import dataclass from datetime import timedelta from functools import wraps -from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union +from typing import Any, Callable, cast, Dict, Generator, List, Optional, TypeVar, Union import torch import torch.nn.functional as F @@ -28,6 +30,8 @@ TParams = ParameterSpecification("TParams") TReturn = TypeVar("TReturn") +logger: logging.Logger = logging.getLogger(__name__) + class PGWrapper: """ @@ -641,3 +645,43 @@ def wrapper(*args: Any, **kwargs: Any) -> T: return cast(T, val) return wrapper + + +@contextmanager +def get_or_create_gloo_pg( + candidate_pg: Optional[dist.ProcessGroup] = None, +) -> Generator[Optional[dist.ProcessGroup], None, None]: + """ + Context manager to ensure that a gloo process group is used for the contained operations. First checks if the + WORLD process group, or the provided candidate process group, is already gloo-based. In case it is, that is returned. + Otherwise, a new gloo process group will be created and returned. Upon exiting the context, if a new process group + was created, it will be destroyed. + + Note: If the distributed environment is not initialized, this context manager will return None and will be no-op. + + Args: + candidate_pg: Optional process group to check if it is gloo-based. If None, the WORLD process group will be checked. + """ + gloo_pg_created = False + + if not dist.is_initialized(): + logger.info("Not in a distributed environment, gloo process group not created") + pg = None + + else: + pg = candidate_pg or dist.group.WORLD + if dist.get_backend(pg) != dist.Backend.GLOO: + logger.info("Creating temporary gloo process group") + pg = dist.new_group( + timeout=timedelta(seconds=3600), backend=dist.Backend.GLOO + ) + gloo_pg_created = True + + try: + yield pg + + finally: + # Cleanup temporary gloo pg if it was created + if gloo_pg_created: + dist.destroy_process_group(pg) + logger.info("Destroyed temporary gloo process group")