diff --git a/llmfoundry/optim/scheduler.py b/llmfoundry/optim/scheduler.py index af7e494122..89b2446b67 100644 --- a/llmfoundry/optim/scheduler.py +++ b/llmfoundry/optim/scheduler.py @@ -5,7 +5,7 @@ import warnings from typing import Union -from composer.core import State, Time +from composer.core import State, Time, TimeUnit from composer.optim import ComposerScheduler, LinearScheduler from composer.optim.scheduler import _convert_time @@ -23,6 +23,14 @@ def _raise_if_units_dont_match(time: Union[str, Time], 'All time units must be the same as max_duration units.') +def _raise_if_units_dur(time: Union[str, Time]) -> None: + if isinstance(time, str): + time = Time.from_timestring(time) + if time.unit == TimeUnit('dur'): + raise ValueError( + 't_warmup, t_scale, and t_cooldown cannot be in units of "dur".') + + class InverseSquareRootWithWarmupScheduler(ComposerScheduler): r"""Inverse square root LR decay with warmup and optional linear cooldown. @@ -80,6 +88,9 @@ def __init__(self, scale_warmup: bool = False): if alpha_f_decay < alpha_f_cooldown: raise ValueError('Required: alpha_f_decay >= alpha_f_cooldown.') + _raise_if_units_dur(t_warmup) + _raise_if_units_dur(t_scale) + _raise_if_units_dur(t_cooldown) self.t_warmup = t_warmup self.t_scale = t_scale self.t_cooldown = t_cooldown