Skip to content

Commit

Permalink
Add more dataloader hooks to the Callback interface (#937)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #937

Reviewed By: JKSenthil

Differential Revision: D64909630
  • Loading branch information
diego-urgell authored and facebook-github-bot committed Oct 31, 2024
1 parent 641c313 commit 19497bc
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 2 deletions.
75 changes: 75 additions & 0 deletions tests/framework/test_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ def on_train_start(self, state: State, unit: TTrainUnit) -> None:
def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
self.called_hooks.add("on_train_epoch_start")

def on_train_dataloader_iter_creation_start(
self, state: State, unit: TTrainUnit
) -> None:
self.called_hooks.add("on_train_dataloader_iter_creation_start")

def on_train_dataloader_iter_creation_end(
self, state: State, unit: TTrainUnit
) -> None:
self.called_hooks.add("on_train_dataloader_iter_creation_end")

def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None:
self.called_hooks.add("on_train_get_next_batch_start")

def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None:
self.called_hooks.add("on_train_get_next_batch_end")

Expand All @@ -67,6 +80,19 @@ def on_eval_start(self, state: State, unit: TEvalUnit) -> None:
def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
self.called_hooks.add("on_eval_epoch_start")

def on_eval_dataloader_iter_creation_start(
self, state: State, unit: TEvalUnit
) -> None:
self.called_hooks.add("on_eval_dataloader_iter_creation_start")

def on_eval_dataloader_iter_creation_end(
self, state: State, unit: TEvalUnit
) -> None:
self.called_hooks.add("on_eval_dataloader_iter_creation_end")

def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None:
self.called_hooks.add("on_eval_get_next_batch_start")

def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None:
self.called_hooks.add("on_eval_get_next_batch_end")

Expand All @@ -85,6 +111,19 @@ def on_eval_end(self, state: State, unit: TEvalUnit) -> None:
def on_predict_start(self, state: State, unit: TPredictUnit) -> None:
self.called_hooks.add("on_predict_start")

def on_predict_dataloader_iter_creation_start(
self, state: State, unit: TPredictUnit
) -> None:
self.called_hooks.add("on_predict_dataloader_iter_creation_start")

def on_predict_dataloader_iter_creation_end(
self, state: State, unit: TPredictUnit
) -> None:
self.called_hooks.add("on_predict_dataloader_iter_creation_end")

def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None:
self.called_hooks.add("on_predict_get_next_batch_start")

def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
self.called_hooks.add("on_predict_epoch_start")

Expand Down Expand Up @@ -129,6 +168,15 @@ def test_callback_handler(self) -> None:
cb_handler.on_train_epoch_start(state, unit)
self.assertIn("on_train_epoch_start", called_hooks)

cb_handler.on_train_dataloader_iter_creation_start(state, unit)
self.assertIn("on_train_dataloader_iter_creation_start", called_hooks)

cb_handler.on_train_dataloader_iter_creation_end(state, unit)
self.assertIn("on_train_dataloader_iter_creation_end", called_hooks)

cb_handler.on_train_get_next_batch_start(state, unit)
self.assertIn("on_train_get_next_batch_start", called_hooks)

cb_handler.on_train_get_next_batch_end(state, unit)
self.assertIn("on_train_get_next_batch_end", called_hooks)

Expand All @@ -154,6 +202,15 @@ def test_callback_handler(self) -> None:
cb_handler.on_eval_epoch_start(state, unit)
self.assertIn("on_eval_epoch_start", called_hooks)

cb_handler.on_eval_dataloader_iter_creation_start(state, unit)
self.assertIn("on_eval_dataloader_iter_creation_start", called_hooks)

cb_handler.on_eval_dataloader_iter_creation_end(state, unit)
self.assertIn("on_eval_dataloader_iter_creation_end", called_hooks)

cb_handler.on_eval_get_next_batch_start(state, unit)
self.assertIn("on_eval_get_next_batch_start", called_hooks)

cb_handler.on_eval_get_next_batch_end(state, unit)
self.assertIn("on_eval_get_next_batch_end", called_hooks)

Expand All @@ -179,6 +236,15 @@ def test_callback_handler(self) -> None:
cb_handler.on_predict_epoch_start(state, unit)
self.assertIn("on_predict_epoch_start", called_hooks)

cb_handler.on_predict_dataloader_iter_creation_start(state, unit)
self.assertIn("on_predict_dataloader_iter_creation_start", called_hooks)

cb_handler.on_predict_dataloader_iter_creation_end(state, unit)
self.assertIn("on_predict_dataloader_iter_creation_end", called_hooks)

cb_handler.on_predict_get_next_batch_start(state, unit)
self.assertIn("on_predict_get_next_batch_start", called_hooks)

cb_handler.on_predict_get_next_batch_end(state, unit)
self.assertIn("on_predict_get_next_batch_end", called_hooks)

Expand All @@ -202,20 +268,29 @@ def test_get_implemented_callback_mapping(self) -> None:
remaining_callback_hooks = (
"on_train_start",
"on_train_epoch_start",
"on_train_dataloader_iter_creation_start",
"on_train_dataloader_iter_creation_end",
"on_train_get_next_batch_start",
"on_train_get_next_batch_end",
"on_train_step_start",
"on_train_step_end",
"on_train_epoch_end",
"on_train_end",
"on_eval_start",
"on_eval_epoch_start",
"on_eval_dataloader_iter_creation_start",
"on_eval_dataloader_iter_creation_end",
"on_eval_get_next_batch_start",
"on_eval_get_next_batch_end",
"on_eval_step_start",
"on_eval_step_end",
"on_eval_epoch_end",
"on_eval_end",
"on_predict_start",
"on_predict_epoch_start",
"on_predict_dataloader_iter_creation_start",
"on_predict_dataloader_iter_creation_end",
"on_predict_get_next_batch_start",
"on_predict_get_next_batch_end",
"on_predict_step_start",
"on_predict_step_end",
Expand Down
75 changes: 75 additions & 0 deletions torchtnt/framework/_callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,20 +63,29 @@ def _get_implemented_callback_mapping(
"on_exception",
"on_train_start",
"on_train_epoch_start",
"on_train_dataloader_iter_creation_start",
"on_train_dataloader_iter_creation_end",
"on_train_get_next_batch_start",
"on_train_get_next_batch_end",
"on_train_step_start",
"on_train_step_end",
"on_train_epoch_end",
"on_train_end",
"on_eval_start",
"on_eval_epoch_start",
"on_eval_dataloader_iter_creation_start",
"on_eval_dataloader_iter_creation_end",
"on_eval_get_next_batch_start",
"on_eval_get_next_batch_end",
"on_eval_step_start",
"on_eval_step_end",
"on_eval_epoch_end",
"on_eval_end",
"on_predict_start",
"on_predict_epoch_start",
"on_predict_dataloader_iter_creation_start",
"on_predict_dataloader_iter_creation_end",
"on_predict_get_next_batch_start",
"on_predict_get_next_batch_end",
"on_predict_step_start",
"on_predict_step_end",
Expand Down Expand Up @@ -127,6 +136,28 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
for cb in callbacks:
cb.on_train_epoch_start(state, unit)

def on_train_dataloader_iter_creation_start(
self, state: State, unit: TTrainUnit
) -> None:
fn_name = "on_train_dataloader_iter_creation_start"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_train_dataloader_iter_creation_start(state, unit)

def on_train_dataloader_iter_creation_end(
self, state: State, unit: TTrainUnit
) -> None:
fn_name = "on_train_dataloader_iter_creation_end"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_train_dataloader_iter_creation_end(state, unit)

def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None:
fn_name = "on_train_get_next_batch_start"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_train_get_next_batch_start(state, unit)

def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None:
fn_name = "on_train_get_next_batch_end"
callbacks = self._callbacks.get(fn_name, [])
Expand Down Expand Up @@ -169,6 +200,28 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
for cb in callbacks:
cb.on_eval_epoch_start(state, unit)

def on_eval_dataloader_iter_creation_start(
self, state: State, unit: TEvalUnit
) -> None:
fn_name = "on_eval_dataloader_iter_creation_start"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_eval_dataloader_iter_creation_start(state, unit)

def on_eval_dataloader_iter_creation_end(
self, state: State, unit: TEvalUnit
) -> None:
fn_name = "on_eval_dataloader_iter_creation_end"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_eval_dataloader_iter_creation_end(state, unit)

def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None:
fn_name = "on_eval_get_next_batch_start"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_eval_get_next_batch_start(state, unit)

def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None:
fn_name = "on_eval_get_next_batch_end"
callbacks = self._callbacks.get(fn_name, [])
Expand Down Expand Up @@ -211,6 +264,28 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
for cb in callbacks:
cb.on_predict_epoch_start(state, unit)

def on_predict_dataloader_iter_creation_start(
self, state: State, unit: TPredictUnit
) -> None:
fn_name = "on_predict_dataloader_iter_creation_start"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_predict_dataloader_iter_creation_start(state, unit)

def on_predict_dataloader_iter_creation_end(
self, state: State, unit: TPredictUnit
) -> None:
fn_name = "on_predict_dataloader_iter_creation_end"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_predict_dataloader_iter_creation_end(state, unit)

def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None:
fn_name = "on_predict_get_next_batch_start"
callbacks = self._callbacks.get(fn_name, [])
for cb in callbacks:
cb.on_predict_get_next_batch_start(state, unit)

def on_predict_get_next_batch_end(self, state: State, unit: TPredictUnit) -> None:
fn_name = "on_predict_get_next_batch_end"
callbacks = self._callbacks.get(fn_name, [])
Expand Down
48 changes: 48 additions & 0 deletions torchtnt/framework/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,22 @@ def on_train_epoch_start(self, state: State, unit: TTrainUnit) -> None:
"""Hook called before a new train epoch starts."""
pass

def on_train_dataloader_iter_creation_start(
self, state: State, unit: TTrainUnit
) -> None:
"""Hook called before the dataloader iterator is created."""
pass

def on_train_dataloader_iter_creation_end(
self, state: State, unit: TTrainUnit
) -> None:
"""Hook called after the dataloader iterator is created."""
pass

def on_train_get_next_batch_start(self, state: State, unit: TTrainUnit) -> None:
"""Hook called before getting the data batch for the next train step."""
pass

def on_train_get_next_batch_end(self, state: State, unit: TTrainUnit) -> None:
"""Hook called after getting the data batch for the next train step."""
pass
Expand Down Expand Up @@ -105,6 +121,22 @@ def on_eval_epoch_start(self, state: State, unit: TEvalUnit) -> None:
"""Hook called before a new eval epoch starts."""
pass

def on_eval_dataloader_iter_creation_start(
self, state: State, unit: TEvalUnit
) -> None:
"""Hook called before the dataloader iterator is created."""
pass

def on_eval_dataloader_iter_creation_end(
self, state: State, unit: TEvalUnit
) -> None:
"""Hook called after the dataloader iterator is created."""
pass

def on_eval_get_next_batch_start(self, state: State, unit: TEvalUnit) -> None:
"""Hook called before getting the data batch for the next eval step."""
pass

def on_eval_get_next_batch_end(self, state: State, unit: TEvalUnit) -> None:
"""Hook called after getting the data batch for the next eval step."""
pass
Expand Down Expand Up @@ -133,6 +165,22 @@ def on_predict_epoch_start(self, state: State, unit: TPredictUnit) -> None:
"""Hook called before a new predict epoch starts."""
pass

def on_predict_dataloader_iter_creation_start(
self, state: State, unit: TPredictUnit
) -> None:
"""Hook called before the dataloader iterator is created."""
pass

def on_predict_dataloader_iter_creation_end(
self, state: State, unit: TPredictUnit
) -> None:
"""Hook called after the dataloader iterator is created."""
pass

def on_predict_get_next_batch_start(self, state: State, unit: TPredictUnit) -> None:
"""Hook called before getting the data batch for the next predict step."""
pass

def on_predict_get_next_batch_end(self, state: State, unit: TPredictUnit) -> None:
"""Hook called after getting the data batch for the next predict step."""
pass
Expand Down
27 changes: 27 additions & 0 deletions torchtnt/framework/callbacks/lambda_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,15 @@ def __init__(
] = None,
on_train_start: Optional[Callable[[State, TTrainUnit], None]] = None,
on_train_epoch_start: Optional[Callable[[State, TTrainUnit], None]] = None,
on_train_dataloader_iter_creation_start: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_train_dataloader_iter_creation_end: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_train_get_next_batch_start: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_train_get_next_batch_end: Optional[
Callable[[State, TTrainUnit], None]
] = None,
Expand All @@ -91,13 +100,31 @@ def __init__(
on_train_end: Optional[Callable[[State, TTrainUnit], None]] = None,
on_eval_start: Optional[Callable[[State, TEvalUnit], None]] = None,
on_eval_epoch_start: Optional[Callable[[State, TEvalUnit], None]] = None,
on_eval_dataloader_iter_creation_start: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_eval_dataloader_iter_creation_end: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_eval_get_next_batch_start: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_eval_get_next_batch_end: Optional[Callable[[State, TEvalUnit], None]] = None,
on_eval_step_start: Optional[Callable[[State, TEvalUnit], None]] = None,
on_eval_step_end: Optional[Callable[[State, TEvalUnit], None]] = None,
on_eval_epoch_end: Optional[Callable[[State, TEvalUnit], None]] = None,
on_eval_end: Optional[Callable[[State, TEvalUnit], None]] = None,
on_predict_start: Optional[Callable[[State, TPredictUnit], None]] = None,
on_predict_epoch_start: Optional[Callable[[State, TPredictUnit], None]] = None,
on_predict_dataloader_iter_creation_start: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_predict_dataloader_iter_creation_end: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_predict_get_next_batch_start: Optional[
Callable[[State, TTrainUnit], None]
] = None,
on_predict_get_next_batch_end: Optional[
Callable[[State, TPredictUnit], None]
] = None,
Expand Down
4 changes: 3 additions & 1 deletion torchtnt/framework/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,9 +132,10 @@ def _evaluate_impl(
eval_unit.on_eval_epoch_start(state)
callback_handler.on_eval_epoch_start(state, eval_unit)

callback_handler.on_eval_dataloader_iter_creation_start(state, eval_unit)
with get_timing_context(state, "evaluate.iter(dataloader)"):
data_iter = iter(eval_state.dataloader)
step_input = data_iter
callback_handler.on_eval_dataloader_iter_creation_end(state, eval_unit)

prev_steps_in_epoch = eval_unit.eval_progress.num_steps_completed_in_epoch

Expand All @@ -151,6 +152,7 @@ def _evaluate_impl(
with get_timing_context(
state, "evaluate.next(data_iter)"
), eval_state.iteration_timer.time("data_wait_time"):
callback_handler.on_eval_get_next_batch_start(state, eval_unit)
step_input = eval_unit.get_next_eval_batch(state, data_iter)
callback_handler.on_eval_get_next_batch_end(state, eval_unit)

Expand Down
Loading

0 comments on commit 19497bc

Please sign in to comment.