From 6d35762281ad80d6f17a8a18ca8a58e5dac75e33 Mon Sep 17 00:00:00 2001 From: Joe Early Date: Thu, 20 Jun 2024 14:21:17 +0100 Subject: [PATCH] Fixed batch referenced before assignment --- composer/trainer/trainer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 77fdc538442..ad1749ed107 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3644,6 +3644,12 @@ def _iter_dataloader(self, trainer_mode: TrainerMode): # 0 = not finished, 1 = finished (using integer tensors so we can use dist.all_reduce) iter_finished = torch.zeros(1, dtype=torch.uint8) iter_finished = self.state.device.tensor_to_device(iter_finished) + + # Initialize batch to avoid "referenced before assignment" warnings + # Unique sentinel value to differentiate uninitialized state and dataloader yielding None + sentinel = object() + batch = sentinel + while True: try: # [BEFORE/AFTER]_DATALOADER only runs while training @@ -3668,6 +3674,10 @@ def _iter_dataloader(self, trainer_mode: TrainerMode): if iter_finished.item() == 1: break + if batch is sentinel: + raise RuntimeError( + "Batch should have been assigned or loop should have been broken. This shouldn't happen!" + ) yield batch def _use_closures(self) -> bool: