forked from homink/deepspeech.pytorch.ko
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
324 lines (280 loc) · 12.6 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import math
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.autograd import Variable
supported_rnns = {
'lstm': nn.LSTM,
'rnn': nn.RNN,
'gru': nn.GRU
}
supported_rnns_inv = dict((v, k) for k, v in supported_rnns.items())
class SequenceWise(nn.Module):
def __init__(self, module):
"""
Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
Allows handling of variable sequence lengths and minibatch sizes.
:param module: Module to apply input to.
"""
super(SequenceWise, self).__init__()
self.module = module
def forward(self, x):
t, n = x.size(0), x.size(1)
x = x.view(t * n, -1)
x = self.module(x)
x = x.view(t, n, -1)
return x
def __repr__(self):
tmpstr = self.__class__.__name__ + ' (\n'
tmpstr += self.module.__repr__()
tmpstr += ')'
return tmpstr
class InferenceBatchSoftmax(nn.Module):
def forward(self, input_):
if not self.training:
return F.softmax(input_, dim=-1)
else:
return input_
class BatchRNN(nn.Module):
def __init__(self, input_size, hidden_size, rnn_type=nn.LSTM, bidirectional=False, batch_norm=True):
super(BatchRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.bidirectional = bidirectional
self.batch_norm = SequenceWise(nn.BatchNorm1d(input_size)) if batch_norm else None
self.rnn = rnn_type(input_size=input_size, hidden_size=hidden_size,
bidirectional=bidirectional, bias=False)
self.num_directions = 2 if bidirectional else 1
def flatten_parameters(self):
self.rnn.flatten_parameters()
def forward(self, x):
if self.batch_norm is not None:
x = self.batch_norm(x)
x, _ = self.rnn(x)
if self.bidirectional:
x = x.view(x.size(0), x.size(1), 2, -1).sum(2).view(x.size(0), x.size(1), -1) # (TxNxH*2) -> (TxNxH) by sum
return x
class Lookahead(nn.Module):
# Wang et al 2016 - Lookahead Convolution Layer for Unidirectional Recurrent Neural Networks
# input shape - sequence, batch, feature - TxNxH
# output shape - same as input
def __init__(self, n_features, context):
# should we handle batch_first=True?
super(Lookahead, self).__init__()
self.n_features = n_features
self.weight = Parameter(torch.Tensor(n_features, context + 1))
assert context > 0
self.context = context
self.register_parameter('bias', None)
self.init_parameters()
def init_parameters(self): # what's a better way initialiase this layer?
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
def forward(self, input):
seq_len = input.size(0)
# pad the 0th dimension (T/sequence) with zeroes whose number = context
# Once pytorch's padding functions have settled, should move to those.
padding = torch.zeros(self.context, *(input.size()[1:])).type_as(input.data)
x = torch.cat((input, Variable(padding)), 0)
# add lookahead windows (with context+1 width) as a fourth dimension
# for each seq-batch-feature combination
x = [x[i:i + self.context + 1] for i in range(seq_len)] # TxLxNxH - sequence, context, batch, feature
x = torch.stack(x)
x = x.permute(0, 2, 3, 1) # TxNxHxL - sequence, batch, feature, context
x = torch.mul(x, self.weight).sum(dim=3)
return x
def __repr__(self):
return self.__class__.__name__ + '(' \
+ 'n_features=' + str(self.n_features) \
+ ', context=' + str(self.context) + ')'
class DeepSpeech(nn.Module):
def __init__(self, rnn_type=nn.LSTM, labels="abc", rnn_hidden_size=768, nb_layers=5, audio_conf=None,
bidirectional=True, context=20):
super(DeepSpeech, self).__init__()
# model metadata needed for serialization/deserialization
if audio_conf is None:
audio_conf = {}
self._version = '0.0.1'
self._hidden_size = rnn_hidden_size
self._hidden_layers = nb_layers
self._rnn_type = rnn_type
self._audio_conf = audio_conf or {}
self._labels = labels
self._bidirectional = bidirectional
sample_rate = self._audio_conf.get("sample_rate", 16000)
window_size = self._audio_conf.get("window_size", 0.02)
num_classes = len(self._labels)
self.conv = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=(41, 11), stride=(2, 2), padding=(0, 10)),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True),
nn.Conv2d(32, 32, kernel_size=(21, 11), stride=(2, 1), ),
nn.BatchNorm2d(32),
nn.Hardtanh(0, 20, inplace=True)
)
# Based on above convolutions and spectrogram size using conv formula (W - F + 2P)/ S+1
rnn_input_size = int(math.floor((sample_rate * window_size) / 2) + 1)
rnn_input_size = int(math.floor(rnn_input_size - 41) / 2 + 1)
rnn_input_size = int(math.floor(rnn_input_size - 21) / 2 + 1)
rnn_input_size *= 32
rnns = []
rnn = BatchRNN(input_size=rnn_input_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type,
bidirectional=bidirectional, batch_norm=False)
rnns.append(('0', rnn))
for x in range(nb_layers - 1):
rnn = BatchRNN(input_size=rnn_hidden_size, hidden_size=rnn_hidden_size, rnn_type=rnn_type,
bidirectional=bidirectional)
rnns.append(('%d' % (x + 1), rnn))
self.rnns = nn.Sequential(OrderedDict(rnns))
self.lookahead = nn.Sequential(
# consider adding batch norm?
Lookahead(rnn_hidden_size, context=context),
nn.Hardtanh(0, 20, inplace=True)
) if not bidirectional else None
fully_connected = nn.Sequential(
nn.BatchNorm1d(rnn_hidden_size),
nn.Linear(rnn_hidden_size, num_classes, bias=False)
)
self.fc = nn.Sequential(
SequenceWise(fully_connected),
)
self.inference_softmax = InferenceBatchSoftmax()
def forward(self, x):
x = self.conv(x)
sizes = x.size()
x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # Collapse feature dimension
x = x.transpose(1, 2).transpose(0, 1).contiguous() # TxNxH
x = self.rnns(x)
if not self._bidirectional: # no need for lookahead layer in bidirectional
x = self.lookahead(x)
x = self.fc(x)
x = x.transpose(0, 1)
# identity in training mode, softmax in eval mode
x = self.inference_softmax(x)
return x
@classmethod
def load_model(cls, path, cuda=False):
package = torch.load(path, map_location=lambda storage, loc: storage)
model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'],
labels=package['labels'], audio_conf=package['audio_conf'],
rnn_type=supported_rnns[package['rnn_type']], bidirectional=package.get('bidirectional', True))
# the blacklist parameters are params that were previous erroneously saved by the model
# care should be taken in future versions that if batch_norm on the first rnn is required
# that it be named something else
blacklist = ['rnns.0.batch_norm.module.weight', 'rnns.0.batch_norm.module.bias',
'rnns.0.batch_norm.module.running_mean', 'rnns.0.batch_norm.module.running_var']
for x in blacklist:
if x in package['state_dict']:
del package['state_dict'][x]
model.load_state_dict(package['state_dict'])
for x in model.rnns:
x.flatten_parameters()
if cuda:
model = torch.nn.DataParallel(model).cuda()
return model
@classmethod
def load_model_package(cls, package, cuda=False):
model = cls(rnn_hidden_size=package['hidden_size'], nb_layers=package['hidden_layers'],
labels=package['labels'], audio_conf=package['audio_conf'],
rnn_type=supported_rnns[package['rnn_type']], bidirectional=package.get('bidirectional', True))
model.load_state_dict(package['state_dict'])
if cuda:
model = torch.nn.DataParallel(model).cuda()
return model
@staticmethod
def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=None,
cer_results=None, wer_results=None, avg_loss=None, meta=None):
model_is_cuda = next(model.parameters()).is_cuda
model = model.module if model_is_cuda else model
package = {
'version': model._version,
'hidden_size': model._hidden_size,
'hidden_layers': model._hidden_layers,
'rnn_type': supported_rnns_inv.get(model._rnn_type, model._rnn_type.__name__.lower()),
'audio_conf': model._audio_conf,
'labels': model._labels,
'state_dict': model.state_dict(),
'bidirectional': model._bidirectional
}
if optimizer is not None:
package['optim_dict'] = optimizer.state_dict()
if avg_loss is not None:
package['avg_loss'] = avg_loss
if epoch is not None:
package['epoch'] = epoch + 1 # increment for readability
if iteration is not None:
package['iteration'] = iteration
if loss_results is not None:
package['loss_results'] = loss_results
package['cer_results'] = cer_results
package['wer_results'] = wer_results
if meta is not None:
package['meta'] = meta
return package
@staticmethod
def get_labels(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._labels if model_is_cuda else model._labels
@staticmethod
def get_param_size(model):
params = 0
for p in model.parameters():
tmp = 1
for x in p.size():
tmp *= x
params += tmp
return params
@staticmethod
def get_audio_conf(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._audio_conf if model_is_cuda else model._audio_conf
@staticmethod
def get_meta(model):
model_is_cuda = next(model.parameters()).is_cuda
m = model.module if model_is_cuda else model
meta = {
"version": m._version,
"hidden_size": m._hidden_size,
"hidden_layers": m._hidden_layers,
"rnn_type": supported_rnns_inv[m._rnn_type]
}
return meta
if __name__ == '__main__':
import os.path
import argparse
parser = argparse.ArgumentParser(description='DeepSpeech model information')
parser.add_argument('--model-path', default='models/deepspeech_final.pth.tar',
help='Path to model file created by training')
args = parser.parse_args()
package = torch.load(args.model_path, map_location=lambda storage, loc: storage)
model = DeepSpeech.load_model(args.model_path)
print("Model name: ", os.path.basename(args.model_path))
print("DeepSpeech version: ", model._version)
print("")
print("Recurrent Neural Network Properties")
print(" RNN Type: ", model._rnn_type.__name__.lower())
print(" RNN Layers: ", model._hidden_layers)
print(" RNN Size: ", model._hidden_size)
print(" Classes: ", len(model._labels))
print("")
print("Model Features")
print(" Labels: ", model._labels)
print(" Sample Rate: ", model._audio_conf.get("sample_rate", "n/a"))
print(" Window Type: ", model._audio_conf.get("window", "n/a"))
print(" Window Size: ", model._audio_conf.get("window_size", "n/a"))
print(" Window Stride: ", model._audio_conf.get("window_stride", "n/a"))
if package.get('loss_results', None) is not None:
print("")
print("Training Information")
epochs = package['epoch']
print(" Epochs: ", epochs)
print(" Current Loss: {0:.3f}".format(package['loss_results'][epochs - 1]))
print(" Current CER: {0:.3f}".format(package['cer_results'][epochs - 1]))
print(" Current WER: {0:.3f}".format(package['wer_results'][epochs - 1]))
if package.get('meta', None) is not None:
print("")
print("Additional Metadata")
for k, v in model._meta:
print(" ", k, ": ", v)