From 8384fdc4137ac0bf0e3f0cae8bda1c1d7a1eb88a Mon Sep 17 00:00:00 2001 From: Chuck Tang Date: Tue, 18 Jun 2024 17:45:22 -0700 Subject: [PATCH] commit change --- composer/trainer/trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 4447698bebb..3ab6e2f9ca3 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2686,10 +2686,10 @@ def _train_loop(self) -> None: def _eval_train_metrics(self, device_batch): assert self._train_data_spec is not None, 'The train data spec should be set on __init__ or fit()' assert self.state.train_metrics is not None, 'The train metrics should be set on __init__ or fit()' - + precision = self.state.precision if self.state.precision is Precision.AMP_FP8 else Precision.AMP_BF16 with torch.no_grad(),\ model_eval_mode(self.state.model),\ - _get_precision_context(self.state.precision, self.state.precision_config, self.state.deepspeed_enabled): + _get_precision_context(precision, self.state.precision_config, self.state.deepspeed_enabled): eval_outputs = self._original_model.eval_forward(device_batch, self.state.outputs) for metric in self.state.train_metrics.values(): self._original_model.update_metric( @@ -3484,9 +3484,9 @@ def _eval_loop( )[0] self.engine.run_event(Event.EVAL_BEFORE_FORWARD) - + precision = self.state.precision if self.state.precision is Precision.AMP_FP8 else Precision.AMP_BF16 with _get_precision_context( - self.state.precision, + precision, self.state.precision_config, self.state.deepspeed_enabled, ): @@ -3501,7 +3501,7 @@ def _eval_loop( # Run in same precision context to avoid NaNs with _get_precision_context( - self.state.precision, + precision, self.state.precision_config, self.state.deepspeed_enabled, ):