forked from Smerity/sha-rnn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lookahead.py
36 lines (33 loc) · 1.36 KB
/
lookahead.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import itertools as it
from torch.optim import Optimizer
class Lookahead(Optimizer):
def __init__(self, base_optimizer,alpha=0.5, k=6):
if not 0.0 <= alpha <= 1.0:
raise ValueError(f'Invalid slow update rate: {alpha}')
if not 1 <= k:
raise ValueError(f'Invalid lookahead steps: {k}')
self.optimizer = base_optimizer
self.param_groups = self.optimizer.param_groups
self.alpha = alpha
self.k = k
for group in self.param_groups:
group["step_counter"] = 0
self.slow_weights = [[p.clone().detach() for p in group['params']]
for group in self.param_groups]
for w in it.chain(*self.slow_weights):
w.requires_grad = False
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
loss = self.optimizer.step()
for group,slow_weights in zip(self.param_groups,self.slow_weights):
group['step_counter'] += 1
if group['step_counter'] % self.k != 0:
continue
for p,q in zip(group['params'],slow_weights):
if p.grad is None:
continue
q.data.add_(self.alpha,p.data - q.data)
p.data.copy_(q.data)
return loss