Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

[WIP] 2-state HMM topo as an alternative to CTC topo #126

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
11 changes: 6 additions & 5 deletions egs/aishell/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from snowfall.common import setup_logger
from snowfall.decoding.graph import compile_LG
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
from snowfall.models.conformer import Conformer
from snowfall.models.transformer import Transformer
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols
Expand Down Expand Up @@ -268,7 +268,8 @@ def main():
P.set_scores_stochastic_(model.P_scores)
print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt')

if not os.path.exists(lang_dir / 'LG.pt'):
HLG_path = exp_dir / 'HLG.pt'
if not HLG_path.exists():
logging.debug("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
Expand All @@ -282,10 +283,10 @@ def main():
ctc_topo=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(LG.as_dict(), lang_dir / 'LG.pt')
torch.save(LG.as_dict(), HLG_path)
else:
logging.debug("Loading pre-compiled LG")
d = torch.load(lang_dir / 'LG.pt')
logging.debug("Loading pre-compiled HLG")
d = torch.load(HLG_path)
LG = k2.Fsa.from_dict(d)

# load dataset
Expand Down
28 changes: 15 additions & 13 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@
from snowfall.common import setup_logger
from snowfall.decoding.graph import compile_HLG
from snowfall.models import AcousticModel
from snowfall.models.transformer import Transformer
from snowfall.models.conformer import Conformer
from snowfall.training.ctc_graph import build_ctc_topo
from snowfall.models.transformer import Transformer
from snowfall.training.hmm_topo import build_hmm_topo_2state
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols

Expand Down Expand Up @@ -206,7 +206,7 @@ def main():
avg = args.avg
att_rate = args.att_rate

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan')
exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-hmm')
setup_logger('{}/log/log-decode'.format(exp_dir), log_level='debug')

# load L, G, symbol_table
Expand All @@ -218,7 +218,8 @@ def main():
P = create_bigram_phone_lm(phone_ids)

phone_ids_with_blank = [0] + phone_ids
ctc_topo = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
# H = k2.arc_sort(build_ctc_topo(phone_ids_with_blank))
H = build_hmm_topo_2state(phone_ids_with_blank)

logging.debug("About to load model")
# Note: Use "export CUDA_VISIBLE_DEVICES=N" to setup device id to N
Expand All @@ -235,15 +236,15 @@ def main():
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
else:
model = Conformer(
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

Expand All @@ -267,7 +268,8 @@ def main():
P.set_scores_stochastic_(model.P_scores)
print_transition_probabilities(P, phone_symbol_table, phone_ids, filename='P_scores.txt')

if not os.path.exists(lang_dir / 'HLG.pt'):
HLG_path = exp_dir / 'HLG.pt'
if not HLG_path.exists():
logging.debug("Loading L_disambig.fst.txt")
with open(lang_dir / 'L_disambig.fst.txt') as f:
L = k2.Fsa.from_openfst(f.read(), acceptor=False)
Expand All @@ -277,14 +279,14 @@ def main():
first_phone_disambig_id = find_first_disambig_symbol(phone_symbol_table)
first_word_disambig_id = find_first_disambig_symbol(symbol_table)
HLG = compile_HLG(L=L,
G=G,
H=ctc_topo,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), lang_dir / 'HLG.pt')
G=G,
H=H,
labels_disambig_id_start=first_phone_disambig_id,
aux_labels_disambig_id_start=first_word_disambig_id)
torch.save(HLG.as_dict(), HLG_path)
else:
logging.debug("Loading pre-compiled HLG")
d = torch.load(lang_dir / 'HLG.pt')
d = torch.load(HLG_path)
HLG = k2.Fsa.from_dict(d)

# load dataset
Expand Down
14 changes: 9 additions & 5 deletions egs/librispeech/asr/simple_v1/mmi_att_transformer_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,10 @@
from snowfall.common import save_training_info
from snowfall.common import setup_logger
from snowfall.models import AcousticModel
from snowfall.models.transformer import Noam, Transformer
from snowfall.models.conformer import Conformer
from snowfall.models.transformer import Noam, Transformer
from snowfall.training.diagnostics import measure_gradient_norms, optim_step_and_measure_param_change
from snowfall.training.hmm_topo import build_hmm_topo_2state
from snowfall.training.mmi_graph import MmiTrainingGraphCompiler
from snowfall.training.mmi_graph import create_bigram_phone_lm
from snowfall.training.mmi_graph import get_phone_symbols
Expand Down Expand Up @@ -64,7 +65,7 @@ def get_tot_objf_and_num_frames(tot_scores: torch.Tensor,
frames_per_seq[bad_indexes], " vs. max length ",
torch.max(frames_per_seq), ", avg ",
(torch.sum(frames_per_seq) / frames_per_seq.numel()))
# print("finite_indexes = ", finite_indexes, ", tot_scores = ", tot_scores)
#print("finite_indexes = ", finite_indexes, ", tot_scores = ", tot_scores)
ok_frames = frames_per_seq[finite_indexes].sum()
all_frames = frames_per_seq.sum()
return (tot_scores[finite_indexes].sum(), ok_frames, all_frames)
Expand Down Expand Up @@ -134,6 +135,7 @@ def get_objf(batch: Dict,

num = k2.intersect_dense(num, dense_fsa_vec, 10.0)
den = k2.intersect_dense(den, dense_fsa_vec, 10.0)
#den = k2.intersect_dense_pruned(den, dense_fsa_vec, search_beam=10.0, output_beam=10.0, min_active_states=100, max_active_states=1000)

num_tot_scores = num.get_tot_scores(
log_semiring=True,
Expand Down Expand Up @@ -446,7 +448,7 @@ def main():

fix_random_seed(42)

exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan')
exp_dir = Path('exp-' + model_type + '-noam-mmi-att-musan-hmm')
setup_logger('{}/log/log-train'.format(exp_dir))
tb_writer = SummaryWriter(log_dir=f'{exp_dir}/tensorboard')

Expand All @@ -467,11 +469,13 @@ def main():
device_id = 0
device = torch.device('cuda', device_id)

logging.info('Initializing the MMI graph compiler')
graph_compiler = MmiTrainingGraphCompiler(
L_inv=L_inv,
phones=phone_symbol_table,
words=word_symbol_table,
device=device,
topo_builder_fn=build_hmm_topo_2state
)
phone_ids = get_phone_symbols(phone_symbol_table)
P = create_bigram_phone_lm(phone_ids)
Expand Down Expand Up @@ -550,15 +554,15 @@ def main():
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)
else:
model = Conformer(
num_features=40,
nhead=args.nhead,
d_model=args.attention_dim,
num_classes=len(phone_ids) + 1, # +1 for the blank symbol
num_classes=2 * len(phone_ids) + 2, # +1 for the blank symbol
subsampling_factor=4,
num_decoder_layers=num_decoder_layers)

Expand Down
52 changes: 52 additions & 0 deletions snowfall/training/hmm_topo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import k2
from typing import List


def build_hmm_topo_2state(tokens: List[int]) -> k2.Fsa:
"""
Build a 2-state HMM topology used in Kaldi's chain models.
The first HMM state is entered only once for each token instance,
and the second HMM state is self-looped and optional.

Args:
tokens:
A list of token int IDs, e.g., phones, characters, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably an issue in the baseline, but we shuold be clear whether this list is supposed to contain zero, or perhaps should not contain zero.

The IDs for the first HMM state will be the same as token IDs;
The IDs for the second HMM state are: ``token_id + len(tokens)``
Returns:
An FST that converts a sequence of HMM state IDs to a sequence of token IDs.
"""
min_token_id = min(tokens)
followup_tokens = list(range(
len(tokens) + min_token_id,
2 * len(tokens) + min_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you making an assumption here that tokens is contiguous?

))
num_states = len(tokens) + 2 # + start state, + final state
arcs = []

# Start state -> token state
for i in range(0, len(tokens)):
arcs += [f'0 {i + 1} {tokens[i]} {tokens[i]} 0.0']

# Token state self loops
for i in range(0, len(tokens)):
arcs += [f'{i + 1} {i + 1} {followup_tokens[i]} 0 0.0']

# Cross-token transitions
for i in range(0, len(tokens)):
for j in range(0, len(tokens)):
if i != j:
arcs += [f'{i + 1} {j + 1} {tokens[i]} {tokens[i]} 0.0']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be tokens[j] and tokens[j], instead of tokens[i] and tokens[i]?


# Token state -> superfinal state
for i in range(0, len(tokens)):
arcs += [f'{i + 1} {num_states - 1} -1 -1 0.0']

# Final state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix the problem, you can change

# Final state
arcs += [f'{num_states - 1}']

# Build the FST
arcs = '\n'.join(sorted(arcs))

to

# Build the FST
arcs = '\n'.join(sorted(arcs))

# Final state
arcs += f'\n{num_states - 1}'

k2 expects that the last line contains the final state. Nothing should follow
the final state.

The documentation https://github.com/k2-fsa/k2/blob/1eeeecfac558a6ae4133e2c0b4f0022bee24c786/k2/python/k2/fsa.py#L1078
says

        Caution:
          The first column has to be non-decreasing.

non-decreasing is in numeric, not in alphabetic order. sorted in python sorts in alphabetic.
That is the problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above fix is not a complete solution.
If the list is too large, it may result in

1 ....
1 ...
11 ....
2 ....

due to sorted. 11 should come after 2 and it will cause another crash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing

arcs = '\n'.join(sorted(arcs))

to

arcs = '\n'.join(sorted(arcs, key=lambda arc: int(arc.split()[0])))

should work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I don't think I would have came up with that so fast myself ;)

arcs += [f'{num_states - 1}']

# Build the FST
arcs = '\n'.join(sorted(arcs, key=lambda arc: int(arc.split()[0])))
ans = k2.Fsa.from_str(arcs)
ans = k2.arc_sort(ans)
return ans
34 changes: 17 additions & 17 deletions snowfall/training/mmi_graph.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# Copyright (c) 2020 Xiaomi Corp. (author: Fangjun Kuang)

import k2
import torch
from typing import Iterable
from typing import List
from typing import Tuple

import k2
import torch

from .ctc_graph import build_ctc_topo
from snowfall.common import get_phone_symbols
from .ctc_graph import build_ctc_topo


def create_bigram_phone_lm(phones: List[int]) -> k2.Fsa:
Expand Down Expand Up @@ -47,6 +46,7 @@ def __init__(self,
phones: k2.SymbolTable,
words: k2.SymbolTable,
device: torch.device,
topo_builder_fn=build_ctc_topo,
oov: str = '<UNK>'):
'''
Args:
Expand Down Expand Up @@ -78,10 +78,9 @@ def __init__(self,
phone_symbols = get_phone_symbols(phones)
phone_symbols_with_blank = [0] + phone_symbols

ctc_topo = build_ctc_topo(phone_symbols_with_blank).to(device)
assert ctc_topo.requires_grad is False

self.ctc_topo_inv = k2.arc_sort(ctc_topo.invert_())
H = topo_builder_fn(phone_symbols_with_blank).to(device)
assert H.requires_grad is False
self.H_inv = k2.arc_sort(H.invert_())

def compile(self, texts: Iterable[str],
P: k2.Fsa) -> Tuple[k2.Fsa, k2.Fsa]:
Expand All @@ -106,28 +105,28 @@ def compile(self, texts: Iterable[str],
assert P.device == self.device
P_with_self_loops = k2.add_epsilon_self_loops(P)

ctc_topo_P = k2.intersect(self.ctc_topo_inv,
P_with_self_loops,
treat_epsilons_specially=False).invert()

ctc_topo_P = k2.arc_sort(ctc_topo_P)
HP = k2.intersect(
self.H_inv,
P_with_self_loops,
treat_epsilons_specially=False
).invert()
HP = k2.arc_sort(HP)

num_graphs = self.build_num_graphs(texts)
num_graphs_with_self_loops = k2.remove_epsilon_and_add_self_loops(
num_graphs)

num_graphs_with_self_loops = k2.arc_sort(num_graphs_with_self_loops)

num = k2.compose(ctc_topo_P,
num = k2.compose(HP,
num_graphs_with_self_loops,
treat_epsilons_specially=False)
num = k2.arc_sort(num)

ctc_topo_P_vec = k2.create_fsa_vec([ctc_topo_P.detach()])
HP_vec = k2.create_fsa_vec([HP.detach()])
indexes = torch.zeros(len(texts),
dtype=torch.int32,
device=self.device)
den = k2.index_fsa(ctc_topo_P_vec, indexes)
den = k2.index_fsa(HP_vec, indexes)

return num, den

Expand Down Expand Up @@ -163,3 +162,4 @@ def build_num_graphs(self, texts: List[str]) -> k2.Fsa:
treat_epsilons_specially=False).invert_()
num_graphs = k2.arc_sort(num_graphs)
return num_graphs