diff --git a/examples/training.py b/examples/training.py index 97d10bb..944e89c 100644 --- a/examples/training.py +++ b/examples/training.py @@ -567,10 +567,10 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]: if "aux" in stats: stats.update(stats.pop("aux", {})) - stats["progress"] = self.progress(self._python_step) - self._python_step += 1 + stats["progress"] = self.progress(self._python_step) + for name in self.config.get("per_device_stats_to_log", []): gathered_stat = jnp.reshape( kfac_jax.utils.host_all_gather(stats[name]), [-1]