diff --git a/TEST.py b/TEST.py index dd81f3e802..9fa9d76fea 100644 --- a/TEST.py +++ b/TEST.py @@ -41,6 +41,11 @@ from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts, get_trainer +from icecream import install +from icecream import ic + +install() +ic.configureOutput(includeContext=True) def test_1(use_tp: bool): @@ -104,6 +109,7 @@ def test_1(use_tp: bool): assert trainer1.state.tp_config is not None assert isinstance(trainer1.state.tp_config, TPConfig) + ic('Before trainer 1 fit') print('Before trainer 1 fit') trainer1.fit() print('After trainer 1 fit') @@ -146,9 +152,9 @@ def test_1(use_tp: bool): if __name__ == '__main__': - print('*'*70, '\nuse_tp=False\n', '*'*70) - test_1(use_tp=False) - print('*'*70, '\nDone\n', '*'*70) + # print('*'*70, '\nuse_tp=False\n', '*'*70) + # test_1(use_tp=False) + # print('*'*70, '\nDone\n', '*'*70) print('*'*70, '\nuse_tp=True\n', '*'*70) test_1(use_tp=True) diff --git a/composer/core/engine.py b/composer/core/engine.py index 2c27e0ec51..a6b0e4dad6 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -260,7 +260,7 @@ def run_event( traces (Traces): Ordered dictionary of trace for each algorithm. """ duration_marker = None - event = Event(event) + event = ic(Event(event)) self._debug_log(event, 'Running event') diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 815aa50001..4215721118 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2667,7 +2667,6 @@ def _train_loop(self) -> None: self.state.batch = fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision) self.engine.run_event(Event.AFTER_DATALOADER) - self.engine.run_event(Event.BATCH_START) # Log time values @@ -2736,8 +2735,9 @@ def _train_loop(self) -> None: duration = datetime.datetime.now() - last_wct self._run_evaluators(Event.BATCH_END) last_wct = datetime.datetime.now() - duration - + ic('before') self.engine.run_event(Event.BATCH_CHECKPOINT) + ic('after') if ( self.state.timestamp >= self.state.max_duration or (