-
Notifications
You must be signed in to change notification settings - Fork 5
/
scheduler.py
313 lines (265 loc) · 13.3 KB
/
scheduler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
from torch.optim import SGD, lr_scheduler
import numpy as np
# Implement from https://github.com/timesler/lr-momentum-scheduler
class _LRMomentumScheduler(lr_scheduler._LRScheduler):
def __init__(self, optimizer, last_epoch=-1):
if last_epoch == -1:
for group in optimizer.param_groups:
group.setdefault('initial_momentum', group['momentum'])
else:
for i, group in enumerate(optimizer.param_groups):
if 'initial_momentum' not in group:
raise KeyError("param 'initial_momentum' is not specified "
"in param_groups[{}] when resuming an optimizer".format(i))
self.base_momentums = list(map(lambda group: group['initial_momentum'], optimizer.param_groups))
super().__init__(optimizer, last_epoch)
def get_lr(self):
raise NotImplementedError
def get_momentum(self):
raise NotImplementedError
def step(self, epoch=None):
if epoch is None:
epoch = self.last_epoch + 1
self.last_epoch = epoch
for param_group, lr, momentum in zip(self.optimizer.param_groups, self.get_lr(), self.get_momentum()):
param_group['lr'] = lr
param_group['momentum'] = momentum
class ParameterUpdate(object):
"""A callable class used to define an arbitrary schedule defined by a list.
This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply
the given schedule.
Arguments:
params {list or numpy.array} -- List or numpy array defining parameter schedule.
base_param {float} -- Parameter value used to initialize the optimizer.
"""
def __init__(self, params, base_param):
self.params = np.hstack([params, 0])
self.base_param = base_param
def __call__(self, epoch):
return self.params[epoch] / self.base_param
def apply_lambda(last_epoch, bases, lambdas):
return [base * lmbda(last_epoch) for lmbda, base in zip(lambdas, bases)]
class LambdaScheduler(_LRMomentumScheduler):
"""Sets the learning rate and momentum of each parameter group to the initial lr and momentum
times a given function. When last_epoch=-1, sets initial lr and momentum to the optimizer
values.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr_lambda (function or list): A function which computes a multiplicative
factor given an integer parameter epoch, or a list of such
functions, one for each group in optimizer.param_groups.
Default: lambda x:x.
momentum_lambda (function or list): As for lr_lambda but applied to momentum.
Default: lambda x:x.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer has two groups.
>>> lr_lambda = [
... lambda epoch: epoch // 30,
... lambda epoch: 0.95 ** epoch
... ]
>>> mom_lambda = [
... lambda epoch: max(0, (50 - epoch) // 50),
... lambda epoch: 0.99 ** epoch
... ]
>>> scheduler = LambdaScheduler(optimizer, lr_lambda, mom_lambda)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, lr_lambda=lambda x:x, momentum_lambda=lambda x:x, last_epoch=-1):
self.optimizer = optimizer
if not isinstance(lr_lambda, (list, tuple)):
self.lr_lambdas = [lr_lambda] * len(optimizer.param_groups)
else:
if len(lr_lambda) != len(optimizer.param_groups):
raise ValueError("Expected {} lr_lambdas, but got {}".format(
len(optimizer.param_groups), len(lr_lambda)))
self.lr_lambdas = list(lr_lambda)
if not isinstance(momentum_lambda, (list, tuple)):
self.momentum_lambdas = [momentum_lambda] * len(optimizer.param_groups)
else:
if len(momentum_lambda) != len(optimizer.param_groups):
raise ValueError("Expected {} momentum_lambdas, but got {}".format(
len(optimizer.param_groups), len(momentum_lambda)))
self.momentum_lambdas = list(momentum_lambda)
self.last_epoch = last_epoch
super().__init__(optimizer, last_epoch)
def state_dict(self):
"""Returns the state of the scheduler as a :class:`dict`.
It contains an entry for every variable in self.__dict__ which
is not the optimizer.
The learning rate and momentum lambda functions will only be saved if they are
callable objects and not if they are functions or lambdas.
"""
state_dict = {key: value for key, value in self.__dict__.items()
if key not in ('optimizer', 'lr_lambdas', 'momentum_lambdas')}
state_dict['lr_lambdas'] = [None] * len(self.lr_lambdas)
state_dict['momentum_lambdas'] = [None] * len(self.momentum_lambdas)
for idx, (lr_fn, mom_fn) in enumerate(zip(self.lr_lambdas, self.momentum_lambdas)):
if not isinstance(lr_fn, types.FunctionType):
state_dict['lr_lambdas'][idx] = lr_fn.__dict__.copy()
if not isinstance(mom_fn, types.FunctionType):
state_dict['momentum_lambdas'][idx] = mom_fn.__dict__.copy()
return state_dict
def load_state_dict(self, state_dict):
"""Loads the schedulers state.
Arguments:
state_dict (dict): scheduler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
lr_lambdas = state_dict.pop('lr_lambdas')
momentum_lambdas = state_dict.pop('momentum_lambdas')
self.__dict__.update(state_dict)
for idx, fn in enumerate(lr_lambdas):
if fn is not None:
self.lr_lambdas[idx].__dict__.update(fn)
for idx, fn in enumerate(momentum_lambdas):
if fn is not None:
self.momentum_lambdas[idx].__dict__.update(fn)
def get_lr(self):
return apply_lambda(self.last_epoch, self.base_lrs, self.lr_lambdas)
def get_momentum(self):
return apply_lambda(self.last_epoch, self.base_momentums, self.momentum_lambdas)
class ParameterUpdate(object):
"""A callable class used to define an arbitrary schedule defined by a list.
This object is designed to be passed to the LambdaLR or LambdaScheduler scheduler to apply
the given schedule. If a base_param is zero, no updates are applied.
Arguments:
params {list or numpy.array} -- List or numpy array defining parameter schedule.
base_param {float} -- Parameter value used to initialize the optimizer.
"""
def __init__(self, params, base_param):
self.params = np.hstack([params, 0])
self.base_param = base_param
if base_param < 1e-12:
self.base_param = 1
self.params = self.params * 0.0 + 1.0
def __call__(self, epoch):
return self.params[epoch] / self.base_param
class ListScheduler(LambdaScheduler):
"""Sets the learning rate and momentum of each parameter group to values defined by lists.
When last_epoch=-1, sets initial lr and momentum to the optimizer values. One of both of lr
and momentum schedules may be specified.
Note that the parameters used to initialize the optimizer are overriden by those defined by
this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
lrs (list or numpy.ndarray): A list of learning rates, or a list of lists, one for each
parameter group. One- or two-dimensional numpy arrays may also be passed.
momentum (list or numpy.ndarray): A list of momentums, or a list of lists, one for each
parameter group. One- or two-dimensional numpy arrays may also be passed.
last_epoch (int): The index of last epoch. Default: -1.
Example:
>>> # Assuming optimizer has two groups.
>>> lrs = [
... np.linspace(0.01, 0.1, 100),
... np.logspace(-2, 0, 100)
... ]
>>> momentums = [
... np.linspace(0.85, 0.95, 100),
... np.linspace(0.8, 0.99, 100)
... ]
>>> scheduler = ListScheduler(optimizer, lrs, momentums)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, lrs=None, momentums=None, last_epoch=-1):
groups = optimizer.param_groups
if lrs is None:
lr_lambda = lambda x: x
else:
lrs = np.array(lrs) if isinstance(lrs, (list, tuple)) else lrs
if len(lrs.shape) == 1:
lr_lambda = [ParameterUpdate(lrs, g['lr']) for g in groups]
else:
lr_lambda = [ParameterUpdate(l, g['lr']) for l, g in zip(lrs, groups)]
if momentums is None:
momentum_lambda = lambda x: x
else:
momentums = np.array(momentums) if isinstance(momentums, (list, tuple)) else momentums
if len(momentums.shape) == 1:
momentum_lambda = [ParameterUpdate(momentums, g['momentum']) for g in groups]
else:
momentum_lambda = [ParameterUpdate(l, g['momentum']) for l, g in zip(momentums, groups)]
super().__init__(optimizer, lr_lambda, momentum_lambda)
class RangeFinder(ListScheduler):
"""Scheduler class that implements the LR range search specified in:
A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch
size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820.
Logarithmically spaced learning rates from 1e-7 to 1 are searched. The number of increments in
that range is determined by 'epochs'.
Note that the parameters used to initialize the optimizer are overriden by those defined by
this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
epochs (int): Number of epochs over which to run test.
Example:
>>> scheduler = RangeFinder(optimizer, 100)
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
>>> scheduler.step()
"""
def __init__(self, optimizer, epochs):
lrs = np.logspace(-7, 0, epochs)
super().__init__(optimizer, lrs)
class OneCyclePolicy(ListScheduler):
"""Scheduler class that implements the 1cycle policy search specified in:
A disciplined approach to neural network hyper-parameters: Part 1 -- learning rate, batch
size, momentum, and weight decay. Leslie N. Smith, 2018, arXiv:1803.09820.
Args:
optimizer (Optimizer): Wrapped optimizer.
lr (float or list). Maximum learning rate in range. If a list of values is passed, they
should correspond to parameter groups.
epochs (int): The number of epochs to use during search.
momentum_rng (list). Optional upper and lower momentum values (may be both equal). Set to
None to run without momentum. Default: [0.85, 0.95]. If a list of lists is passed, they
should correspond to parameter groups.
phase_ratio (float): Fraction of epochs used for the increasing and decreasing phase of
the schedule. For example, if phase_ratio=0.45 and epochs=100, the learning rate will
increase from lr/10 to lr over 45 epochs, then decrease back to lr/10 over 45 epochs,
then decrease to lr/100 over the remaining 10 epochs. Default: 0.45.
"""
def __init__(self, optimizer, lr, epochs, momentum_rng=[0.85, 0.95], phase_ratio=0.45):
phase_epochs = int(phase_ratio * epochs)
if isinstance(lr, (list, tuple)):
lrs = [
np.hstack([
np.linspace(l * 1e-1, l, phase_epochs),
np.linspace(l, l * 1e-1, phase_epochs),
np.linspace(l * 1e-1, l * 1e-2, epochs - 2 * phase_epochs),
]) for l in lr
]
else:
lrs = np.hstack([
np.linspace(lr * 1e-1, lr, phase_epochs),
np.linspace(lr, lr * 1e-1, phase_epochs),
np.linspace(lr * 1e-1, lr * 1e-2, epochs - 2 * phase_epochs),
])
if momentum_rng is not None:
momentum_rng = np.array(momentum_rng)
if len(momentum_rng.shape) == 2:
for i, g in enumerate(optimizer.param_groups):
g['momentum'] = momentum_rng[i][1]
momentums = [
np.hstack([
np.linspace(m[1], m[0], phase_epochs),
np.linspace(m[0], m[1], phase_epochs),
np.linspace(m[1], m[1], epochs - 2 * phase_epochs),
]) for m in momentum_rng
]
else:
for i, g in enumerate(optimizer.param_groups):
g['momentum'] = momentum_rng[1]
momentums = np.hstack([
np.linspace(momentum_rng[1], momentum_rng[0], phase_epochs),
np.linspace(momentum_rng[0], momentum_rng[1], phase_epochs),
np.linspace(momentum_rng[1], momentum_rng[1], epochs - 2 * phase_epochs),
])
else:
momentums = None
super().__init__(optimizer, lrs, momentums)