Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix lattice length of rnnt_decode #1089

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions k2/csrc/rnnt_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,14 @@ void RnntDecodingStreams::TerminateAndFlushToStreams() {
NVTX_RANGE(K2_FUNC);
// return directly if already detached or no frames decoded.
if (!attached_ || prev_frames_.empty()) return;

// We do this extra Advance to get arcs point to the super-final state.
const Array2<float> dummy_logprobs(c_,
states_.TotSize(1),
config_.vocab_size,
0);
Advance(dummy_logprobs);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current implementation will cause an error when decoding chunk by chunk, I think you need to delete the states gennerated by dummy advance before flushing back to streams.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will say, a dummy advance is a good idea to fix this issue.

Copy link
Collaborator

@pkufool pkufool Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW, we'd better make this dummy advance run only once per sequence (i.e. the last chunk), If we do it for every chunk, it will be a big overhead.


std::vector<Ragged<int64_t>> states;
std::vector<Ragged<double>> scores;
Unstack(states_, 0, &states);
Expand Down Expand Up @@ -682,16 +690,19 @@ void RnntDecodingStreams::GatherPrevFrames(
Array1<int32_t> stream2t_row_splits(GetCpuContext(), num_frames.size() + 1);

for (size_t i = 0; i < num_frames.size(); ++i) {
stream2t_row_splits.Data()[i] = num_frames[i];
K2_CHECK_LE(num_frames[i],
// + 1 for the last dummy_logprobs.
stream2t_row_splits.Data()[i] = num_frames[i] + 1;
K2_CHECK_LE(num_frames[i] + 1,
static_cast<int32_t>(srcs_[i]->prev_frames.size()));
for (int32_t j = 0; j < num_frames[i]; ++j) {

// + 1 for the last dummy_logprobs.
for (int32_t j = 0; j < num_frames[i] + 1; ++j) {
frames_ptr.push_back(srcs_[i]->prev_frames[j].get());
}
}

// frames has a shape of [t][state][arc],
// its Dim0() equals std::sum(num_frames)
// its Dim0() equals std::sum(num_frames) + num_frames.size()
auto frames = Stack(0, frames_ptr.size(), frames_ptr.data());

stream2t_row_splits = stream2t_row_splits.To(c_);
Expand Down