Skip to content

Commit

Permalink
add ic statements
Browse files Browse the repository at this point in the history
  • Loading branch information
ez2rok committed Sep 5, 2024
1 parent b06337b commit c720ac5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
12 changes: 9 additions & 3 deletions TEST.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@


from tests.trainer.test_fsdp_checkpoint import _compare_model_params_between_state_dicts, _compare_optims_between_state_dicts, _compare_metrics_between_state_dicts, get_trainer
from icecream import install
from icecream import ic

install()
ic.configureOutput(includeContext=True)

def test_1(use_tp: bool):

Expand Down Expand Up @@ -104,6 +109,7 @@ def test_1(use_tp: bool):
assert trainer1.state.tp_config is not None
assert isinstance(trainer1.state.tp_config, TPConfig)

ic('Before trainer 1 fit')
print('Before trainer 1 fit')
trainer1.fit()
print('After trainer 1 fit')
Expand Down Expand Up @@ -146,9 +152,9 @@ def test_1(use_tp: bool):


if __name__ == '__main__':
print('*'*70, '\nuse_tp=False\n', '*'*70)
test_1(use_tp=False)
print('*'*70, '\nDone\n', '*'*70)
# print('*'*70, '\nuse_tp=False\n', '*'*70)
# test_1(use_tp=False)
# print('*'*70, '\nDone\n', '*'*70)

print('*'*70, '\nuse_tp=True\n', '*'*70)
test_1(use_tp=True)
Expand Down
2 changes: 1 addition & 1 deletion composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def run_event(
traces (Traces): Ordered dictionary of trace for each algorithm.
"""
duration_marker = None
event = Event(event)
event = ic(Event(event))

self._debug_log(event, 'Running event')

Expand Down
4 changes: 2 additions & 2 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2667,7 +2667,6 @@ def _train_loop(self) -> None:
self.state.batch = fix_batch_precision_for_deepspeed(self.state.batch, self.state.precision)

self.engine.run_event(Event.AFTER_DATALOADER)

self.engine.run_event(Event.BATCH_START)

# Log time values
Expand Down Expand Up @@ -2736,8 +2735,9 @@ def _train_loop(self) -> None:
duration = datetime.datetime.now() - last_wct
self._run_evaluators(Event.BATCH_END)
last_wct = datetime.datetime.now() - duration

ic('before')
self.engine.run_event(Event.BATCH_CHECKPOINT)
ic('after')

if (
self.state.timestamp >= self.state.max_duration or (
Expand Down

0 comments on commit c720ac5

Please sign in to comment.