-
Notifications
You must be signed in to change notification settings - Fork 266
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add progress reporter callback (#785)
Summary: Pull Request resolved: #785 Reviewed By: JKSenthil Differential Revision: D56175728 fbshipit-source-id: be61bf67dd0b0ac18d3633574ac7f91259e08432
- Loading branch information
1 parent
5beb537
commit 6de95a5
Showing
4 changed files
with
154 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
# pyre-strict | ||
|
||
import unittest | ||
|
||
import torch | ||
from torchtnt.framework._test_utils import DummyAutoUnit | ||
from torchtnt.framework.callbacks.progress_reporter import ProgressReporter | ||
from torchtnt.framework.state import EntryPoint, State | ||
from torchtnt.utils.distributed import get_global_rank, spawn_multi_process | ||
from torchtnt.utils.progress import Progress | ||
|
||
|
||
class ProgressReporterTest(unittest.TestCase): | ||
def test_log_with_rank(self) -> None: | ||
spawn_multi_process(2, "gloo", self._test_log_with_rank) | ||
|
||
@staticmethod | ||
def _test_log_with_rank() -> None: | ||
progress_reporter = ProgressReporter() | ||
unit = DummyAutoUnit(module=torch.nn.Linear(2, 2)) | ||
unit.train_progress = Progress( | ||
num_epochs_completed=1, | ||
num_steps_completed=5, | ||
num_steps_completed_in_epoch=3, | ||
) | ||
unit.eval_progress = Progress( | ||
num_epochs_completed=2, | ||
num_steps_completed=15, | ||
num_steps_completed_in_epoch=7, | ||
) | ||
state = State(entry_point=EntryPoint.FIT) | ||
tc = unittest.TestCase() | ||
with tc.assertLogs(level="INFO") as log: | ||
progress_reporter.on_train_end(state, unit) | ||
tc.assertEqual( | ||
log.output, | ||
[ | ||
f"INFO:torchtnt.framework.callbacks.progress_reporter:Progress Reporter: rank {get_global_rank()} at on_train_end. " | ||
"Train progress: completed epochs: 1, completed steps: 5, completed steps in current epoch: 3. " | ||
"Eval progress: completed epochs: 2, completed steps: 15, completed steps in current epoch: 7." | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import logging | ||
from typing import cast | ||
|
||
from torchtnt.framework.callback import Callback | ||
from torchtnt.framework.state import EntryPoint, State | ||
from torchtnt.framework.unit import AppStateMixin, TEvalUnit, TPredictUnit, TTrainUnit | ||
from torchtnt.utils.distributed import get_global_rank | ||
|
||
logger: logging.Logger = logging.getLogger(__name__) | ||
|
||
|
||
class ProgressReporter(Callback): | ||
""" | ||
A simple callback which logs the progress at each loop start/end, epoch start/end and step start/end. | ||
This is useful to debug certain issues, for which the root cause might be unequal progress across ranks, for instance NCCL timeouts. | ||
If used, it's recommended to pass this callback as the first item in the callbacks list. | ||
""" | ||
|
||
def on_train_start(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_train_start") | ||
|
||
def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_train_epoch_start") | ||
|
||
def on_train_step_start(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_train_step_start") | ||
|
||
def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_train_step_end") | ||
|
||
def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_train_epoch_end") | ||
|
||
def on_train_end(self, state: State, unit: TTrainUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_train_end") | ||
|
||
def on_eval_start(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_eval_start") | ||
|
||
def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_eval_epoch_start") | ||
|
||
def on_eval_step_start(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_eval_step_start") | ||
|
||
def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_eval_step_end") | ||
|
||
def on_eval_epoch_end(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_eval_epoch_end") | ||
|
||
def on_eval_end(self, state: State, unit: TEvalUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_eval_end") | ||
|
||
def on_predict_start(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_predict_start") | ||
|
||
def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_predict_epoch_start") | ||
|
||
def on_predict_step_start(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_predict_step_start") | ||
|
||
def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_predict_step_end") | ||
|
||
def on_predict_epoch_end(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_predict_epoch_end") | ||
|
||
def on_predict_end(self, state: State, unit: TPredictUnit) -> None: | ||
self._log_with_rank_and_unit(state, unit, "on_predict_end") | ||
|
||
@classmethod | ||
def _log_with_rank_and_unit( | ||
cls, state: State, unit: AppStateMixin, hook: str | ||
) -> None: | ||
output_str = f"Progress Reporter: rank {get_global_rank()} at {hook}." | ||
if state.entry_point == EntryPoint.TRAIN: | ||
output_str = f"{output_str} Train progress: {cast(TTrainUnit, unit).train_progress.get_progress_string()}" | ||
|
||
elif state.entry_point == EntryPoint.EVALUATE: | ||
output_str = f"{output_str} Eval progress: {cast(TEvalUnit, unit).eval_progress.get_progress_string()}" | ||
|
||
elif state.entry_point == EntryPoint.PREDICT: | ||
output_str = f"{output_str} Predict progress: {cast(TPredictUnit, unit).predict_progress.get_progress_string()}" | ||
|
||
elif state.entry_point == EntryPoint.FIT: | ||
output_str = f"{output_str} Train progress: {cast(TTrainUnit, unit).train_progress.get_progress_string()} Eval progress: {cast(TEvalUnit, unit).eval_progress.get_progress_string()}" | ||
|
||
else: | ||
raise ValueError( | ||
f"State entry point {state.entry_point} is not supported in ProgressReporter" | ||
) | ||
|
||
logger.info(output_str) |