Skip to content

Commit

Permalink
Add GetFinalFrame; Fix the online decoding issue
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Aug 7, 2023
1 parent ee49740 commit 3430ffe
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 67 deletions.
178 changes: 118 additions & 60 deletions k2/csrc/intersect_dense_pruned.cu
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,7 @@ class MultiGraphDenseIntersectPruned {
const std::vector<std::unique_ptr<FrameInfo>>* OnlineIntersect(
DenseFsaVec *b_fsas,
std::vector<std::unique_ptr<FrameInfo>> &frames,
Array1<float> &beams,
Array1<bool> &is_final) {
/*
T is the largest number (frames+1) of neural net output currently
received, or the largest number of frames of log-likelihoods we count the
final frame with (0, -inf, -inf..) that is used for the final-arc.
The largest number of states in the fsas represented by b_fsas equals
T+1 (e.g. 1 frame would require 2 states, because that 1 frame is the arc
from state 0 to state 1). So the #states is 2 greater than the actual
number of frames in the neural-net output.
*/
Array1<float> &beams) {
K2_CHECK(online_decoding_);
K2_CHECK(c_->IsCompatible(*b_fsas->Context()));
K2_CHECK_EQ(a_fsas_.shape.Dim0(), 1);
Expand All @@ -277,17 +267,21 @@ class MultiGraphDenseIntersectPruned {
b_fsas_ = b_fsas;
frames_.swap(frames);
dynamic_beams_ = beams.To(c_);
is_final_ = is_final.To(c_);
T_ = frames_.size() - 1;

// -1 here because we already put the initial frame info to frames_
int32_t T = T_ + b_fsas_->shape.MaxSize(1);
// T_ is the actual number of frames we have already processed in previous
// chunks, -1 here because frames_ includes the initial frame.
T_ = frames_.size() - 1;
// -1 here because we add extra frame to b_fsas_ (to handle -1 arc)
// see dense_fsa_vec.py for more details of converting nnet_outputs to fsas.
int32_t chunk_size = b_fsas_->shape.MaxSize(1) - 1;
int32_t T = T_ + chunk_size;

// we'll initially populate frames_[0.. T+1], but discard the one at T+1,
// which has no arcs or states, the ones we use are from 0 to T.
frames_.reserve(T + 2);
// plus initial frame, we actually have T + 1 frames.
frames_.reserve(T + 1);

for (int32_t t = 0; t <= b_fsas_->shape.MaxSize(1); t++) {
// we only do PropagateForward for real frames(i.e. not including the extra
// frame we added to b_fsas_.
for (int32_t t = 0; t < chunk_size; t++) {
if (state_map_.NumKeyBits() == 32) {
frames_.push_back(PropagateForward<32>(t, frames_.back().get()));
} else if (state_map_.NumKeyBits() == 36) {
Expand All @@ -296,25 +290,12 @@ class MultiGraphDenseIntersectPruned {
K2_CHECK_EQ(state_map_.NumKeyBits(), 40);
frames_.push_back(PropagateForward<40>(t, frames_.back().get()));
}
if (t == b_fsas_->shape.MaxSize(1)) {
int32_t start = std::max<int32_t>(0, T_ - 3);
PruneTimeRange(start, T_ + t - 1);
PruneTimeRange(T_ + t - 1, T_ + t);
if (t == chunk_size - 1) {
int32_t start = std::max<int32_t>(0, T_ - 2);
PruneTimeRange(start, T_ + t + 1);
}
}
// The FrameInfo for time T+1 will have no states. We did that
// last PropagateForward so that the 'arcs' member of frames_[T]
// is set up (it has no arcs but we need the shape).
frames_.pop_back();

int32_t history_t = T_;

T_ = T;
// partial_final_frame_ is the last frame to generate partial result,
// but it should not be the start frame of next chunk decoding.
partial_final_frame_ = std::move(frames_.back());
frames_.pop_back();

const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data();
int32_t *final_t_data = final_t_.Data();

Expand All @@ -325,9 +306,104 @@ class MultiGraphDenseIntersectPruned {
b_fsas_row_splits1[i + 1] - b_fsas_row_splits1[i];
final_t_data[i] = history_t + b_chunk_size;
});

// T_ will be used in FormatOutput, plus 1 here because we need an extra
// frame for final arcs (i.e. the partial_final_frame return by
// GetFinalFrame()) to construct the lattice.
T_ = T + 1;
return &frames_;
}

/* Propagate the last frame in b_fsas_(i.e. the extra frame containing only 0
and -infs). See dense_fsa_vec.py to get more details of b_fsas_.
The purpose of this function is to get the final states to construct
partial results for online decoding. It suppose to be invoked in
FormatOutput when online_decoding_ is True.
This function returns the final FrameInfo needed by the FormatOutput. The
final_frame->states contains the final state for each sequence (if it has),
the final_frame->arcs actually contains no arc at all, but we need its
shape.
This function also adds the arcs to frames_.back(), normally the arcs of
frames_.back() will be populated in next ForwardPass, we populate it here
so that we can get valid fsas in FormatOutput. It will not affect the
ForwardPass because the ForwardPass only need the states in frames_.back().
Actually we will re-expand the arcs in frames_.back() in the next
ForwardPass.
*/
std::unique_ptr<FrameInfo> GetFinalFrame() {
K2_CHECK(online_decoding_);

// chunk_size is the index of the added extra frame.
int32_t chunk_size = b_fsas_->shape.MaxSize(1) - 1;
FrameInfo *cur_frame = frames_.back().get();

// These are all of the expanded arcs, actually we only need the arcs
// pointing to the final states.
auto arcs = GetArcs(chunk_size, cur_frame);

int32_t num_fsas = NumFsas();

// Number of final states for each sequence, should be 0 or 1.
Array1<int32_t> num_final_states(c_, num_fsas + 1, 0);
// Keep the arcs pointing to final states.
Renumbering renumber_arcs(c_, arcs.NumElements());
char *keep_this_arc_data = renumber_arcs.Keep().Data();
const int32_t *arcs_row_ids1_data = arcs.RowIds(1).Data(),
*arcs_row_ids2_data = arcs.RowIds(2).Data(),
*fsa_row_split1_data = a_fsas_.RowSplits(1).Data();
int32_t *num_final_states_data = num_final_states.Data();
ArcInfo *arcs_data = arcs.values.Data();

K2_EVAL(
c_, arcs.NumElements(), lambda_renumber_arc, (int32_t idx012) -> void {
int32_t idx01 = arcs_row_ids2_data[idx012],
idx0 = arcs_row_ids1_data[idx01];
ArcInfo ai = arcs_data[idx012];
// Arcs pointing to final states have non infinity scores
if (ai.arc_loglike - ai.arc_loglike == 0) {
num_final_states_data[idx0] = 1;
keep_this_arc_data[idx012] = 1;
} else {
keep_this_arc_data[idx012] = 0;
}
});

int32_t num_arcs = renumber_arcs.NumNewElems();
const int32_t *new2old_data = renumber_arcs.New2Old().Data();
Array1<ArcInfo> new_arcs(c_, num_arcs);
ArcInfo *new_arcs_data = new_arcs.Data();

K2_EVAL(c_, num_arcs, lambda_set_new_arcs, (int32_t new_idx012) -> void {
int32_t old_idx012 = new2old_data[new_idx012];
ArcInfo old_ai = arcs_data[old_idx012];
// Only 1 state (the final state) in next frame, so idx1 is always 0.
old_ai.u.dest_info_state_idx1 = 0;
new_arcs_data[new_idx012] = old_ai;
});

auto old2new_rowsplits = renumber_arcs.Old2New(true);
auto old2new_shape = RaggedShape2(&old2new_rowsplits, nullptr, num_arcs);
auto total_shape = ComposeRaggedShapes(arcs.shape, old2new_shape);
auto new_arcs_shape = RemoveAxis(total_shape, 2);
cur_frame->arcs = Ragged<ArcInfo>(new_arcs_shape, new_arcs);

std::unique_ptr<FrameInfo> ans = std::make_unique<FrameInfo>();
ExclusiveSum(num_final_states, &num_final_states);
auto final_state_shape = RaggedShape2(
&num_final_states, nullptr, -1);
// No arcs for final frame, but we need its shape in FormatOutput.
auto state_to_arc_shape = RegularRaggedShape(
c_, final_state_shape.NumElements(), 0);
auto final_arc_shape = ComposeRaggedShapes(
final_state_shape, state_to_arc_shape);
ans->arcs = Ragged<ArcInfo>(final_arc_shape, Array1<ArcInfo>(c_, 0));
return ans;
}


void BackwardPass() {
int32_t num_fsas = b_fsas_->shape.Dim0(),
num_work_items = max_active_ * num_fsas * T_;
Expand Down Expand Up @@ -401,7 +477,9 @@ class MultiGraphDenseIntersectPruned {

bool online_decoding = online_decoding_;
bool allow_partial = allow_partial_;
std::unique_ptr<FrameInfo> partial_final_frame;
if (online_decoding) {
partial_final_frame = std::move(GetFinalFrame());
K2_CHECK(arc_map_a);
K2_CHECK_EQ(arc_map_b, nullptr);
} else {
Expand All @@ -417,10 +495,10 @@ class MultiGraphDenseIntersectPruned {
arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data();
}
arcs_data_ptrs.Data()[T] = online_decoding
? partial_final_frame_->arcs.values.Data()
? partial_final_frame->arcs.values.Data()
: frames_[T]->arcs.values.Data();
arcs_row_splits1_ptrs.Data()[T] =
online_decoding ? partial_final_frame_->arcs.RowSplits(1).Data()
online_decoding ? partial_final_frame->arcs.RowSplits(1).Data()
: frames_[T]->arcs.RowSplits(1).Data();

// transfer to GPU if we're using a GPU
Expand Down Expand Up @@ -484,7 +562,7 @@ class MultiGraphDenseIntersectPruned {
for (int32_t t = 0; t < T; t++)
arcs_shapes[t] = &(frames_[t]->arcs.shape);

arcs_shapes[T] = online_decoding ? &(partial_final_frame_->arcs.shape)
arcs_shapes[T] = online_decoding ? &(partial_final_frame->arcs.shape)
: &(frames_[T]->arcs.shape);

arcs_shapes[T + 1] = &final_arcs_shape;
Expand Down Expand Up @@ -774,12 +852,6 @@ class MultiGraphDenseIntersectPruned {
});
}

bool online_decoding = online_decoding_;
bool *is_final_data = nullptr;
if (online_decoding) {
is_final_data = is_final_.Data();
}

K2_EVAL(
c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void {
int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012],
Expand All @@ -802,14 +874,8 @@ class MultiGraphDenseIntersectPruned {
auto dest_state = arc.dest_state;
auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0];

bool is_final_chunk = false;
if (online_decoding) {
is_final_chunk = is_final_data[ai_fsa_idx0];
}

if (final_t - 1 == t &&
((online_decoding && !is_final_chunk) ||
(allow_partial && !has_valid_final_arc_data[ai_fsa_idx0]))) {
(allow_partial && !has_valid_final_arc_data[ai_fsa_idx0])) {
int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01];
// state_idx1 is 0-based.
// So "-1" is used when calculating a_fsas_final_state_idx1.
Expand Down Expand Up @@ -1021,7 +1087,6 @@ class MultiGraphDenseIntersectPruned {

int32_t dest_a_fsas_state_idx01 = info.u.dest_a_fsas_state_idx01;


uint64_t state_map_idx = dest_a_fsas_state_idx01 +
fsa_id * state_map_fsa_stride;
uint64_t state_idx01;
Expand Down Expand Up @@ -1593,9 +1658,6 @@ class MultiGraphDenseIntersectPruned {

bool online_decoding_; // true for online decoding.
Array1<int32_t> final_t_; // record the final frame id of each DenseFsa.
Array1<bool> is_final_; // For online decoding, it has a dimension of
// b_fsas_->Dim0() indicating whether this is
// the final chunk of current sequence.
std::unique_ptr<FrameInfo> partial_final_frame_; // store the final frame for
// partial results

Expand Down Expand Up @@ -1711,8 +1773,6 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas,

Array1<float> beams(GetCpuContext(), num_seqs);
float *beams_data = beams.Data();
Array1<bool> is_final(GetCpuContext(), num_seqs);
bool *is_final_data = is_final.Data();
for (int32_t i = 0; i < num_seqs; ++i) {
DecodeStateInfo *decode_state_ptr = decode_states->at(i);
K2_CHECK(decode_state_ptr);
Expand All @@ -1730,12 +1790,10 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas,
Array1<ArcInfo>(c_, std::vector<ArcInfo>{ArcInfo()}));

decode_state_ptr->beam = search_beam_;
decode_state_ptr->is_final = false;
}
seq_states_ptr_vec[i] = &(decode_state_ptr->states);
seq_arcs_ptr_vec[i] = &(decode_state_ptr->arcs);
beams_data[i] = decode_state_ptr->beam;
is_final_data[i] = decode_state_ptr->is_final;
}

auto stack_states = Stack(0, num_seqs, seq_states_ptr_vec.data());
Expand Down Expand Up @@ -1764,7 +1822,7 @@ void OnlineDenseIntersecter::Decode(DenseFsaVec &b_fsas,
}

const auto new_frames = impl_->OnlineIntersect(
&b_fsas, frames, beams, is_final);
&b_fsas, frames, beams);

impl_->FormatOutput(ofsa, arc_map_a, nullptr/*arc_map_b*/);

Expand Down
3 changes: 0 additions & 3 deletions k2/csrc/intersect_dense_pruned.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,6 @@ struct DecodeStateInfo {

// current search beam for this sequence
float beam;

// True if the chunk to be decoded is the final chunk
bool is_final;
};


Expand Down
1 change: 0 additions & 1 deletion k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,6 @@ static void PybindDecodeStateInfo(py::module &m) {
py::class_<PyClass> state_info(
m, "DecodeStateInfo");
state_info.def(py::init<>());
state_info.def_readwrite("is_final", &PyClass::is_final);
}

static void PybindOnlineDenseIntersecter(py::module &m) {
Expand Down
2 changes: 2 additions & 0 deletions k2/python/k2/online_dense_intersecter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(
output_beam: float,
min_active_states: int,
max_active_states: int,
allow_partial: bool = True,
) -> None:
"""Create a new online intersecter object.
Args:
Expand Down Expand Up @@ -101,6 +102,7 @@ def __init__(
output_beam,
min_active_states,
max_active_states,
allow_partial=allow_partial,
)

@property
Expand Down
1 change: 0 additions & 1 deletion k2/torch/bin/online_decode.cu
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,6 @@ int main(int argc, char *argv[]) {
if (num_frames[i] <= chunk_size * subsampling_factor) {
num_frame.push_back(num_frames[i]);
num_frames[i] = 0;
states_info[i].is_final = true;
} else {
num_frame.push_back(chunk_size * subsampling_factor);
num_frames[i] -= chunk_size * subsampling_factor;
Expand Down
3 changes: 1 addition & 2 deletions k2/torch/bin/online_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def decode_one_chunk(
current_num_frames.append(stream.num_frames - stream.position)
end = stream.num_frames
stream.position = stream.num_frames
stream.state_info.is_final = True
finised_streams.append(i)
else:
current_num_frames.append(params.chunk_size)
Expand Down Expand Up @@ -264,7 +263,7 @@ def main():
args.subsampling_factor = 4
args.feature_dim = 80
args.num_classes = 500
args.chunk_size = 10
args.chunk_size = 16

wave_list: List[Tuple[str, str]] = []
if args.wav_scp is not None:
Expand Down

0 comments on commit 3430ffe

Please sign in to comment.