Skip to content

Commit

Permalink
scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
mansheej committed Oct 9, 2023
1 parent 54007d3 commit 1f9ac7a
Showing 1 changed file with 12 additions and 1 deletion.
13 changes: 12 additions & 1 deletion llmfoundry/optim/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings
from typing import Union

from composer.core import State, Time
from composer.core import State, Time, TimeUnit
from composer.optim import ComposerScheduler, LinearScheduler
from composer.optim.scheduler import _convert_time

Expand All @@ -23,6 +23,14 @@ def _raise_if_units_dont_match(time: Union[str, Time],
'All time units must be the same as max_duration units.')


def _raise_if_units_dur(time: Union[str, Time]) -> None:
if isinstance(time, str):
time = Time.from_timestring(time)
if time.unit == TimeUnit('dur'):
raise ValueError(
't_warmup, t_scale, and t_cooldown cannot be in units of "dur".')


class InverseSquareRootWithWarmupScheduler(ComposerScheduler):
r"""Inverse square root LR decay with warmup and optional linear cooldown.
Expand Down Expand Up @@ -80,6 +88,9 @@ def __init__(self,
scale_warmup: bool = False):
if alpha_f_decay < alpha_f_cooldown:
raise ValueError('Required: alpha_f_decay >= alpha_f_cooldown.')
_raise_if_units_dur(t_warmup)
_raise_if_units_dur(t_scale)
_raise_if_units_dur(t_cooldown)
self.t_warmup = t_warmup
self.t_scale = t_scale
self.t_cooldown = t_cooldown
Expand Down

0 comments on commit 1f9ac7a

Please sign in to comment.