Skip to content

Commit

Permalink
Fix progress off by one.
Browse files Browse the repository at this point in the history
Progress should be measured by number of gradient steps taken and be equal to 1.0 at end of training. This change also reconciles train progress with eval progress.

PiperOrigin-RevId: 667916045
  • Loading branch information
timothyn617 authored and KfacJaxDev committed Aug 27, 2024
1 parent b59b188 commit 615e709
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions examples/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,10 @@ def train_step(self, global_step: Array, rng: PRNGKey) -> dict[str, Numeric]:
if "aux" in stats:
stats.update(stats.pop("aux", {}))

stats["progress"] = self.progress(self._python_step)

self._python_step += 1

stats["progress"] = self.progress(self._python_step)

for name in self.config.get("per_device_stats_to_log", []):
gathered_stat = jnp.reshape(
kfac_jax.utils.host_all_gather(stats[name]), [-1]
Expand Down

0 comments on commit 615e709

Please sign in to comment.