Skip to content

Commit

Permalink
add file arg to tqdm utils / callbacks
Browse files Browse the repository at this point in the history
Summary:
# Context
tqdm by default prints to stderr. This can be unintuitive location for the progress bar to appear
# This Diff
Adds `file` arg so users can pass `sys.stdout` in they want to print there instead

Reviewed By: gunchu

Differential Revision: D53718992

fbshipit-source-id: 8dce3a23880c60a1e6238916402a008108881581
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Feb 14, 2024
1 parent bbb696a commit 4a57bda
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 3 deletions.
50 changes: 50 additions & 0 deletions tests/utils/test_tqdm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import sys
import unittest
from io import StringIO
from unittest.mock import MagicMock, patch

from torchtnt.utils.tqdm import create_progress_bar


class TQDMTest(unittest.TestCase):
@patch("sys.stdout", new_callable=StringIO)
@patch("sys.stderr", new_callable=StringIO)
def test_tqdm_file(self, mock_stderr: MagicMock, mock_stdout: MagicMock) -> None:
"""
Test the file argument to create_progress_bar
"""

create_progress_bar(
dataloader=["foo", "bar"],
desc="foo",
num_epochs_completed=0,
num_steps_completed=0,
max_steps=None,
max_steps_per_epoch=None,
file=None,
)
self.assertIn(
"foo 0: 0%| | 0/2 [00:00<?, ?it/s]", mock_stderr.getvalue()
)
# ensure nothing written to stdout
self.assertEqual(mock_stdout.getvalue(), "")

create_progress_bar(
dataloader=["foo", "bar"],
desc="foo",
num_epochs_completed=0,
num_steps_completed=0,
max_steps=None,
max_steps_per_epoch=None,
file=sys.stdout,
)
self.assertIn(
"foo 0: 0%| | 0/2 [00:00<?, ?it/s]", mock_stdout.getvalue()
)
14 changes: 12 additions & 2 deletions torchtnt/framework/callbacks/tqdm_progress_bar.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
import io
from typing import Optional, TextIO, Union

from pyre_extensions import none_throws

Expand All @@ -27,10 +28,16 @@ class TQDMProgressBar(Callback):
Args:
refresh_rate: Determines at which rate (in number of steps) the progress bars get updated.
file: specifies where to output the progress messages (default: sys.stderr)
"""

def __init__(self, refresh_rate: int = 1) -> None:
def __init__(
self,
refresh_rate: int = 1,
file: Optional[Union[TextIO, io.StringIO]] = None,
) -> None:
self._refresh_rate = refresh_rate
self._file = file

self._train_progress_bar: Optional[tqdm] = None
self._eval_progress_bar: Optional[tqdm] = None
Expand All @@ -46,6 +53,7 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
num_steps_completed=unit.train_progress.num_steps_completed_in_epoch,
max_steps=train_state.max_steps,
max_steps_per_epoch=train_state.max_steps_per_epoch,
file=self._file,
)

def on_train_step_end(self, state: State, unit: TTrainUnit) -> None:
Expand Down Expand Up @@ -76,6 +84,7 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
num_steps_completed=unit.eval_progress.num_steps_completed_in_epoch,
max_steps=eval_state.max_steps,
max_steps_per_epoch=eval_state.max_steps_per_epoch,
file=self._file,
)

def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None:
Expand Down Expand Up @@ -106,6 +115,7 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
num_steps_completed=unit.predict_progress.num_steps_completed,
max_steps=predict_state.max_steps,
max_steps_per_epoch=predict_state.max_steps_per_epoch,
file=self._file,
)

def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None:
Expand Down
6 changes: 5 additions & 1 deletion torchtnt/utils/tqdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import io
import logging
from typing import Iterable, Optional
from typing import Iterable, Optional, TextIO, Union

from torchtnt.utils.progress import estimated_steps_in_epoch
from tqdm.auto import tqdm
Expand All @@ -22,6 +23,7 @@ def create_progress_bar(
num_steps_completed: int,
max_steps: Optional[int],
max_steps_per_epoch: Optional[int],
file: Optional[Union[TextIO, io.StringIO]] = None,
) -> tqdm:
"""Constructs a :func:`tqdm` progress bar. The number of steps in an epoch is inferred from the dataloader, num_steps_completed, max_steps and max_steps_per_epoch.
Expand All @@ -32,6 +34,7 @@ def create_progress_bar(
num_steps_completed: an integer for the number of steps completed so far in the loop.
max_steps: an optional integer for the number of max steps in the loop.
max_steps_per_epoch: an optional integer for the number of max steps per epoch.
file: specifies where to output the progress messages (default: sys.stderr)
"""
current_epoch = num_epochs_completed
total = estimated_steps_in_epoch(
Expand All @@ -45,6 +48,7 @@ def create_progress_bar(
total=total,
initial=num_steps_completed,
bar_format="{l_bar}{bar}{r_bar}\n",
file=file,
)


Expand Down

0 comments on commit 4a57bda

Please sign in to comment.