Skip to content

Commit

Permalink
move rank_zero_read_and_broadcast to distributed utils (#796)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #796

Reviewed By: diego-urgell, JKSenthil

Differential Revision: D56506784

fbshipit-source-id: 331450f67abe2a60653b546a9d3bf60045daaf2a
  • Loading branch information
galrotem authored and facebook-github-bot committed Apr 24, 2024
1 parent e6739ab commit e7b9e64
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 73 deletions.
20 changes: 0 additions & 20 deletions tests/framework/callbacks/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,23 +421,3 @@ def test_get_app_state(self) -> None:
app_state.keys(),
["module", "optimizer", "loss_fn", "train_progress"],
)

@skip_if_not_distributed
def test_rank_zero_read_and_broadcast(self) -> None:
spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast)

@staticmethod
def _test_rank_zero_read_and_broadcast() -> None:
"""
Tests that rank_zero_read_and_broadcast decorator works as expected
"""

@rank_zero_read_and_broadcast
def _test_method_for_rank_zero() -> str:
assert get_global_rank() == 0
return "foo"

init_from_env()
val_from_test_method = _test_method_for_rank_zero()
tc = unittest.TestCase()
tc.assertEqual(val_from_test_method, "foo")
20 changes: 20 additions & 0 deletions tests/utils/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
get_world_size,
PGWrapper,
rank_zero_fn,
rank_zero_read_and_broadcast,
revert_sync_batchnorm,
spawn_multi_process,
sync_bool,
Expand Down Expand Up @@ -443,3 +444,22 @@ def _test_method(offset_arg: int, offset_kwarg: int) -> int:
def test_spawn_multi_process(self) -> None:
mp_list = spawn_multi_process(2, "gloo", self._test_method, 3, offset_kwarg=2)
self.assertEqual(mp_list, [1, 2])

@skip_if_not_distributed
def test_rank_zero_read_and_broadcast(self) -> None:
spawn_multi_process(2, "gloo", self._test_rank_zero_read_and_broadcast)

@staticmethod
def _test_rank_zero_read_and_broadcast() -> None:
"""
Tests that rank_zero_read_and_broadcast decorator works as expected
"""

@rank_zero_read_and_broadcast
def _test_method_for_rank_zero() -> str:
assert get_global_rank() == 0
return "foo"

val_from_test_method = _test_method_for_rank_zero()
tc = unittest.TestCase()
tc.assertEqual(val_from_test_method, "foo")
53 changes: 2 additions & 51 deletions torchtnt/framework/callbacks/_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,7 @@
import os
import re

from typing import (
Any,
Callable,
cast,
Dict,
List,
Literal,
Optional,
Pattern,
Tuple,
TypeVar,
)
from typing import Any, Dict, List, Literal, Optional, Pattern, Tuple, TypeVar

import fsspec

Expand All @@ -30,7 +19,7 @@
from torchtnt.framework.callbacks.checkpointer_types import RestoreOptions
from torchtnt.framework.state import State
from torchtnt.framework.unit import AppStateMixin
from torchtnt.utils.distributed import get_global_rank, PGWrapper
from torchtnt.utils.distributed import rank_zero_read_and_broadcast

from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.stateful import Stateful
Expand All @@ -40,44 +29,6 @@
T = TypeVar("T")


def rank_zero_read_and_broadcast(
func: Callable[..., T],
) -> Callable[..., T]:
"""
Decorator that ensures a function is only executed by rank 0 and returns the result to all ranks.
Note:
By default will use the global process group. To use a custom process group, `process_group` must be an arg to the function and passed as a keyword argument.
"""

def wrapper(*args: Any, **kwargs: Any) -> T:
ret = None
rank = get_global_rank()
process_group = kwargs.pop("process_group", None)

# Do all filesystem reads from rank 0 only
if rank == 0:
ret = func(*args, **kwargs)

# If not running in a distributed setting, return as is
if not (dist.is_available() and dist.is_initialized()):
# we cast here to avoid type errors, since it is
# guaranteed the return value is of type T
return cast(T, ret)

# Otherwise, broadcast result from rank 0 to all ranks
pg = PGWrapper(process_group)
path_container = [ret]
pg.broadcast_object_list(path_container, 0)
val = path_container[0]

# we cast here to avoid type errors, since it is
# guaranteed the return value is of type T
return cast(T, val)

return wrapper


@rank_zero_read_and_broadcast
def get_latest_checkpoint_path(
dirpath: str,
Expand Down
3 changes: 1 addition & 2 deletions torchtnt/framework/callbacks/base_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
get_best_checkpoint_path,
get_checkpoint_dirpaths,
get_latest_checkpoint_path,
rank_zero_read_and_broadcast,
)
from torchtnt.framework.callbacks.checkpointer_types import (
BestCheckpointConfig,
Expand All @@ -33,7 +32,7 @@
from torchtnt.framework.state import EntryPoint, State
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TTrainData, TTrainUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.utils.distributed import PGWrapper
from torchtnt.utils.distributed import PGWrapper, rank_zero_read_and_broadcast
from torchtnt.utils.fsspec import get_filesystem
from torchtnt.utils.rank_zero_log import rank_zero_info, rank_zero_warn

Expand Down
38 changes: 38 additions & 0 deletions torchtnt/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,3 +590,41 @@ def _init_pg_and_rank_and_launch_method(

finally:
destroy_process_group()


def rank_zero_read_and_broadcast(
func: Callable[..., T],
) -> Callable[..., T]:
"""
Decorator that ensures a function is only executed by rank 0 and returns the result to all ranks.
Note:
By default will use the global process group. To use a custom process group, `process_group` must be an arg to the function and passed as a keyword argument.
"""

def wrapper(*args: Any, **kwargs: Any) -> T:
ret = None
rank = get_global_rank()
process_group = kwargs.pop("process_group", None)

# Do all filesystem reads from rank 0 only
if rank == 0:
ret = func(*args, **kwargs)

# If not running in a distributed setting, return as is
if not (dist.is_available() and dist.is_initialized()):
# we cast here to avoid type errors, since it is
# guaranteed the return value is of type T
return cast(T, ret)

# Otherwise, broadcast result from rank 0 to all ranks
pg = PGWrapper(process_group)
path_container = [ret]
pg.broadcast_object_list(path_container, 0)
val = path_container[0]

# we cast here to avoid type errors, since it is
# guaranteed the return value is of type T
return cast(T, val)

return wrapper

0 comments on commit e7b9e64

Please sign in to comment.