Skip to content

Commit

Permalink
Move spawn_multi_process() to utils/distributed.py (#704)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #704

This is a generally useful utility function outside of tests, so moving it from `utils/test_utils.py` to `utils/distributed.py`.

Reviewed By: rayg1234, anshulverma, JKSenthil

Differential Revision: D53841205

fbshipit-source-id: 9e0843169dc8df1f81717e7432e1c9e572d1096e
  • Loading branch information
gunchu authored and facebook-github-bot committed Feb 16, 2024
1 parent b27462f commit 5b8db2a
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 172 deletions.
2 changes: 1 addition & 1 deletion docs/source/utils/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ Distributed Utils
all_gather_tensors
rank_zero_fn
revert_sync_batchnorm
spawn_multi_process
sync_bool


Expand Down Expand Up @@ -284,7 +285,6 @@ Test Utils
is_asan
is_tsan
skip_if_asan
spawn_multi_process


Timer Utils
Expand Down
3 changes: 2 additions & 1 deletion examples/torchrec/tests/torchrec_example_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@

import unittest

from torchtnt.utils.test_utils import skip_if_asan, skip_if_not_gpu, spawn_multi_process
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.test_utils import skip_if_asan, skip_if_not_gpu

from ..main import main

Expand Down
8 changes: 2 additions & 6 deletions tests/framework/callbacks/test_base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@

from torchtnt.framework.train import train
from torchtnt.framework.unit import AppStateMixin, TrainUnit, TTrainData
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class BaseCheckpointSaver(BaseCheckpointer):
Expand Down
8 changes: 2 additions & 6 deletions tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,10 @@
get_latest_checkpoint_path,
rank_zero_read_and_broadcast,
)
from torchtnt.utils.distributed import get_global_rank, PGWrapper
from torchtnt.utils.distributed import get_global_rank, PGWrapper, spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.test_utils import (
get_pet_launch_config,
skip_if_not_distributed,
spawn_multi_process,
)
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed

METADATA_FNAME: str = ".metadata"

Expand Down
8 changes: 2 additions & 6 deletions tests/framework/callbacks/test_dcp_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,9 @@
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.callbacks.dcp_saver import DistributedCheckpointSaver
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class DistributedCheckpointSaverTest(unittest.TestCase):
Expand Down
8 changes: 2 additions & 6 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,9 @@
TorchSnapshotSaver,
)
from torchtnt.framework.train import train
from torchtnt.utils.distributed import get_global_rank
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process
from torchtnt.utils.env import seed
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class TorchSnapshotSaverTest(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@
from torchtnt.framework.train import train
from torchtnt.framework.unit import TPredictData
from torchtnt.utils.device import copy_data_to_device
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env, seed
from torchtnt.utils.lr_scheduler import TLRScheduler
from torchtnt.utils.prepare_module import DDPStrategy, FSDPStrategy, TorchCompileParams
from torchtnt.utils.progress import Progress
from torchtnt.utils.test_utils import spawn_multi_process
from torchtnt.utils.timer import Timer

TParams = ParamSpec("TParams")
Expand Down
7 changes: 2 additions & 5 deletions tests/framework/test_unit_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,9 @@
_step_requires_iterator,
)
from torchtnt.framework.state import State
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class UnitUtilsTest(unittest.TestCase):
Expand Down
10 changes: 10 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
PGWrapper,
rank_zero_fn,
revert_sync_batchnorm,
spawn_multi_process,
sync_bool,
)
from torchtnt.utils.test_utils import get_pet_launch_config, skip_if_not_distributed
Expand Down Expand Up @@ -431,3 +432,12 @@ def _test_pg_wrapper_scatter_object_list(
)
tc = unittest.TestCase()
tc.assertEqual(output_list[0], get_local_rank() + 1)

@staticmethod
def _test_method(offset_arg: int, offset_kwarg: int) -> int:
return get_global_rank() + offset_arg - offset_kwarg

@skip_if_not_distributed
def test_spawn_multi_process(self) -> None:
mp_list = spawn_multi_process(2, "gloo", self._test_method, 3, offset_kwarg=2)
self.assertEqual(mp_list, [1, 2])
11 changes: 6 additions & 5 deletions tests/utils/test_distributed_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@
import torch
import torch.distributed as dist
from torchtnt.utils.device import get_device_from_env
from torchtnt.utils.distributed import all_gather_tensors, get_local_rank, PGWrapper
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
from torchtnt.utils.distributed import (
all_gather_tensors,
get_local_rank,
PGWrapper,
spawn_multi_process,
)
from torchtnt.utils.env import init_from_env
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu


class DistributedGPUTest(unittest.TestCase):
Expand Down
7 changes: 2 additions & 5 deletions tests/utils/test_prepare_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import MixedPrecision
from torch.nn.parallel import DistributedDataParallel as DDP
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.env import init_from_env
from torchtnt.utils.prepare_module import (
_is_fsdp_module,
Expand All @@ -23,11 +24,7 @@
prepare_module,
TorchCompileParams,
)
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.version import is_torch_version_geq_1_13, is_torch_version_geq_2_0

COMPILE_AVAIL = False
Expand Down
23 changes: 0 additions & 23 deletions tests/utils/test_test_utils.py

This file was deleted.

7 changes: 2 additions & 5 deletions tests/utils/test_timer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,8 @@
import torch
import torch.distributed as dist
from pyre_extensions import none_throws
from torchtnt.utils.test_utils import (
skip_if_not_distributed,
skip_if_not_gpu,
spawn_multi_process,
)
from torchtnt.utils.distributed import spawn_multi_process
from torchtnt.utils.test_utils import skip_if_not_distributed, skip_if_not_gpu
from torchtnt.utils.timer import (
BoundedTimer,
FullSyncPeriodicTimer,
Expand Down
3 changes: 2 additions & 1 deletion torchtnt/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_process_group_backend_from_device,
get_world_size,
PGWrapper,
spawn_multi_process,
sync_bool,
)
from .early_stop_checker import EarlyStopChecker
Expand Down Expand Up @@ -65,7 +66,7 @@
)
from .stateful import Stateful
from .swa import AveragedModel
from .test_utils import get_pet_launch_config, spawn_multi_process
from .test_utils import get_pet_launch_config
from .timer import FullSyncPeriodicTimer, get_timer_summary, log_elapsed_time, Timer
from .tqdm import close_progress_bar, create_progress_bar, update_progress_bar
from .version import (
Expand Down
89 changes: 86 additions & 3 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,23 @@

import os
import tempfile
from dataclasses import dataclass
from datetime import timedelta
from functools import wraps
from typing import Any, Callable, cast, List, Optional, TypeVar, Union
from typing import Any, Callable, cast, Dict, List, Optional, TypeVar, Union

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor
from pyre_extensions import ParameterSpecification
from torch import distributed as dist, multiprocessing, Tensor
from torch.distributed.elastic.utils.distributed import get_free_port
from typing_extensions import Literal


T = TypeVar("T")
DistObjList = Union[List[T], List[None]]
TParams = ParameterSpecification("TParams")
TReturn = TypeVar("TReturn")


class PGWrapper:
Expand Down Expand Up @@ -504,3 +509,81 @@ def sync_bool(
raise TypeError(
f'Invalid value for `coherence_mode` provided: Expected type int, float, or one of ("any", "all", "rank_zero"), but received {coherence_mode}.'
)


@dataclass
class ProcessGroupSetupParams:
backend: str
port: str
world_size: int


def spawn_multi_process(
world_size: int,
backend: str,
test_method: Callable[TParams, TReturn],
*test_method_args: Any,
**test_method_kwargs: Any,
) -> List[TReturn]:
"""
Spawn single node, multi-rank function.
Uses localhost and free port to communicate.
Args:
world_size: number of processes
backend: backend to use. for example, "nccl", "gloo", etc
test_method: callable to spawn. first 3 arguments are rank, world_size and mp output dict
test_method_args: args for the test method
test_method_kwargs: kwargs for the test method
Returns:
A list, l, where l[i] is the return value of test_method on rank i
"""
manager = multiprocessing.Manager()
mp_output_dict = manager.dict()

port = str(get_free_port())
torch.multiprocessing.spawn(
# torch.multiprocessing.spawn sends rank as the first param
# https://pytorch.org/docs/stable/multiprocessing.html#torch.multiprocessing.spawn
_init_pg_and_rank_and_launch_test,
args=(
ProcessGroupSetupParams(backend=backend, port=port, world_size=world_size),
mp_output_dict,
test_method,
test_method_args,
test_method_kwargs,
),
nprocs=world_size,
)

output_list = []
for i in range(world_size):
output_list.append(mp_output_dict[i])

return output_list


def _init_pg_and_rank_and_launch_test(
rank: int,
pg_setup_params: ProcessGroupSetupParams,
mp_output_dict: Dict[int, object],
test_method: Callable[TParams, TReturn],
args: List[object],
kwargs: Dict[str, object],
) -> None:
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = pg_setup_params.port
os.environ["WORLD_SIZE"] = str(pg_setup_params.world_size)
os.environ["LOCAL_RANK"] = str(rank)
dist.init_process_group(
rank=rank,
world_size=pg_setup_params.world_size,
backend=pg_setup_params.backend,
timeout=timedelta(seconds=10), # setting up timeout for distributed collectives
)
try:
mp_output_dict[rank] = test_method(*args, **kwargs) # pyre-fixme

finally:
destroy_process_group()
Loading

0 comments on commit 5b8db2a

Please sign in to comment.