From e3ffa1ffcbc29d3bf831a12f2fa270444c5e5e19 Mon Sep 17 00:00:00 2001 From: Gal Rotem Date: Thu, 25 Apr 2024 16:22:03 -0700 Subject: [PATCH] throughput logger Summary: Introduce throughput logger. Internal # Context The stack adds a throughput logger that can be used to log generic throughput per second, based on user config. This diff will add the throughput logger including logging per step. The next diff will add throughput on an epoch granularity. # This diff Adds throughput logger: 1. It uses the already collected iteration time and data wait time timers to get the step time. 2. It's slightly confusing but when `on_train_step_end` is called, the iteration time timer hasn't been populated yet, while the data wait time timer has been populated, hence there's a difference between the two when we are logging for (step-1). On the `on_train_end` both lists are fully populated so we can just use the last element safely. Reviewed By: JKSenthil Differential Revision: D56496451 fbshipit-source-id: e6b119b1a42264d3e764da86e853deb03bd1cf82 --- docs/source/framework/callbacks.rst | 3 +- .../callbacks/test_throughput_logger.py | 245 ++++++++++++++++++ torchtnt/framework/callbacks/__init__.py | 2 + .../framework/callbacks/throughput_logger.py | 156 +++++++++++ 4 files changed, 405 insertions(+), 1 deletion(-) create mode 100644 tests/framework/callbacks/test_throughput_logger.py create mode 100644 torchtnt/framework/callbacks/throughput_logger.py diff --git a/docs/source/framework/callbacks.rst b/docs/source/framework/callbacks.rst index 15c0ace0de..ffea4599fd 100644 --- a/docs/source/framework/callbacks.rst +++ b/docs/source/framework/callbacks.rst @@ -22,6 +22,7 @@ We offer several pre-written callbacks which are ready to be used out of the box BaseCSVWriter EarlyStopping GarbageCollector + IterationTimeLogger Lambda LearningRateMonitor MemorySnapshot @@ -33,7 +34,7 @@ We offer several pre-written callbacks which are ready to be used out of the box TensorBoardParameterMonitor TimeLimitInterrupter TimeWaitForBatchLogger - IterationTimeLogger + ThroughputLogger TorchSnapshotSaver TQDMProgressBar TrainProgressMonitor diff --git a/tests/framework/callbacks/test_throughput_logger.py b/tests/framework/callbacks/test_throughput_logger.py new file mode 100644 index 0000000000..12f578265b --- /dev/null +++ b/tests/framework/callbacks/test_throughput_logger.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from unittest.mock import ANY, call, MagicMock + +import torch +from pyre_extensions import none_throws + +from torchtnt.framework._callback_handler import CallbackHandler +from torchtnt.framework._test_utils import ( + DummyAutoUnit, + DummyPredictUnit, + generate_random_dataloader, +) +from torchtnt.framework.callbacks.throughput_logger import ThroughputLogger +from torchtnt.framework.predict import predict + +from torchtnt.framework.state import EntryPoint, PhaseState, State +from torchtnt.framework.train import _train_impl +from torchtnt.utils.loggers.logger import MetricLogger + + +class ThroughputLoggerTest(unittest.TestCase): + def test_maybe_log_for_step(self) -> None: + logger = MagicMock(spec=MetricLogger) + throughput_logger = ThroughputLogger(logger, {"Batches": 1, "Items": 32}, 1) + phase_state = PhaseState(dataloader=[]) + phase_state.iteration_timer.recorded_durations = { + "data_wait_time": [1, 4], + "train_iteration_time": [3], + } + state = State(entry_point=EntryPoint.TRAIN, train_state=phase_state) + throughput_logger._maybe_log_for_step(state, 1) + logger.log.assert_has_calls( + [ + call( + "Train: Batches per second (step granularity)", + 0.25, # 1/(1+3) + 1, + ), + call( + "Train: Items per second (step granularity)", + 8, # 32/(1+3) + 1, + ), + ], + any_order=True, + ) + logger.log.reset_mock() + phase_state.iteration_timer.recorded_durations["train_iteration_time"].append(4) + throughput_logger._maybe_log_for_step(state, 2, is_step_end_hook=False) + logger.log.assert_has_calls( + [ + call( + "Train: Batches per second (step granularity)", + 0.125, # 1/(4+4) + 2, + ), + call( + "Train: Items per second (step granularity)", + 4, # 32/(4+4) + 2, + ), + ] + ) + + def test_maybe_log_for_step_early_return(self) -> None: + logger = MagicMock(spec=MetricLogger) + throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 1) + phase_state = PhaseState(dataloader=[]) + recorded_durations_dict = { + "data_wait_time": [0.0, 4.0], + "train_iteration_time": [0.0], + } + # total_time <= 0 + phase_state.iteration_timer.recorded_durations = recorded_durations_dict + state = State(entry_point=EntryPoint.TRAIN, train_state=phase_state) + throughput_logger._maybe_log_for_step(state, step_logging_for=1) + logger.log.assert_not_called() + + # empty iteration_time_list + recorded_durations_dict["data_wait_time"] = [1.0, 2.0] + recorded_durations_dict["train_iteration_time"] = [] + throughput_logger._maybe_log_for_step(state, step_logging_for=1) + logger.log.assert_not_called() + + # small data_wait_time list + recorded_durations_dict["data_wait_time"] = [1.0] + recorded_durations_dict["train_iteration_time"] = [1.0] + throughput_logger._maybe_log_for_step(state, step_logging_for=1) + logger.log.assert_not_called() + + # step_logging_for % log_every_n_steps != 0 + recorded_durations_dict["data_wait_time"] = [1.0, 2.0] + throughput_logger = ThroughputLogger(logger, {"Batches": 1}, 2) + throughput_logger._maybe_log_for_step(state, step_logging_for=1) + logger.log.assert_not_called() + + def test_with_comparing_time(self) -> None: + logger = MagicMock(spec=MetricLogger) + dataloader = generate_random_dataloader( + num_samples=8, input_dim=2, batch_size=2 + ) + state = State( + entry_point=EntryPoint.FIT, + train_state=PhaseState( + dataloader=dataloader, + max_epochs=2, + max_steps_per_epoch=2, + ), + eval_state=PhaseState( + dataloader=dataloader, + max_steps_per_epoch=2, + evaluate_every_n_epochs=2, + ), + ) + + # we want to be able to compare the logging value to the state, so we need to create state manually and + # call _train_impl. This would have been similar to calling fit() and getting the state as a ret value + _train_impl( + state, + DummyAutoUnit(module=torch.nn.Linear(2, 2)), + CallbackHandler( + [ + ThroughputLogger( + logger=logger, + throughput_per_batch={"Batches": 1, "Queries": 8}, + log_every_n_steps=1, + ) + ], + ), + ) + + train_iteration_times = none_throws( + state.train_state + ).iteration_timer.recorded_durations["train_iteration_time"] + train_twfb_times = none_throws( + state.train_state + ).iteration_timer.recorded_durations["data_wait_time"] + eval_iteration_times = none_throws( + state.eval_state + ).iteration_timer.recorded_durations["eval_iteration_time"] + eval_twfb_times = none_throws( + state.eval_state + ).iteration_timer.recorded_durations["data_wait_time"] + + self.assertEqual(len(train_iteration_times), 4) + self.assertEqual(len(train_twfb_times), 4) + self.assertEqual(len(eval_iteration_times), 2) + self.assertEqual(len(eval_twfb_times), 2) + + train_step_times = [ + train_iteration_times[i] + train_twfb_times[i] for i in range(4) + ] + eval_step_times = [ + eval_iteration_times[i] + eval_twfb_times[i] for i in range(2) + ] + self.assertEqual( + logger.log.call_count, 12 + ) # 8 train (2epochs x 2steps x 2items), 4 eval (1x2x2) + train_batches_step_logs = [ + call( + "Train: Batches per second (step granularity)", + 1 / (train_step_times[i]), + i + 1, + ) + for i in range(4) + ] + train_queries_step_logs = [ + call( + "Train: Queries per second (step granularity)", + 8 / (train_step_times[i]), + i + 1, + ) + for i in range(4) + ] + eval_batches_step_logs = [ + call( + "Eval: Batches per second (step granularity)", + 1 / (eval_step_times[i]), + i + 1, + ) + for i in range(2) + ] + eval_queries_step_logs = [ + call( + "Eval: Queries per second (step granularity)", + 8 / (eval_step_times[i]), + i + 1, + ) + for i in range(2) + ] + logger.log.assert_has_calls( + train_batches_step_logs + + train_queries_step_logs + + eval_batches_step_logs + + eval_queries_step_logs, + any_order=True, + ) + + def test_with_predict(self) -> None: + logger = MagicMock(spec=MetricLogger) + predict( + DummyPredictUnit(input_dim=2), + generate_random_dataloader(num_samples=8, input_dim=2, batch_size=2), + max_steps_per_epoch=1, + callbacks=[ + ThroughputLogger( + logger=logger, + throughput_per_batch={"Batches": 1}, + log_every_n_steps=1, + ) + ], + ) + logger.log.assert_has_calls( + [ + call( + "Predict: Batches per second (step granularity)", + ANY, + 1, + ) + ], + ) + + def test_input_validation(self) -> None: + logger = MagicMock(spec=MetricLogger) + with self.assertRaisesRegex(ValueError, "throughput_per_batch cannot be empty"): + ThroughputLogger(logger, {}, 1) + + with self.assertRaisesRegex( + ValueError, "throughput_per_batch item Batches must be at least 1, got -1" + ): + ThroughputLogger(logger, {"Queries": 8, "Batches": -1}, 1) + + with self.assertRaisesRegex( + ValueError, "log_every_n_steps must be at least 1, got 0" + ): + ThroughputLogger(logger, {"Batches": 1}, 0) diff --git a/torchtnt/framework/callbacks/__init__.py b/torchtnt/framework/callbacks/__init__.py index 29c9996f69..8dbc29f0c1 100644 --- a/torchtnt/framework/callbacks/__init__.py +++ b/torchtnt/framework/callbacks/__init__.py @@ -21,6 +21,7 @@ from .slow_rank_detector import SlowRankDetector from .system_resources_monitor import SystemResourcesMonitor from .tensorboard_parameter_monitor import TensorBoardParameterMonitor +from .throughput_logger import ThroughputLogger from .time_limit_interrupter import TimeLimitInterrupter from .time_wait_for_batch_logger import TimeWaitForBatchLogger from .torch_compile import TorchCompile @@ -43,6 +44,7 @@ "SlowRankDetector", "SystemResourcesMonitor", "TensorBoardParameterMonitor", + "ThroughputLogger", "TimeLimitInterrupter", "TimeWaitForBatchLogger", "TorchCompile", diff --git a/torchtnt/framework/callbacks/throughput_logger.py b/torchtnt/framework/callbacks/throughput_logger.py new file mode 100644 index 0000000000..e56a42ae9d --- /dev/null +++ b/torchtnt/framework/callbacks/throughput_logger.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + + +from typing import Mapping + +from pyre_extensions import none_throws + +from torchtnt.framework.callback import Callback +from torchtnt.framework.state import ActivePhase, State +from torchtnt.framework.unit import TEvalUnit, TPredictUnit, TTrainUnit +from torchtnt.utils.loggers.logger import MetricLogger + +ACTIVE_PHASE_TO_ITERATION_TIME_KEY: Mapping[ActivePhase, str] = { + ActivePhase.TRAIN: "train_iteration_time", + ActivePhase.EVALUATE: "eval_iteration_time", + ActivePhase.PREDICT: "predict_iteration_time", +} + +ACTIVE_PHASE_TO_LABEL_PREFIX: Mapping[ActivePhase, str] = { + ActivePhase.TRAIN: "Train", + ActivePhase.EVALUATE: "Eval", + ActivePhase.PREDICT: "Predict", +} + + +class ThroughputLogger(Callback): + """ + A callback which logs the train/eval/predict/fit throughput. For instance, it can be used to log QPS and number of batches processed per second. + The callback logs the throughput on a step basis. + We measure the throughput by dividing the number of batches processed (times the number of items in batch) by the time it took to process the batch: + On a step granularity, we do this by leveraging the already collected timers for the iteration time and data wait time. + + Args: + logger: A a subclass of :class:`torchtnt.utils.loggers.logger.MetricLogger`. + throughput_per_batch: a dict mapping the item name to the number of corresponding items in the batch. + For instace, a user can pass in {Batches: 1, Queries: 32} which will visualize two charts - + one for Batches per second and one for Queries per second. + As an example, if each of your batches is of type: {data: torch.Size([16, 8, 8]), labels: torch.Size([16,1])}, then you could pass {Queries: 16}. + log_every_n_steps: an optional int to control the log frequency. + + Note: + The values reported are only for rank 0. + """ + + def __init__( + self, + logger: MetricLogger, + throughput_per_batch: Mapping[str, int], + log_every_n_steps: int = 1, + ) -> None: + self._logger = logger + + if not throughput_per_batch: + raise ValueError("throughput_per_batch cannot be empty") + + for item, num_items in throughput_per_batch.items(): + if num_items < 1: + raise ValueError( + f"throughput_per_batch item {item} must be at least 1, got {num_items}" + ) + + self._throughput_per_batch = throughput_per_batch + + if log_every_n_steps < 1: + raise ValueError( + f"log_every_n_steps must be at least 1, got {log_every_n_steps}" + ) + + self._log_every_n_steps = log_every_n_steps + + def on_train_step_end(self, state: State, unit: TTrainUnit) -> None: + self._maybe_log_for_step( + state, + unit.train_progress.num_steps_completed - 1, + ) + + def on_train_end(self, state: State, unit: TTrainUnit) -> None: + self._maybe_log_for_step( + state, + unit.train_progress.num_steps_completed, + is_step_end_hook=False, + ) + + def on_eval_step_end(self, state: State, unit: TEvalUnit) -> None: + self._maybe_log_for_step( + state, + unit.eval_progress.num_steps_completed - 1, + ) + + def on_eval_end(self, state: State, unit: TEvalUnit) -> None: + self._maybe_log_for_step( + state, + unit.eval_progress.num_steps_completed, + is_step_end_hook=False, + ) + + def on_predict_step_end(self, state: State, unit: TPredictUnit) -> None: + self._maybe_log_for_step( + state, + unit.predict_progress.num_steps_completed - 1, + ) + + def on_predict_end(self, state: State, unit: TPredictUnit) -> None: + self._maybe_log_for_step( + state, + unit.predict_progress.num_steps_completed, + is_step_end_hook=False, + ) + + def _maybe_log_for_step( + self, + state: State, + step_logging_for: int, + *, + is_step_end_hook: bool = True, + ) -> None: + if step_logging_for % self._log_every_n_steps != 0: + return + + active_phase_state = none_throws(state.active_phase_state()) + timer_recorded_durations = active_phase_state.iteration_timer.recorded_durations + iteration_time_list = timer_recorded_durations.get( + ACTIVE_PHASE_TO_ITERATION_TIME_KEY[state.active_phase] + ) + data_wait_time_list = timer_recorded_durations.get("data_wait_time") + + # if it's a step hook, we're logging for the previous step, but the data wait time list + # has already been populated with the current step, so the offset is 2 + data_wait_time_offset = 2 if is_step_end_hook else 1 + + if ( + (not iteration_time_list) + or (not data_wait_time_list) + or len(data_wait_time_list) < data_wait_time_offset + ): + return + + prev_iteration_time = iteration_time_list[-1] + data_wait_time = data_wait_time_list[-data_wait_time_offset] + total_time = prev_iteration_time + data_wait_time + + if total_time <= 0: + return + + for item, num_items in self._throughput_per_batch.items(): + self._logger.log( + f"{ACTIVE_PHASE_TO_LABEL_PREFIX[state.active_phase]}: {item} per second (step granularity)", + num_items / total_time, + step_logging_for, + )