Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Add timit recipe #247

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions egs/aishell/asr/simple_v1/mmi_bigram_train.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

#!/usr/bin/env python3
# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu, Mingshuang Luo)
# 2021 Pingfeng Luo
Expand Down
65 changes: 65 additions & 0 deletions egs/timit/asr/simple_v1/RESULTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# TIMIT CTC Training Results

## 2021-09-03
(Mingshuang Luo):

### TIMIT CTC_Train Based on 48 phones

Testing results based on different training epochs:
```
epoch=20
2021-09-03 10:54:10,903 INFO [ctc_decode.py:188] %PER 30.34% [2225 / 7333, 293 ins, 441 del, 1491 sub ]

epoch=30
2021-09-03 10:59:10,147 INFO [ctc_decode.py:188] %PER 29.77% [2183 / 7333, 221 ins, 473 del, 1489 sub ]

epoch=35
2021-09-03 11:11:00,885 INFO [ctc_decode.py:188] %PER 28.94% [2122 / 7333, 266 ins, 397 del, 1459 sub ]

epoch=40
2021-09-03 11:12:39,029 INFO [ctc_decode.py:188] %PER 29.52% [2165 / 7333, 304 ins, 348 del, 1513 sub ]
```

### TIMIT CTC_Train Based on 39 phones

Testing results based on different training epochs:
```
epoch=40
2021-09-13 11:02:14,793 INFO [ctc_decode.py:189] %PER 25.61% [1848 / 7215, 301 ins, 396 del, 1151 sub ]

epoch=45
2021-09-13 11:01:20,787 INFO [ctc_decode.py:189] %PER 25.50% [1840 / 7215, 286 ins, 386 del, 1168 sub ]

epoch=47
2021-09-13 11:04:05,533 INFO [ctc_decode.py:189] %PER 26.20% [1890 / 7215, 373 ins, 367 del, 1150 sub ]

```
### TIMIT CTC_TRAIN_with_CRDNN Based on 48 phones

Testing results based on different training epochs:
```
epoch=35
2021-09-13 11:21:01,592 INFO [ctc_crdnn_decode.py:201] %PER 20.46% [1476 / 7215, 249 ins, 356 del, 871 sub ]

epoch=45
2021-09-13 11:22:02,221 INFO [ctc_crdnn_decode.py:201] %PER 19.75% [1425 / 7215, 239 ins, 324 del, 862 sub ]

epoch=53
2021-09-13 11:23:07,969 INFO [ctc_crdnn_decode.py:201] %PER 18.86% [1361 / 7215, 214 ins, 320 del, 827 sub ]

```

### TIMIT CTC_TRAIN_with_CRDNN Based on 39 phones

Testing results based on different training epochs:
```
epoch=26
2021-09-13 11:32:41,388 INFO [ctc_crdnn_decode.py:201] %PER 21.04% [1518 / 7215, 345 ins, 251 del, 922 sub ]

epoch=45
2021-09-13 11:34:27,566 INFO [ctc_crdnn_decode.py:201] %PER 18.74% [1352 / 7215, 316 ins, 239 del, 797 sub ]

epoch=55
2021-09-13 11:35:55,751 INFO [ctc_crdnn_decode.py:201] %PER 18.24% [1316 / 7215, 267 ins, 242 del, 807 sub ]

```
236 changes: 236 additions & 0 deletions egs/timit/asr/simple_v1/ctc_crdnn_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
#!/usr/bin/env python3

# Copyright (c) 2020 Xiaomi Corporation (authors: Daniel Povey, Haowen Qiu)
# 2021 Xiaomi Corporation (authors: Mingshuang Luo)
# Apache 2.0

# Notice: before you run this script, you should install speechbrain first.
# You can install speechbrain by "pip install speechbrain".

import k2
import logging
import os
import torch
import torch.nn as nn

from k2 import Fsa, SymbolTable
from kaldialign import edit_distance
from pathlib import Path
from typing import Union

from lhotse import CutSet
from lhotse.dataset import K2SpeechRecognitionDataset
from lhotse.dataset import SingleCutSampler
from snowfall.common import find_first_disambig_symbol
from snowfall.common import get_phone_symbols
from snowfall.common import get_texts
from snowfall.common import load_checkpoint
from snowfall.common import setup_logger
from snowfall.decoding.graph import compile_HLG
from snowfall.models import AcousticModel
from snowfall.training.ctc_graph import build_ctc_topo

import sys
import argparse

from speechbrain.lobes.models.CRDNN import CRDNN
from speechbrain.nnet.linear import Linear

class crdnn_model(nn.Module):
def __init__(self, num_features:int,
num_classes:int,
subsampling_factor: int,
crdnn,
linear):

super(crdnn_model, self).__init__()
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor

self.crdnn = crdnn
self.linear=linear

def forward(self, x):
x = self.crdnn(x)
x = self.linear(x)
x = x.transpose(1,2)
x = nn.functional.log_softmax(x, dim=1)

return x

def decode(dataloader: torch.utils.data.DataLoader, model: AcousticModel,
device: Union[str, torch.device], HLG: Fsa, symbols: SymbolTable):
tot_num_cuts = len(dataloader)
num_cuts = 0
results = [] # a list of pair (ref_words, hyp_words)
for batch_idx, batch in enumerate(dataloader):

feature = batch['inputs']
supervisions = batch['supervisions']
supervision_segments = torch.stack(
(supervisions['sequence_idx'],
torch.floor_divide(supervisions['start_frame'],
model.subsampling_factor),
torch.floor_divide(supervisions['num_frames'],
model.subsampling_factor)), 1).to(torch.int32)
indices = torch.argsort(supervision_segments[:, 2], descending=True)
supervision_segments = supervision_segments[indices]
texts = supervisions['text']
assert feature.ndim == 3

feature = feature.to(device)
# at entry, feature is [N, T, C]
#feature = feature.permute(0, 2, 1) # now feature is [N, C, T]
with torch.no_grad():
nnet_output = model(feature)
# nnet_output is [N, C, T]
nnet_output = nnet_output.permute(0, 2,
1) # now nnet_output is [N, T, C]

blank_bias = -3.0
nnet_output[:, :, 0] += blank_bias

dense_fsa_vec = k2.DenseFsaVec(nnet_output, supervision_segments)
# assert HLG.is_cuda()
assert HLG.device == nnet_output.device, \
f"Check failed: HLG.device ({HLG.device}) == nnet_output.device ({nnet_output.device})"
# TODO(haowen): with a small `beam`, we may get empty `target_graph`,
# thus `tot_scores` will be `inf`. Definitely we need to handle this later.
lattices = k2.intersect_dense_pruned(HLG, dense_fsa_vec, 20.0, 7.0, 30, 10000)

# lattices = k2.intersect_dense(HLG, dense_fsa_vec, 10.0)
best_paths = k2.shortest_path(lattices, use_double_scores=True)
assert best_paths.shape[0] == len(texts)
hyps = get_texts(best_paths, indices)
assert len(hyps) == len(texts)

for i in range(len(texts)):
hyp_words = [symbols.get(x) for x in hyps[i]]
ref_words = texts[i].split(' ')
results.append((ref_words, hyp_words))

if batch_idx % 10 == 0:
logging.info(
'batch {}, cuts processed until now is {}/{} ({:.6f}%)'.format(
batch_idx, num_cuts, tot_num_cuts,
float(num_cuts) / tot_num_cuts * 100))

num_cuts += len(texts)

return results


def main():
import argparse

parser = argparse.ArgumentParser(description='Process some integers.')
parser.add_argument('--epoch', type=int, default=20,
help='the checkpoint for loading.')

parser.add_argument('--mode', type=str, default='TEST',
help='the mode to test.')

args = parser.parse_args()

exp_dir = Path('exp-lstm-adam-ctc-musan-with-crdnn')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

# load L, G, symbol_table
lang_dir = Path('data/lang_nosp')
symbol_table = k2.SymbolTable.from_file(lang_dir / 'words.txt')
phone_symbol_table = k2.SymbolTable.from_file(lang_dir / 'phones.txt')
phone_ids = get_phone_symbols(phone_symbol_table)
phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))

if not os.path.exists(lang_dir / 'HLG.pt'):
print("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
print("Loading G.fst.txt")
with open(lang_dir / 'G.fst.txt') as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False)
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
else:
print("Loading pre-compiled HLG")
d = torch.load(lang_dir / 'HLG.pt')
HLG = k2.Fsa.from_dict(d)

# load dataset
feature_dir = Path('exp/data')
print("About to get test cuts")
cuts_test = CutSet.from_json(feature_dir / 'cuts_{}.json.gz'.format(args.mode))

print("About to create test dataset")
test = K2SpeechRecognitionDataset(cuts_test)
sampler = SingleCutSampler(cuts_test, max_frames=100000)
print("About to create test dataloader")
test_dl = torch.utils.data.DataLoader(test, batch_size=None, sampler=sampler, num_workers=1)

# if not torch.cuda.is_available():
# logging.error('No GPU detected!')
# sys.exit(-1)

print("About to load model")
# Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
# device = torch.device('cuda', 1)
device = torch.device('cuda')

crdnn = CRDNN(
input_size=80,
time_pooling=True)

linear = Linear(
input_size=512,
n_neurons=len(phone_ids) + 1)

model = crdnn_model(80, len(phone_ids)+1, 2, crdnn, linear)

checkpoint = os.path.join(exp_dir, 'epoch-{}.pt'.format(args.epoch))
load_checkpoint(checkpoint, model)
model.to(device)
model.eval()

print("convert HLG to device")
HLG = HLG.to(device)
HLG.aux_labels = k2.ragged.remove_values_eq(HLG.aux_labels, 0)
HLG.requires_grad_(False)
print("About to decode")
results = decode(dataloader=test_dl,
model=model,
device=device,
HLG=HLG,
symbols=symbol_table)
s = ''
for ref, hyp in results:
s += f'ref={ref}\n'
s += f'hyp={hyp}\n'
#logging.info(s)
results = [([one for one in n[0] if one], [one for one in n[1] if one]) for n in results]
# compute WER
dists = [edit_distance(r, h) for r, h in results]
errors = {
key: sum(dist[key] for dist in dists)
for key in ['sub', 'ins', 'del', 'total']
}
total_words = sum(len(ref) for ref, _ in results)

logging.info(
f'%PER {errors["total"] / total_words:.2%} '
f'[{errors["total"]} / {total_words}, {errors["ins"]} ins, {errors["del"]} del, {errors["sub"]} sub ]'
)


torch.set_num_threads(1)
torch.set_num_interop_threads(1)

if __name__ == '__main__':
main()
Loading