-
Notifications
You must be signed in to change notification settings - Fork 70
/
subfuncs.py
66 lines (52 loc) · 2.34 KB
/
subfuncs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from __future__ import division
from chainer.training import extension
class VaswaniRule(extension.Extension):
"""Trainer extension to shift an optimizer attribute magically by Vaswani.
Args:
attr (str): Name of the attribute to shift.
rate (float): Rate of the exponential shift. This value is multiplied
to the attribute at each call.
init (float): Initial value of the attribute. If it is ``None``, the
extension extracts the attribute at the first call and uses it as
the initial value.
target (float): Target value of the attribute. If the attribute reaches
this value, the shift stops.
optimizer (~chainer.Optimizer): Target optimizer to adjust the
attribute. If it is ``None``, the main optimizer of the updater is
used.
"""
def __init__(self, attr, d, warmup_steps=4000,
init=None, target=None, optimizer=None,
scale=1.):
self._attr = attr
self._d_inv05 = d ** (-0.5) * scale
self._warmup_steps_inv15 = warmup_steps ** (-1.5)
self._init = init
self._target = target
self._optimizer = optimizer
self._t = 0
self._last_value = None
def initialize(self, trainer):
optimizer = self._get_optimizer(trainer)
# ensure that _init is set
if self._init is None:
# self._init = getattr(optimizer, self._attr)
self._init = self._d_inv05 * (1. * self._warmup_steps_inv15)
if self._last_value is not None: # resuming from a snapshot
self._update_value(optimizer, self._last_value)
else:
self._update_value(optimizer, self._init)
def __call__(self, trainer):
self._t += 1
optimizer = self._get_optimizer(trainer)
value = self._d_inv05 * \
min(self._t ** (-0.5), self._t * self._warmup_steps_inv15)
self._update_value(optimizer, value)
def serialize(self, serializer):
self._t = serializer('_t', self._t)
self._last_value = serializer('_last_value', self._last_value)
def _get_optimizer(self, trainer):
return self._optimizer or trainer.updater.get_optimizer('main')
def _update_value(self, optimizer, value):
setattr(optimizer, self._attr, value)
self._last_value = value