Skip to content

Commit

Permalink
add get_next_X_batch to unit (#617)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #617

Extract out the next-batch fetching logic from the loop script (train/evaluate/predict) to a method on Unit/AutoUnit

Reviewed By: JKSenthil

Differential Revision: D51162751

fbshipit-source-id: 67302adfd1ffce91f4942218e442ddfd16731c16
  • Loading branch information
galrotem authored and facebook-github-bot committed Nov 14, 2023
1 parent 8b258a6 commit f8e5313
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 70 deletions.
28 changes: 21 additions & 7 deletions tests/framework/test_auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,10 @@ def test_mixed_precision_fp16(self, mock_autocast: MagicMock) -> None:
)
dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))]
state = get_dummy_train_state(dummy_iterable)
auto_unit.train_step(state=state, data=iter(dummy_iterable))
auto_unit.train_step(
state=state,
data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)),
)
mock_autocast.assert_called_with(
device_type="cuda", dtype=torch.float16, enabled=True
)
Expand All @@ -159,7 +162,10 @@ def test_mixed_precision_bf16(self, mock_autocast: MagicMock) -> None:
)
dummy_iterable = [(torch.ones(2, 2), torch.ones(2, 2))]
state = get_dummy_train_state(dummy_iterable)
auto_unit.train_step(state=state, data=iter(dummy_iterable))
auto_unit.train_step(
state=state,
data=auto_unit.get_next_train_batch(state, iter(dummy_iterable)),
)
mock_autocast.assert_called_with(
device_type="cuda", dtype=torch.bfloat16, enabled=True
)
Expand Down Expand Up @@ -431,13 +437,17 @@ def _test_ddp_no_sync() -> None:

# for the first step no_sync should be called since we accumulate gradients
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(state=state, data=dummy_iterator)
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_called_once()

auto_unit.train_progress.increment_step()
# for the second step no_sync should not be called since we run optimizer step
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(state=state, data=dummy_iterator)
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_not_called()

@staticmethod
Expand All @@ -462,13 +472,17 @@ def _test_fsdp_no_sync() -> None:

# for the first step no_sync should be called since we accumulate gradients
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(state=state, data=dummy_iterator)
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_called_once()

auto_unit.train_progress.increment_step()
# for the second step no_sync should not be called since we run optimizer step
with patch.object(auto_unit.module, "no_sync") as no_sync_mock:
auto_unit.train_step(state=state, data=dummy_iterator)
auto_unit.train_step(
state=state, data=auto_unit.get_next_train_batch(state, dummy_iterator)
)
no_sync_mock.assert_not_called()

def test_move_data_to_device(self) -> None:
Expand All @@ -493,7 +507,7 @@ def test_move_data_to_device(self) -> None:
) as move_data_to_device_mock:
dummy_data = copy_data_to_device(dummy_data, device)
move_data_to_device_mock.return_value = dummy_data
auto_unit.train_step(state=state, data=data_iter)
auto_unit._get_next_batch(state=state, data=data_iter)
move_data_to_device_mock.assert_called_once()

def test_configure_optimizers_and_lr_scheduler(self) -> None:
Expand Down
58 changes: 58 additions & 0 deletions tests/framework/test_unit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import unittest

from typing import Iterator

import torch
from torchtnt.framework._test_utils import get_dummy_train_state
from torchtnt.framework.state import State
from torchtnt.framework.unit import EvalUnit, PredictUnit, TrainUnit


class TestUnit(
EvalUnit[Iterator[torch.Tensor]], PredictUnit[torch.Tensor], TrainUnit[torch.Tensor]
):
def __init__(self) -> None:
super().__init__()

def train_step(self, state: State, data: torch.Tensor) -> None:
return

def eval_step(self, state: State, data: Iterator[torch.Tensor]) -> None:
return

def predict_step(self, state: State, data: torch.Tensor) -> None:
return


class UnitTest(unittest.TestCase):
def test_initialization_and_get_next_batch(self) -> None:
unit = TestUnit()
self.assertIsNotNone(unit.train_progress)
self.assertIsNotNone(unit.eval_progress)
self.assertIsNotNone(unit.predict_progress)

tensor_1 = torch.ones(1)
tensor_2 = torch.zeros(1)
state = get_dummy_train_state()

# test train next batch - exepct to return the elements within the iterable
train_data_iter = iter([tensor_1, tensor_2])
self.assertEqual(unit.get_next_train_batch(state, train_data_iter), tensor_1)
self.assertEqual(unit.get_next_train_batch(state, train_data_iter), tensor_2)

# test predict next batch - exepct to return the elements within the iterable
self.assertEqual(
unit.get_next_predict_batch(state, iter([tensor_1, tensor_2])), tensor_1
)

# test eval next batch - exepct to return the iterable
data_iter = iter([tensor_1, tensor_2])
next_eval_batch = unit.get_next_eval_batch(state, data_iter)
self.assertEqual(next_eval_batch, data_iter)
81 changes: 50 additions & 31 deletions torchtnt/framework/auto_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import (
Any,
Callable,
cast,
Generic,
Iterator,
List,
Expand All @@ -29,7 +30,7 @@
from torch.optim.swa_utils import SWALR
from torchtnt.framework.state import ActivePhase, EntryPoint, State
from torchtnt.framework.unit import EvalUnit, PredictUnit, TPredictData, TrainUnit
from torchtnt.framework.utils import get_timing_context
from torchtnt.framework.utils import _step_requires_iterator, get_timing_context
from torchtnt.utils.device import copy_data_to_device, record_data_in_stream
from torchtnt.utils.env import init_from_env
from torchtnt.utils.lr_scheduler import TLRScheduler
Expand Down Expand Up @@ -307,10 +308,7 @@ def __init__(
)

# pyre-fixme[3]: Return annotation cannot be `Any`.
def predict_step(self, state: State, data: Iterator[TPredictData]) -> Any:
with none_throws(state.predict_state).iteration_timer.time("data_wait_time"):
batch = self._get_next_batch(state, data)

def predict_step(self, state: State, data: TPredictData) -> Any:
# if detect_anomaly is true, run forward pass under detect_anomaly context
detect_anomaly = self.detect_anomaly
maybe_detect_anomaly = (
Expand All @@ -321,10 +319,10 @@ def predict_step(self, state: State, data: Iterator[TPredictData]) -> Any:

with self.maybe_autocast_precision, maybe_detect_anomaly:
with get_timing_context(state, f"{self.__class__.__name__}.forward"):
outputs = self.module(batch)
outputs = self.module(data)

step = self.predict_progress.num_steps_completed
self.on_predict_step_end(state, batch, step, outputs)
self.on_predict_step_end(state, data, step, outputs)
return outputs

def on_predict_step_end(
Expand All @@ -347,6 +345,15 @@ def on_predict_step_end(
"""
pass

def get_next_predict_batch(
self, state: State, data_iter: Iterator[TPredictData]
) -> Union[Iterator[TPredictData], TPredictData]:
# Override the default behavior from PredictUnit in order to enable prefetching if possible.
pass_data_iter_to_step = _step_requires_iterator(self.predict_step)
if pass_data_iter_to_step:
return data_iter
return self._get_next_batch(state, data_iter)


class AutoUnit(
_AutoUnitMixin[TData],
Expand Down Expand Up @@ -523,14 +530,7 @@ def compute_loss(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
...

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def train_step(
self, state: State, data: Iterator[TData]
) -> Tuple[torch.Tensor, Any]:
# In auto unit they will not be exclusive since data fetching is done as
# part of the training step
with none_throws(state.train_state).iteration_timer.time("data_wait_time"):
batch = self._get_next_batch(state, data)

def train_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
should_update_weights = (
self.train_progress.num_steps_completed_in_epoch + 1
) % self.gradient_accumulation_steps == 0 or self._is_last_batch
Expand Down Expand Up @@ -564,7 +564,7 @@ def train_step(
state, f"{self.__class__.__name__}.compute_loss"
):
# Run the forward pass and compute the loss
loss, outputs = self.compute_loss(state, batch)
loss, outputs = self.compute_loss(state, data)

# normalize loss to account for gradient accumulation
loss = loss / self.gradient_accumulation_steps
Expand All @@ -581,7 +581,7 @@ def train_step(
self._update_weights(state)

step = self.train_progress.num_steps_completed
self.on_train_step_end(state, batch, step, loss, outputs)
self.on_train_step_end(state, data, step, loss, outputs)
return loss, outputs

def on_train_step_end(
Expand Down Expand Up @@ -617,23 +617,18 @@ def on_train_epoch_end(self, state: State) -> None:
self._is_last_batch = False

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def eval_step(
self, state: State, data: Iterator[TData]
) -> Tuple[torch.Tensor, Any]:
with none_throws(state.eval_state).iteration_timer.time("data_wait_time"):
batch = self._get_next_batch(state, data)

def eval_step(self, state: State, data: TData) -> Tuple[torch.Tensor, Any]:
with self.maybe_autocast_precision:
# users must override this
with get_timing_context(state, f"{self.__class__.__name__}.compute_loss"):
loss, outputs = self.compute_loss(state, batch)
loss, outputs = self.compute_loss(state, data)

if state.entry_point == EntryPoint.FIT:
step = self.train_progress.num_steps_completed
else:
step = self.eval_progress.num_steps_completed

self.on_eval_step_end(state, batch, step, loss, outputs)
self.on_eval_step_end(state, data, step, loss, outputs)
return loss, outputs

def on_eval_step_end(
Expand All @@ -659,20 +654,17 @@ def on_eval_step_end(
pass

# pyre-fixme[3]: Return annotation cannot contain `Any`.
def predict_step(self, state: State, data: Iterator[TData]) -> Any:
with none_throws(state.predict_state).iteration_timer.time("data_wait_time"):
batch = self._get_next_batch(state, data)

def predict_step(self, state: State, data: TData) -> Any:
with self.maybe_autocast_precision:
with get_timing_context(state, f"{self.__class__.__name__}.forward"):
outputs = self.module(batch)
outputs = self.module(data)

step = self.predict_progress.num_steps_completed
# users can override this, by default this is a no-op
with get_timing_context(
state, f"{self.__class__.__name__}.on_predict_step_end"
):
self.on_predict_step_end(state, batch, step, outputs)
self.on_predict_step_end(state, data, step, outputs)
return outputs

def on_predict_step_end(
Expand Down Expand Up @@ -755,6 +747,33 @@ def _update_weights(self, state: State) -> None:
if self.step_lr_interval == "step":
self._update_lr_and_swa(state, self.train_progress.num_steps_completed)

def get_next_train_batch(
self, state: State, data_iter: Iterator[TData]
) -> Union[Iterator[TData], TData]:
# Override the default behavior from PredictUnit in order to enable prefetching if possible.
pass_data_iter_to_step = _step_requires_iterator(self.train_step)
if pass_data_iter_to_step:
return data_iter
return self._get_next_batch(state, data_iter)

def get_next_eval_batch(
self, state: State, data_iter: Iterator[TData]
) -> Union[Iterator[TData], TData]:
# Override the default behavior from PredictUnit in order to enable prefetching if possible.
pass_data_iter_to_step = _step_requires_iterator(self.eval_step)
if pass_data_iter_to_step:
return data_iter
return self._get_next_batch(state, data_iter)

def get_next_predict_batch(
self, state: State, data_iter: Iterator[TData]
) -> Union[Iterator[TData], TData]:
# Override the default behavior from PredictUnit in order to enable prefetching if possible.
pass_data_iter_to_step = _step_requires_iterator(self.predict_step)
if pass_data_iter_to_step:
return data_iter
return self._get_next_batch(state, data_iter)

def _should_update_swa(self) -> bool:
if not self.swa_params:
return False
Expand Down
14 changes: 5 additions & 9 deletions torchtnt/framework/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
_is_epoch_done,
_reset_module_training_mode,
_set_module_training_mode,
_step_requires_iterator,
get_timing_context,
log_api_usage,
)
Expand Down Expand Up @@ -66,7 +65,7 @@ def evaluate(
while not done:
call on_eval_epoch_start on unit first and then callbacks
try:
data = next(dataloader)
call get_next_eval_batch on unit
call on_eval_step_start on callbacks
call eval_step on unit
increment step counter
Expand Down Expand Up @@ -132,7 +131,6 @@ def _evaluate_impl(
data_iter = iter(eval_state.dataloader)
step_input = data_iter

pass_data_iter_to_step = _step_requires_iterator(eval_unit.eval_step)
prev_steps_in_epoch = eval_unit.eval_progress.num_steps_completed_in_epoch

while not (
Expand All @@ -144,12 +142,10 @@ def _evaluate_impl(
)
):
try:
if not pass_data_iter_to_step:
# get the next batch from the data iterator
with get_timing_context(
state, "evaluate.next(data_iter)"
), eval_state.iteration_timer.time("data_wait_time"):
step_input = next(data_iter)
with get_timing_context(
state, "evaluate.next(data_iter)"
), eval_state.iteration_timer.time("data_wait_time"):
step_input = eval_unit.get_next_eval_batch(state, data_iter)
with eval_state.iteration_timer.time("eval_iteration_time"):
callback_handler.on_eval_step_start(state, eval_unit)
eval_state._step_output = eval_unit.eval_step(state, step_input)
Expand Down
2 changes: 1 addition & 1 deletion torchtnt/framework/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def fit(
while epoch is not done:
call on_train_epoch_start on unit first and then callbacks
try:
data = next(dataloader)
call get_next_train_batch on unit
call on_train_step_start on callbacks
call train_step on unit
increment step counter
Expand Down
Loading

0 comments on commit f8e5313

Please sign in to comment.