Skip to content

Commit

Permalink
more fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
mansheej committed Oct 10, 2023
1 parent 7a17279 commit d7d49f9
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ def __init__(self,
alpha_f_cooldown: float = 0.0,
scale_warmup: bool = False) -> None:
if alpha_f_decay < alpha_f_cooldown:
raise ValueError(
f'Required: alpha_f_decay >= alpha_f_cooldown. Current: alpha_f_decay={alpha_f_decay}, alpha_f_cooldown={alpha_f_cooldown}'
)
raise ValueError(('Required: alpha_f_decay >= alpha_f_cooldown. '
f'Current: alpha_f_decay={alpha_f_decay}, '
f'alpha_f_cooldown={alpha_f_cooldown}.'))
_raise_if_units_dur(t_warmup, 't_warmup')
_raise_if_units_dur(t_scale, 't_scale')
_raise_if_units_dur(t_cooldown, 't_cooldown')
Expand Down Expand Up @@ -131,6 +131,7 @@ def __call__(self, state: State, ssr: float = 1.0) -> float:
current_time = state.timestamp.get(t_scale.unit)

t_shift = t_scale - t_warmup
# t_cooldown_start = max(t_warmup, t_max - t_cooldown)
t_cooldown_start = t_max - t_cooldown
if t_cooldown_start < t_warmup:
t_cooldown_start = t_warmup
Expand Down

0 comments on commit d7d49f9

Please sign in to comment.