From 882832c1f5019f4149a2d5d79c2dc35e98fb2a1e Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Fri, 21 Jun 2024 15:17:16 -0700 Subject: [PATCH] add eval support to distributed sync (#847) Summary: Pull Request resolved: https://github.com/pytorch/tnt/pull/847 Support also eval Reviewed By: diego-urgell Differential Revision: D58855082 fbshipit-source-id: b2ed61e27d1d8a786ba3352f14437b48097a2e4d --- .../callbacks/test_periodic_distributed_sync.py | 17 +++++++++++++++-- .../callbacks/periodic_distributed_sync.py | 15 +++++++++++---- 2 files changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/framework/callbacks/test_periodic_distributed_sync.py b/tests/framework/callbacks/test_periodic_distributed_sync.py index 5173c47022..4f620dc4af 100644 --- a/tests/framework/callbacks/test_periodic_distributed_sync.py +++ b/tests/framework/callbacks/test_periodic_distributed_sync.py @@ -10,7 +10,7 @@ import unittest from unittest.mock import MagicMock, patch -from torchtnt.framework._test_utils import DummyPredictUnit +from torchtnt.framework._test_utils import DummyEvalUnit, DummyPredictUnit from torchtnt.framework.callbacks.periodic_distributed_sync import ( PeriodicDistributedSync, @@ -20,7 +20,7 @@ class PeriodicDistributedSyncTest(unittest.TestCase): @patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier") - def test_frequency(self, barrier_mock: MagicMock) -> None: + def test_frequency_predict(self, barrier_mock: MagicMock) -> None: pds = PeriodicDistributedSync(sync_every_n_steps=2) unit = DummyPredictUnit(2) state = State(entry_point=EntryPoint.PREDICT) @@ -31,3 +31,16 @@ def test_frequency(self, barrier_mock: MagicMock) -> None: unit.predict_progress.increment_step() # 2 steps completed pds.on_predict_step_end(state, unit) barrier_mock.assert_called_once() + + @patch("torchtnt.framework.callbacks.periodic_distributed_sync.barrier") + def test_frequency_evaluate(self, barrier_mock: MagicMock) -> None: + pds = PeriodicDistributedSync(sync_every_n_steps=2) + unit = DummyEvalUnit(2) + state = State(entry_point=EntryPoint.EVALUATE) + unit.eval_progress.increment_step() # 1 step completed + pds.on_eval_step_end(state, unit) + barrier_mock.assert_not_called() + + unit.eval_progress.increment_step() # 2 steps completed + pds.on_eval_step_end(state, unit) + barrier_mock.assert_called_once() diff --git a/torchtnt/framework/callbacks/periodic_distributed_sync.py b/torchtnt/framework/callbacks/periodic_distributed_sync.py index 73d080e718..b1cdda1b1d 100644 --- a/torchtnt/framework/callbacks/periodic_distributed_sync.py +++ b/torchtnt/framework/callbacks/periodic_distributed_sync.py @@ -10,8 +10,8 @@ from torchtnt.framework.callback import Callback from torchtnt.framework.state import State -from torchtnt.framework.unit import TPredictUnit -from torchtnt.utils.distributed import barrier +from torchtnt.framework.unit import TEvalUnit, TPredictUnit +from torchtnt.utils.distributed import barrier, get_global_rank logger: logging.Logger = logging.getLogger(__name__) @@ -20,7 +20,7 @@ class PeriodicDistributedSync(Callback): """ A callback to sync all distributed workers at a given frequency. Helpful when using distributed without DDP/FSDP but would still like to ensure that the workers are in sync with each other, for example large predict jobs. - Note that only predict is supported at the moment. + Both predict and evaluate are supported. Args: sync_every_n_steps: the frequency at which to sync the workers. @@ -28,9 +28,16 @@ class PeriodicDistributedSync(Callback): def __init__(self, sync_every_n_steps: int = 1000) -> None: self.sync_every_n_steps = sync_every_n_steps + self._global_rank: int = get_global_rank() def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: num_steps = unit.predict_progress.num_steps_completed if num_steps % self.sync_every_n_steps == 0: - logger.info(f"Barrier at step {num_steps}") + logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}") + barrier() + + def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: + num_steps = unit.eval_progress.num_steps_completed + if num_steps % self.sync_every_n_steps == 0: + logger.info(f"Barrier at step {num_steps} on rank {self._global_rank}") barrier()