Skip to content

Commit

Permalink
commit change
Browse files Browse the repository at this point in the history
  • Loading branch information
Chuck Tang committed Jun 19, 2024
1 parent 2bf78b9 commit 8384fdc
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2686,10 +2686,10 @@ def _train_loop(self) -> None:
def _eval_train_metrics(self, device_batch):
assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()'
assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()'

precision = self.state.precision if self.state.precision is Precision.AMP_FP8 else Precision.AMP_BF16
with torch.no_grad(),\
model_eval_mode(self.state.model),\
_get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled):
_get_precision_context(precision, self.state.precision_config, self.state.deepspeed_enabled):
eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs)
for metric in self.state.train_metrics.values():
self._original_model.update_metric(
Expand Down Expand Up @@ -3484,9 +3484,9 @@ def _eval_loop(
)[0]

self.engine.run_event(Event.EVAL_BEFORE_FORWARD)

precision = self.state.precision if self.state.precision is Precision.AMP_FP8 else Precision.AMP_BF16
with _get_precision_context(
self.state.precision,
precision,
self.state.precision_config,
self.state.deepspeed_enabled,
):
Expand All @@ -3501,7 +3501,7 @@ def _eval_loop(

# Run in same precision context to avoid NaNs
with _get_precision_context(
self.state.precision,
precision,
self.state.precision_config,
self.state.deepspeed_enabled,
):
Expand Down

0 comments on commit 8384fdc

Please sign in to comment.