diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 77fdc53844..1df5406aba 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -3644,6 +3644,12 @@ def _iter_dataloader(self, trainer_mode: TrainerMode): # 0 = not finished, 1 = finished (using integer tensors so we can use dist.all_reduce) iter_finished = torch.zeros(1, dtype=torch.uint8) iter_finished = self.state.device.tensor_to_device(iter_finished) + + # Initialize batch to avoid "referenced before assignment" warnings + # Unique sentinel value to differentiate uninitialized state and dataloader yielding None + sentinel = object() + batch = sentinel + while True: try: # [BEFORE/AFTER]_DATALOADER only runs while training @@ -3668,6 +3674,10 @@ def _iter_dataloader(self, trainer_mode: TrainerMode): if iter_finished.item() == 1: break + if batch is sentinel: + raise RuntimeError( + "Batch should have been assigned or loop should have been broken. This shouldn't happen!", + ) yield batch def _use_closures(self) -> bool: