Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

[WIP] 2-state HMM topo as an alternative to CTC topo #126

Open
wants to merge 12 commits into
base: master
Choose a base branch
from

Conversation

pzelasko
Copy link
Collaborator

I'm trying to build a topology where the "blank" is phone-specific instead of shared between phones (I believe that corresponds to Kaldi's chain topology).

I added a function build_hmm_topo_2state which seems to work OK for low numbers of tokens, but crashes at inputs that have more than 8 elements. Except for this function, this PR is not otherwise ready for review.

The first problematic input is [0, 1, 2, 3, 4, 5, 6, 7, 8], everything smaller than that works. This is how the FSA looks for 1, 2, and 3 element lists:

image

It also seems to work when 0 is in the input IDs although I'm not sure if K2 has any hard-coded assumptions about symbols with ID 0 (I think OpenFST did).

For the problematic inputs, the program crashes with message:

[F] /exp/pzelasko/k2/k2/csrc/fsa_utils.cu:k2::Fsa k2::K2TransducerFromStream(std::istringstream&, k2::Array1<int>*):203 Check failed: finished == false (1 vs. 0)


[ Stack-Trace: ]
/exp/pzelasko/k2/build/lib/libk2_log.so(k2::internal::GetStackTrace()+0x34) [0x2aab587d98e4]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::internal::Logger::~Logger()+0x28) [0x2aab53cc5fe8]
/exp/pzelasko/k2/build/lib/libk2context.so(+0xf6b37) [0x2aab53d4fb37]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::FsaFromString(std::string const&, bool, k2::Array1<int>*)+0x416) [0x2aab53d50b06]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so(+0x3b836) [0x2aab50447836]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so(+0x1a6d3) [0x2aab504266d3]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyMethodDef_RawFastCallKeywords+0x274) [0x5555556b9914]

for i in range(0, len(tokens)):
arcs += [f'{i + 1} {num_states - 1} -1 -1 0.0']

# Final state
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To fix the problem, you can change

# Final state
arcs += [f'{num_states - 1}']

# Build the FST
arcs = '\n'.join(sorted(arcs))

to

# Build the FST
arcs = '\n'.join(sorted(arcs))

# Final state
arcs += f'\n{num_states - 1}'

k2 expects that the last line contains the final state. Nothing should follow
the final state.

The documentation https://github.com/k2-fsa/k2/blob/1eeeecfac558a6ae4133e2c0b4f0022bee24c786/k2/python/k2/fsa.py#L1078
says

        Caution:
          The first column has to be non-decreasing.

non-decreasing is in numeric, not in alphabetic order. sorted in python sorts in alphabetic.
That is the problem.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above fix is not a complete solution.
If the list is too large, it may result in

1 ....
1 ...
11 ....
2 ....

due to sorted. 11 should come after 2 and it will cause another crash.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing

arcs = '\n'.join(sorted(arcs))

to

arcs = '\n'.join(sorted(arcs, key=lambda arc: int(arc.split()[0])))

should work.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I don't think I would have came up with that so fast myself ;)

Returns:
An FST that converts a sequence of HMM state IDs to a sequence of token IDs.
"""
followup_tokens = range(len(tokens), len(tokens) * 2)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should it be

followup_tokens = range(len(tokens) + 1, len(tokens) * 2 + 1)

as token id starts from 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you're right. In the general case, to avoid surprises, I think that should be len(tokens) + min_token_id.

@pzelasko
Copy link
Collaborator Author

After the changes it seems to work (sometimes). In general it seems to consume much more GPU memory, I have been decreasing max_frames and the beam for intersect_dense for dens, but still didn't find the right combination and keep hitting CUDA OOM.

On a good run, it's converging:

2021-03-13 00:13:46,821 INFO [mmi_att_transformer_train.py:599] epoch 0, learning rate 0
2021-03-13 00:13:49,124 INFO [mmi_att_transformer_train.py:310] batch 0, epoch 0/10 global average objf: 1.927489 over 4948.0 frames (100.0% kept), current batch average objf: 1.927489 over 4948 frames (100.0% kept) avg time waiting for batch 0.590s
2021-03-13 00:14:02,008 INFO [mmi_att_transformer_train.py:310] batch 10, epoch 0/10 global average objf: 1.761968 over 52529.0 frames (100.0% kept), current batch average objf: 1.614887 over 4715 frames (100.0% kept) avg time waiting for batch 0.062s
2021-03-13 00:14:14,586 INFO [mmi_att_transformer_train.py:310] batch 20, epoch 0/10 global average objf: 1.648654 over 99927.0 frames (100.0% kept), current batch average objf: 1.352455 over 4765 frames (100.0% kept) avg time waiting for batch 0.033s
2021-03-13 00:14:27,491 INFO [mmi_att_transformer_train.py:310] batch 30, epoch 0/10 global average objf: 1.580114 over 147669.0 frames (100.0% kept), current batch average objf: 1.385895 over 4603 frames (100.0% kept) avg time waiting for batch 0.023s
2021-03-13 00:14:40,817 INFO [mmi_att_transformer_train.py:310] batch 40, epoch 0/10 global average objf: 1.516337 over 195646.0 frames (100.0% kept), current batch average objf: 1.351712 over 4658 frames (100.0% kept) avg time waiting for batch 0.018s
2021-03-13 00:14:55,210 INFO [mmi_att_transformer_train.py:310] batch 50, epoch 0/10 global average objf: 1.479551 over 242952.0 frames (100.0% kept), current batch average objf: 1.247315 over 4623 frames (100.0% kept) avg time waiting for batch 0.015s
2021-03-13 00:15:09,543 INFO [mmi_att_transformer_train.py:310] batch 60, epoch 0/10 global average objf: 1.441809 over 291021.0 frames (100.0% kept), current batch average objf: 1.128138 over 4720 frames (100.0% kept) avg time waiting for batch 0.013s
2021-03-13 00:15:24,617 INFO [mmi_att_transformer_train.py:310] batch 70, epoch 0/10 global average objf: 1.406366 over 339101.0 frames (100.0% kept), current batch average objf: 1.045948 over 4713 frames (100.0% kept) avg time waiting for batch 0.011s

However, sometimes it will also crash during intersection (happened in a specific batch that I got when setting max_frames=30000).

2021-03-13 00:11:31,534 INFO [mmi_att_transformer_train.py:599] epoch 0, learning rate 0
[F] /exp/pzelasko/k2/k2/csrc/intersect_dense.cu:lambda [](int)->void::operator()(int)->void:728 block:[0,0,0], thread: [24,0,0] Check failed: tot_score_end == tot_score_start || fabs(tot_score_end - tot_score_start) < 1.0 -455.167328 vs -inf
/exp/pzelasko/k2/k2/csrc/intersect_dense.cu:728: lambda [](int)->void::operator()(int)->void: block: [0,0,0], thread: [24,0,0] Assertion `Some bad things happened` failed.
[F] /exp/pzelasko/k2/k2/csrc/array.h:T k2::Array1<T>::operator[](int32_t) const [with T = int; int32_t = int]:275 Check failed: ret == cudaSuccess (710 vs. 0)  Error: device-side assert triggered.


[ Stack-Trace: ]
/exp/pzelasko/k2/build/lib/libk2_log.so(k2::internal::GetStackTrace()+0x34) [0x2aab31b748e4]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::internal::Logger::~Logger()+0x28) [0x2aab2d060fe8]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::Array1<int>::operator[](int) const+0x1929) [0x2aab2d062c69]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::Renumbering::ComputeOld2New()+0x13a) [0x2aab2d05e74a]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::Renumbering::ComputeNew2Old()+0x5e0) [0x2aab2d05f780]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::MultiGraphDenseIntersect::FormatOutput(k2::Array1<int>*, k2::Array1<int>*)+0x157c) [0x2aab2d178bec]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::IntersectDense(k2::Ragged<k2::Arc>&, k2::DenseFsaVec&, k2::Array1<int> const*, float, k2::Ragged<k2::Arc>*, k2::Array1<int>*, k2::Array1<int>*)+0x415) [0x2aab2d16cc35]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so(+0x57aba) [0x2aab29819aba]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so(+0x1a6d3) [0x2aab297dc6d3]
python3(_PyMethodDef_RawFastCallKeywords+0x316) [0x5555556b99b6]

Any thoughts?

@pzelasko
Copy link
Collaborator Author

Hmm, I think the latter error was related to having too few outputs in the nnet (I was off by two). I fixed that and the error disappeared...

@pzelasko
Copy link
Collaborator Author

FYI I updated to the most recent K2 because I remembered there were some new memory optimizations for intersection; it does help. For dens intersection with posteriors, I am now gettting messages like:

[I] /exp/pzelasko/k2/k2/csrc/intersect_dense.cu:k2::FsaVec k2::MultiGraphDenseIntersect::FormatOutput(k2::Array1<int>*, k2::Array1<int>*):267 Num-states 16903729 exceeds limit 15000000, decreasing beam from 10.000000 to 7.500000
[I] /exp/pzelasko/k2/k2/csrc/intersect_dense.cu:k2::FsaVec k2::MultiGraphDenseIntersect::FormatOutput(k2::Array1<int>*, k2::Array1<int>*):267 Num-states 20420932 exceeds limit 15000000, decreasing beam from 10.000000 to 7.500000
[I] /exp/pzelasko/k2/k2/csrc/intersect_dense.cu:k2::FsaVec k2::MultiGraphDenseIntersect::FormatOutput(k2::Array1<int>*, k2::Array1<int>*):267 Num-states 25258280 exceeds limit 15000000, decreasing beam from 10.000000 to 7.500000
[I] /exp/pzelasko/k2/k2/csrc/intersect_dense.cu:k2::FsaVec k2::MultiGraphDenseIntersect::FormatOutput(k2::Array1<int>*, k2::Array1<int>*):267 Num-states 72044903 exceeds limit 15000000, decreasing beam from 10.000000 to 4.562932
[I] /exp/pzelasko/k2/k2/csrc/intersect_dense.cu:k2::FsaVec k2::MultiGraphDenseIntersect::FormatOutput(k2::Array1<int>*, k2::Array1<int>*):267 Num-states 17092940 exceeds limit 15000000, decreasing beam from 4.562932 to 3.422199

I'll let this one run and see what happens. I wonder if it makes sense to used k2.intersect_dense_pruned instead for this topology.

@danpovey
Copy link
Contributor

danpovey commented Mar 14, 2021 via email

@pzelasko
Copy link
Collaborator Author

I am getting an error during decoding graph composition, @csukuangfj @qindazhu @danpovey can you suggest what would be the right approach to debugging it (or what could be the cause)? Also FYI it's taking quite a long time to compose H o LG with this topo (2h+)

2021-03-16 11:03:49,572 INFO [graph.py:50] LG shape = (18560454, None)
2021-03-16 11:03:49,573 INFO [graph.py:51] Connecting det(L*G)
2021-03-16 11:03:49,573 INFO [graph.py:53] LG shape = (18560454, None)
2021-03-16 11:03:49,573 INFO [graph.py:54] Removing disambiguation symbols on L*G
2021-03-16 11:03:49,949 INFO [graph.py:60] Removing epsilons
2021-03-16 11:04:26,037 INFO [graph.py:62] LG shape = (16546647, None)
2021-03-16 11:04:26,038 INFO [graph.py:63] Connecting rm-eps(det(L*G))
2021-03-16 11:04:32,251 INFO [graph.py:65] LG shape = (11282948, None)
2021-03-16 11:04:33,064 INFO [graph.py:68] Arc sorting LG
2021-03-16 11:04:36,200 INFO [graph.py:71] Composing ctc_topo LG
[F] /exp/pzelasko/k2/k2/csrc/host/intersect.cc:bool k2host::Intersection::GetOutput(k2host::Fsa*, int32_t*, int32_t*):191 Check failed: arcs_.size() == c->size2 (4709562396 vs. 414595100)


[ Stack-Trace: ]
/exp/pzelasko/k2/build/lib/libk2_log.so(k2::internal::GetStackTrace()+0x34) [0x2aab3d5aa4d4]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::internal::Logger::~Logger()+0x28) [0x2aab38a357c8]
/exp/pzelasko/k2/build/lib/libk2fsa.so(k2host::Intersection::GetOutput(k2host::Fsa*, int*, int*)+0x27f) [0x2aab3d37070f]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::Intersect(k2::Ragged<k2::Arc>&, int, k2::Ragged<k2::Arc>&, int, bool, k2::Ragged<k2::Arc>*, k2::Array1<int>*, k2::Array1<int>*)+0xb00) [0x2aab38aa0730]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::Intersect(k2::Ragged<k2::Arc>&, int, k2::Ragged<k2::Arc>&, int, bool, k2::Ragged<k2::Arc>*, k2::Array1<int>*, k2::Array1<int>*)+0x12a3) [0x2aab38aa0ed3]
/exp/pzelasko/k2/build/lib/libk2context.so(k2::Intersect(k2::Ragged<k2::Arc>&, int, k2::Ragged<k2::Arc>&, int, bool, k2::Ragged<k2::Arc>*, k2::Array1<int>*, k2::Array1<int>*)+0x1320) [0x2aab38aa0f50]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so(+0x4f661) [0x2aab35153661]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/lib/python3.7/site-packages/_k2.cpython-37m-x86_64-linux-gnu.so(+0x19503) [0x2aab3511d503]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyMethodDef_RawFastCallKeywords+0x274) [0x5555556b9914]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyCFunction_FastCallKeywords+0x21) [0x5555556b9a31]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalFrameDefault+0x4e1d) [0x555555725ebd]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalCodeWithName+0x2f9) [0x555555668829]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyFunction_FastCallKeywords+0x387) [0x5555556b9107]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalFrameDefault+0x14e5) [0x555555722585]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalCodeWithName+0x2f9) [0x555555668829]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyFunction_FastCallKeywords+0x387) [0x5555556b9107]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalFrameDefault+0x14e5) [0x555555722585]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalCodeWithName+0xc30) [0x555555669160]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyFunction_FastCallKeywords+0x387) [0x5555556b9107]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalFrameDefault+0x416) [0x5555557214b6]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_PyEval_EvalCodeWithName+0x2f9) [0x555555668829]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(PyEval_EvalCodeEx+0x44) [0x555555669714]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(PyEval_EvalCode+0x1c) [0x55555566973c]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(+0x22cf14) [0x555555780f14]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(PyRun_FileExFlags+0xa1) [0x55555578b331]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(PyRun_SimpleFileExFlags+0x1c3) [0x55555578b523]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(+0x238655) [0x55555578c655]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(_Py_UnixMain+0x3c) [0x55555578c77c]
/lib64/libc.so.6(__libc_start_main+0xf5) [0x2aaaaaf0d445]
/home/hltcoe/pzelasko/miniconda3/envs/k2env/bin/python(+0x1dcff0) [0x555555730ff0]

@pzelasko
Copy link
Collaborator Author

BTW the training seems to have gone OK

2021-03-16 05:41:50,279 INFO [mmi_att_transformer_train.py:336] Validation average objf: 0.151131 over 481977.0 frames (100.0% kept)
2021-03-16 05:42:09,358 INFO [mmi_att_transformer_train.py:311] batch 3610, epoch 9/10 global average objf: 0.104617 over 26368480.0 frames (100.0% kept), current batch average objf: 0.104625 over 7327 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:42:27,843 INFO [mmi_att_transformer_train.py:311] batch 3620, epoch 9/10 global average objf: 0.104634 over 26441049.0 frames (100.0% kept), current batch average objf: 0.126236 over 7180 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:42:46,441 INFO [mmi_att_transformer_train.py:311] batch 3630, epoch 9/10 global average objf: 0.104630 over 26513812.0 frames (100.0% kept), current batch average objf: 0.101007 over 7290 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:43:05,153 INFO [mmi_att_transformer_train.py:311] batch 3640, epoch 9/10 global average objf: 0.104635 over 26586828.0 frames (100.0% kept), current batch average objf: 0.102357 over 7295 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:43:24,110 INFO [mmi_att_transformer_train.py:311] batch 3650, epoch 9/10 global average objf: 0.104638 over 26660072.0 frames (100.0% kept), current batch average objf: 0.117503 over 7167 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:43:42,966 INFO [mmi_att_transformer_train.py:311] batch 3660, epoch 9/10 global average objf: 0.104630 over 26733403.0 frames (100.0% kept), current batch average objf: 0.103639 over 7249 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:44:01,469 INFO [mmi_att_transformer_train.py:311] batch 3670, epoch 9/10 global average objf: 0.104624 over 26806061.0 frames (100.0% kept), current batch average objf: 0.095448 over 7148 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:44:20,269 INFO [mmi_att_transformer_train.py:311] batch 3680, epoch 9/10 global average objf: 0.104619 over 26879448.0 frames (100.0% kept), current batch average objf: 0.104208 over 7357 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:44:38,788 INFO [mmi_att_transformer_train.py:311] batch 3690, epoch 9/10 global average objf: 0.104607 over 26952162.0 frames (100.0% kept), current batch average objf: 0.092308 over 7318 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:44:57,588 INFO [mmi_att_transformer_train.py:311] batch 3700, epoch 9/10 global average objf: 0.104614 over 27025008.0 frames (100.0% kept), current batch average objf: 0.091770 over 7417 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:45:16,313 INFO [mmi_att_transformer_train.py:311] batch 3710, epoch 9/10 global average objf: 0.104628 over 27098069.0 frames (100.0% kept), current batch average objf: 0.091939 over 7398 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:45:35,197 INFO [mmi_att_transformer_train.py:311] batch 3720, epoch 9/10 global average objf: 0.104638 over 27171457.0 frames (100.0% kept), current batch average objf: 0.098617 over 7103 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:45:53,545 INFO [mmi_att_transformer_train.py:311] batch 3730, epoch 9/10 global average objf: 0.104618 over 27243992.0 frames (100.0% kept), current batch average objf: 0.089249 over 7309 frames (100.0% kept) avg time waiting for batch 0.003s
2021-03-16 05:45:54,641 INFO [common.py:156] Save checkpoint to exp-conformer-noam-mmi-att-musan-hmm/epoch-9.pt: epoch=9, learning_rate=0.00034101112768565475, objf=0.10461679651370448, valid_objf=0.15113134427210284
2021-03-16 05:45:55,822 INFO [common.py:199] write training info to exp-conformer-noam-mmi-att-musan-hmm/epoch-9-info


Args:
tokens:
A list of token int IDs, e.g., phones, characters, etc.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is probably an issue in the baseline, but we shuold be clear whether this list is supposed to contain zero, or perhaps should not contain zero.

min_token_id = min(tokens)
followup_tokens = list(range(
len(tokens) + min_token_id,
2 * len(tokens) + min_token_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you making an assumption here that tokens is contiguous?

for i in range(0, len(tokens)):
for j in range(0, len(tokens)):
if i != j:
arcs += [f'{i + 1} {j + 1} {tokens[i]} {tokens[i]} 0.0']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be tokens[j] and tokens[j], instead of tokens[i] and tokens[i]?

@csukuangfj
Copy link
Collaborator

csukuangfj commented Mar 25, 2021

I am getting an error during decoding graph composition,

Have you fixed it?

@pzelasko
Copy link
Collaborator Author

Sorry, had to de-prioritize it to take care of other stuff. I will eventually get back to it.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants