From 615e709c988c0dfed459b5750f00688e00d48f88 Mon Sep 17 00:00:00 2001 From: Timothy Nguyen Date: Tue, 27 Aug 2024 03:24:04 -0700 Subject: [PATCH] Fix progress off by one. Progress should be measured by number of gradient steps taken and be equal to 1.0 at end of training. This change also reconciles train progress with eval progress. PiperOrigin-RevId: 667916045 --- examples/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/training.py b/examples/training.py index 97d10bb..944e89c 100644 --- a/examples/training.py +++ b/examples/training.py @@ -567,10 +567,10 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]: if "aux" in stats: stats.update(stats.pop("aux", {})) - stats["progress"] = self.progress(self._python_step) - self._python_step += 1 + stats["progress"] = self.progress(self._python_step) + for name in self.config.get("per_device_stats_to_log", []): gathered_stat = jnp.reshape( kfac_jax.utils.host_all_gather(stats[name]), [-1]