-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
ema.py
195 lines (172 loc) · 6.98 KB
/
ema.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import math
import paddle
import weakref
from copy import deepcopy
from .utils import get_bn_running_state_names
__all__ = ['ModelEMA', 'SimpleModelEMA']
class ModelEMA(object):
"""
Exponential Weighted Average for Deep Neutal Networks
Args:
model (nn.Layer): Detector of model.
decay (int): The decay used for updating ema parameter.
Ema's parameter are updated with the formula:
`ema_param = decay * ema_param + (1 - decay) * cur_param`.
Defaults is 0.9998.
ema_decay_type (str): type in ['threshold', 'normal', 'exponential'],
'threshold' as default.
cycle_epoch (int): The epoch of interval to reset ema_param and
step. Defaults is -1, which means not reset. Its function is to
add a regular effect to ema, which is set according to experience
and is effective when the total training epoch is large.
ema_black_list (set|list|tuple, optional): The custom EMA black_list.
Blacklist of weight names that will not participate in EMA
calculation. Default: None.
"""
def __init__(self,
model,
decay=0.9998,
ema_decay_type='threshold',
cycle_epoch=-1,
ema_black_list=None,
ema_filter_no_grad=False):
self.step = 0
self.epoch = 0
self.decay = decay
self.ema_decay_type = ema_decay_type
self.cycle_epoch = cycle_epoch
self.ema_black_list = self._match_ema_black_list(
model.state_dict().keys(), ema_black_list)
bn_states_names = get_bn_running_state_names(model)
if ema_filter_no_grad:
for n, p in model.named_parameters():
if p.stop_gradient and n not in bn_states_names:
self.ema_black_list.add(n)
self.state_dict = dict()
for k, v in model.state_dict().items():
if k in self.ema_black_list:
self.state_dict[k] = v
else:
self.state_dict[k] = paddle.zeros_like(v, dtype='float32')
self._model_state = {
k: weakref.ref(p)
for k, p in model.state_dict().items()
}
def reset(self):
self.step = 0
self.epoch = 0
for k, v in self.state_dict.items():
if k in self.ema_black_list:
self.state_dict[k] = v
else:
self.state_dict[k] = paddle.zeros_like(v)
def resume(self, state_dict, step=0):
for k, v in state_dict.items():
if k in self.state_dict:
if self.state_dict[k].dtype == v.dtype:
self.state_dict[k] = v
else:
self.state_dict[k] = v.astype(self.state_dict[k].dtype)
self.step = step
def update(self, model=None):
if self.ema_decay_type == 'threshold':
decay = min(self.decay, (1 + self.step) / (10 + self.step))
elif self.ema_decay_type == 'exponential':
decay = self.decay * (1 - math.exp(-(self.step + 1) / 2000))
else:
decay = self.decay
self._decay = decay
if model is not None:
model_dict = model.state_dict()
else:
model_dict = {k: p() for k, p in self._model_state.items()}
assert all(
[v is not None for _, v in model_dict.items()]), 'python gc.'
for k, v in self.state_dict.items():
if k not in self.ema_black_list:
v = decay * v + (1 - decay) * model_dict[k].astype('float32')
v.stop_gradient = True
self.state_dict[k] = v
self.step += 1
def apply(self):
if self.step == 0:
return self.state_dict
state_dict = dict()
model_dict = {k: p() for k, p in self._model_state.items()}
for k, v in self.state_dict.items():
if k in self.ema_black_list:
v.stop_gradient = True
state_dict[k] = v
else:
if self.ema_decay_type != 'exponential':
v = v / (1 - self._decay**self.step)
v = v.astype(model_dict[k].dtype)
v.stop_gradient = True
state_dict[k] = v
self.epoch += 1
if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
self.reset()
return state_dict
def _match_ema_black_list(self, weight_name, ema_black_list=None):
out_list = set()
if ema_black_list:
for name in weight_name:
for key in ema_black_list:
if key in name:
out_list.add(name)
return out_list
class SimpleModelEMA(object):
"""
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model=None, decay=0.9996):
"""
Args:
model (nn.Module): model to apply EMA.
decay (float): ema decay reate.
"""
self.model = deepcopy(model)
self.decay = decay
def update(self, model, decay=None):
if decay is None:
decay = self.decay
with paddle.no_grad():
state = {}
msd = model.state_dict()
for k, v in self.model.state_dict().items():
if paddle.is_floating_point(v):
v *= decay
v += (1.0 - decay) * msd[k].detach()
state[k] = v
self.model.set_state_dict(state)
def resume(self, state_dict, step=0):
state = {}
msd = state_dict
for k, v in self.model.state_dict().items():
if paddle.is_floating_point(v):
v = msd[k].detach()
state[k] = v
self.model.set_state_dict(state)
self.step = step