diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index fe9cd3d93b..6e863a5467 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -136,9 +136,9 @@ def __call__(self, state: State, ssr: float = 1.0) -> float: t_cooldown_start = t_warmup if state.timestamp < t_cooldown_start: - # rescale LR by a coeff equal to the inverse square root of the time + # Rescale LR by a coefficient equal to the inverse square root of the time # elapsed after warmup, rescaled by the time scale, such that, at - # infinite time, the LR decays to alpha_f_decay + # infinite time, the LR decays to alpha_f_decay. coeff = 1 / ((current_time + t_shift) / t_scale).value**0.5 current_factor = (self.alpha_f_decay + coeff * (1.0 - self.alpha_f_decay))