From 4a57bda1a2dc9e988a2edcb478f1ef4bf8ea4474 Mon Sep 17 00:00:00 2001 From: Jason Senthil Date: Wed, 14 Feb 2024 09:00:02 -0800 Subject: [PATCH] add file arg to tqdm utils / callbacks Summary: # Context tqdm by default prints to stderr. This can be unintuitive location for the progress bar to appear # This Diff Adds `file` arg so users can pass `sys.stdout` in they want to print there instead Reviewed By: gunchu Differential Revision: D53718992 fbshipit-source-id: 8dce3a23880c60a1e6238916402a008108881581 --- tests/utils/test_tqdm.py | 50 +++++++++++++++++++ .../framework/callbacks/tqdm_progress_bar.py | 14 +++++- torchtnt/utils/tqdm.py | 6 ++- 3 files changed, 67 insertions(+), 3 deletions(-) create mode 100644 tests/utils/test_tqdm.py diff --git a/tests/utils/test_tqdm.py b/tests/utils/test_tqdm.py new file mode 100644 index 0000000000..d1b1b827e5 --- /dev/null +++ b/tests/utils/test_tqdm.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import sys +import unittest +from io import StringIO +from unittest.mock import MagicMock, patch + +from torchtnt.utils.tqdm import create_progress_bar + + +class TQDMTest(unittest.TestCase): + @patch("sys.stdout", new_callable=StringIO) + @patch("sys.stderr", new_callable=StringIO) + def test_tqdm_file(self, mock_stderr: MagicMock, mock_stdout: MagicMock) -> None: + """ + Test the file argument to create_progress_bar + """ + + create_progress_bar( + dataloader=["foo", "bar"], + desc="foo", + num_epochs_completed=0, + num_steps_completed=0, + max_steps=None, + max_steps_per_epoch=None, + file=None, + ) + self.assertIn( + "foo 0: 0%| | 0/2 [00:00 None: + def __init__( + self, + refresh_rate: int = 1, + file: Optional[Union[TextIO, io.StringIO]] = None, + ) -> None: self._refresh_rate = refresh_rate + self._file = file self._train_progress_bar: Optional[tqdm] = None self._eval_progress_bar: Optional[tqdm] = None @@ -46,6 +53,7 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None: num_steps_completed=unit.train_progress.num_steps_completed_in_epoch, max_steps=train_state.max_steps, max_steps_per_epoch=train_state.max_steps_per_epoch, + file=self._file, ) def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: @@ -76,6 +84,7 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None: num_steps_completed=unit.eval_progress.num_steps_completed_in_epoch, max_steps=eval_state.max_steps, max_steps_per_epoch=eval_state.max_steps_per_epoch, + file=self._file, ) def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: @@ -106,6 +115,7 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None: num_steps_completed=unit.predict_progress.num_steps_completed, max_steps=predict_state.max_steps, max_steps_per_epoch=predict_state.max_steps_per_epoch, + file=self._file, ) def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: diff --git a/torchtnt/utils/tqdm.py b/torchtnt/utils/tqdm.py index 0e2f047e7c..343315a3bb 100644 --- a/torchtnt/utils/tqdm.py +++ b/torchtnt/utils/tqdm.py @@ -5,8 +5,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import io import logging -from typing import Iterable, Optional +from typing import Iterable, Optional, TextIO, Union from torchtnt.utils.progress import estimated_steps_in_epoch from tqdm.auto import tqdm @@ -22,6 +23,7 @@ def create_progress_bar( num_steps_completed: int, max_steps: Optional[int], max_steps_per_epoch: Optional[int], + file: Optional[Union[TextIO, io.StringIO]] = None, ) -> tqdm: """Constructs a :func:`tqdm` progress bar. The number of steps in an epoch is inferred from the dataloader, num_steps_completed, max_steps and max_steps_per_epoch. @@ -32,6 +34,7 @@ def create_progress_bar( num_steps_completed: an integer for the number of steps completed so far in the loop. max_steps: an optional integer for the number of max steps in the loop. max_steps_per_epoch: an optional integer for the number of max steps per epoch. + file: specifies where to output the progress messages (default: sys.stderr) """ current_epoch = num_epochs_completed total = estimated_steps_in_epoch( @@ -45,6 +48,7 @@ def create_progress_bar( total=total, initial=num_steps_completed, bar_format="{l_bar}{bar}{r_bar}\n", + file=file, )