Skip to content

Commit

Permalink
add eval support to distributed sync (pytorch#847)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#847

Support also eval

Reviewed By: diego-urgell

Differential Revision: D58855082

fbshipit-source-id: b2ed61e27d1d8a786ba3352f14437b48097a2e4d
  • Loading branch information
galrotem authored and facebook-github-bot committed Jun 21, 2024
1 parent 0f72333 commit 882832c
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
17 changes: 15 additions & 2 deletions tests/framework/callbacks/test_periodic_distributed_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import unittest
from unittest.mock import MagicMock, patch

from torchtnt.framework._test_utils import DummyPredictUnit
from torchtnt.framework._test_utils import DummyEvalUnit, DummyPredictUnit

from torchtnt.framework.callbacks.periodic_distributed_sync import (
PeriodicDistributedSync,
Expand All @@ -20,7 +20,7 @@

class PeriodicDistributedSyncTest(unittest.TestCase):
@patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier")
def test_frequency(self, barrier_mock: MagicMock) -> None:
def test_frequency_predict(self, barrier_mock: MagicMock) -> None:
pds = PeriodicDistributedSync(sync_every_n_steps=2)
unit = DummyPredictUnit(2)
state = State(entry_point=EntryPoint.PREDICT)
Expand All @@ -31,3 +31,16 @@ def test_frequency(self, barrier_mock: MagicMock) -> None:
unit.predict_progress.increment_step() # 2 steps completed
pds.on_predict_step_end(state, unit)
barrier_mock.assert_called_once()

@patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier")
def test_frequency_evaluate(self, barrier_mock: MagicMock) -> None:
pds = PeriodicDistributedSync(sync_every_n_steps=2)
unit = DummyEvalUnit(2)
state = State(entry_point=EntryPoint.EVALUATE)
unit.eval_progress.increment_step() # 1 step completed
pds.on_eval_step_end(state, unit)
barrier_mock.assert_not_called()

unit.eval_progress.increment_step() # 2 steps completed
pds.on_eval_step_end(state, unit)
barrier_mock.assert_called_once()
15 changes: 11 additions & 4 deletions torchtnt/framework/callbacks/periodic_distributed_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@

from torchtnt.framework.callback import Callback
from torchtnt.framework.state import State
from torchtnt.framework.unit import TPredictUnit
from torchtnt.utils.distributed import barrier
from torchtnt.framework.unit import TEvalUnit, TPredictUnit
from torchtnt.utils.distributed import barrier, get_global_rank

logger: logging.Logger = logging.getLogger(__name__)

Expand All @@ -20,17 +20,24 @@ class PeriodicDistributedSync(Callback):
"""
A callback to sync all distributed workers at a given frequency.
Helpful when using distributed without DDP/FSDP but would still like to ensure that the workers are in sync with each other, for example large predict jobs.
Note that only predict is supported at the moment.
Both predict and evaluate are supported.
Args:
sync_every_n_steps: the frequency at which to sync the workers.
"""

def __init__(self, sync_every_n_steps: int = 1000) -> None:
self.sync_every_n_steps = sync_every_n_steps
self._global_rank: int = get_global_rank()

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
num_steps = unit.predict_progress.num_steps_completed
if num_steps % self.sync_every_n_steps == 0:
logger.info(f"Barrier at step {num_steps}")
logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}")
barrier()

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
num_steps = unit.eval_progress.num_steps_completed
if num_steps % self.sync_every_n_steps == 0:
logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}")
barrier()

0 comments on commit 882832c

Please sign in to comment.