Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow_partial for intersect_dense_pruned #1087

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions k2/csrc/fsa_algo.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest,
@param[in] b_fsas Input FSAs that correspond to neural network
outputs (see documentation in fsa.h).
@param[in] search_beam Beam for frame-synchronous beam pruning,
e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be
modified by {min,max}_active which dictate the minimum
or maximum allowed number of active states per frame.
e.g. 20. Smaller is faster, larger is more exact
(less pruning). This is the default value; it may be
modified by {min,max}_active which dictate the minimum
or maximum allowed number of active states per frame.
@param[in] output_beam Beam with which we prune the output (analogous
to lattice-beam in Kaldi), e.g. 8. We discard arcs in
the output that are not on a path that's within
Expand All @@ -178,6 +178,11 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest,
of states are active. The hash size used per FSA is 4
times (this rounded up to a power of 2), so this
affects memory consumption.
@param [in] allow_partial If true and there was no final state active,
we will treat all the states on the last frame
to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
@param[out] out Output vector of composed, pruned FSAs, with same
glynpu marked this conversation as resolved.
Show resolved Hide resolved
Dim0() as b_fsas. Elements of it may be empty if the
composition was empty, either intrinsically or due to
Expand All @@ -196,6 +201,7 @@ void AddEpsilonSelfLoops(FsaOrVec &src, FsaOrVec *dest,
void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas,
float search_beam, float output_beam,
int32_t min_active_states, int32_t max_active_states,
bool allow_partial,
FsaVec *out, Array1<int32_t> *arc_map_a,
Array1<int32_t> *arc_map_b);

Expand Down
72 changes: 67 additions & 5 deletions k2/csrc/intersect_dense_pruned.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,24 @@ class MultiGraphDenseIntersectPruned {
intersection/composition task. This is advisory,
in that it will try not to exceed that but may not
always succeed. This determines the hash size.
@param [in] allow_partial If true and there was no final state active,
we will treat all the states on the last frame
to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.

*/
MultiGraphDenseIntersectPruned(FsaVec &a_fsas, DenseFsaVec &b_fsas,
float search_beam, float output_beam,
int32_t min_active, int32_t max_active)
int32_t min_active, int32_t max_active,
bool allow_partial)
: a_fsas_(a_fsas),
b_fsas_(b_fsas),
search_beam_(search_beam),
output_beam_(output_beam),
min_active_(min_active),
max_active_(max_active),
allow_partial_(allow_partial),
dynamic_beams_(a_fsas.Context(), b_fsas.shape.Dim0(), search_beam),
forward_semaphore_(1) {
NVTX_RANGE(K2_FUNC);
Expand Down Expand Up @@ -498,12 +506,27 @@ class MultiGraphDenseIntersectPruned {
int32_t dest_state_idx012 = oarc_idx01x_next +
arc_info.u.dest_info_state_idx1;
arc.dest_state = dest_state_idx012 - oarc_idx0xx;
arc.label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label;
int32_t arc_label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label;
arc.label = arc_label;
int32_t final_t = b_fsas_row_splits1[oarc_idx0+1] - b_fsas_row_splits1[oarc_idx0];
if (t == final_t - 1 && arc_label != -1) {
if (allow_partial_) {
arc.label = -1;
} else {
// Unreachable code.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this branch unreachable, if I understand correctly, this branch should do nothing instead of raising fatal error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For arcs pointing to super-final state, their labels must be -1 if allow_partial==false.
Just add this "else" branch to trigger some un-realized bug in the future.

K2_LOG(FATAL) <<
"arc.labe != -1 on final_arc when allow_partial==false.";
}
}

int32_t fsa_id = oarc_idx0,
b_fsas_idx0x = b_fsas_row_splits1[fsa_id],
b_fsas_idx01 = b_fsas_idx0x + t,
b_fsas_idx2 = (arc.label + 1),
// Use arc_label instead of arc.label to keep track of
// the origial arc index in b_fsas when allow_partial == true.
// Then arc_map_b storages the "correct" arc index instead of
// the non-exist manually added arc pointing to super-final state.
b_fsas_idx2 = (arc_label + 1),
b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2;

arc.score = arc_info.arc_loglike;
Expand Down Expand Up @@ -664,6 +687,9 @@ class MultiGraphDenseIntersectPruned {
const int32_t *ai_row_ids2 = ai_shape.RowIds(2).Data();
// from state_idx01 to arc_idx01x
const int32_t *ai_row_splits2 = ai_shape.RowSplits(2).Data();

const int32_t *a_fsas_row_splits1 = a_fsas_.shape.RowSplits(1).Data();
const int32_t *a_fsas_row_ids1 = a_fsas_.shape.RowIds(1).Data();
// from state_idx01 (into a_fsas_) to arc_idx01x (into a_fsas_)
const int32_t *a_fsas_row_splits2 = a_fsas_.shape.RowSplits(2).Data();

Expand All @@ -679,6 +705,29 @@ class MultiGraphDenseIntersectPruned {
Ragged<ArcInfo> ai(ai_shape);
ArcInfo *ai_data = ai.values.Data(); // uninitialized

// A valid final arc means its label == -1.
auto has_valid_final_arc = Array1<bool>(c_, NumFsas(), false);
bool *has_valid_final_arc_data = has_valid_final_arc.Data();

if (allow_partial_) {
K2_EVAL(
c_, ai.values.Dim(), set_has_non_inf_arc, (int32_t ai_arc_idx012)->void {
int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012],
ai_fsa_idx0 = ai_row_ids1[ai_state_idx01],
ai_arc_idx01x = ai_row_splits2[ai_state_idx01],
ai_arc_idx2 = ai_arc_idx012 - ai_arc_idx01x;
StateInfo sinfo = state_values[ai_state_idx01];
int32_t a_fsas_arc_idx01x =
a_fsas_row_splits2[sinfo.a_fsas_state_idx01],
a_fsas_arc_idx012 = a_fsas_arc_idx01x + ai_arc_idx2;
Arc arc = arcs[a_fsas_arc_idx012];
auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0];
if (final_t - 1 == t && -1 == arc.label) {
has_valid_final_arc_data[ai_fsa_idx0] = true;
}
});
}

K2_EVAL(
c_, ai.values.Dim(), ai_lambda, (int32_t ai_arc_idx012)->void {
int32_t ai_state_idx01 = ai_row_ids2[ai_arc_idx012],
Expand All @@ -698,6 +747,17 @@ class MultiGraphDenseIntersectPruned {
K2_DCHECK_LT(static_cast<uint32_t>(scores_idx2),
static_cast<uint32_t>(scores_num_cols));
float acoustic_score = scores_acc(scores_idx01, scores_idx2);
auto dest_state = arc.dest_state;
auto final_t = b_fsas_row_splits1[ai_fsa_idx0+1] - b_fsas_row_splits1[ai_fsa_idx0];
if (final_t - 1 == t && !has_valid_final_arc_data[ai_fsa_idx0] &&
allow_partial_) {
int32_t a_fsas_idx0 = a_fsas_row_ids1[sinfo.a_fsas_state_idx01];
// state_idx1 is 0-based.
// So "-1" is used when calculating a_fsas_final_state_idx1.
int32_t a_fsas_final_state_idx1 = a_fsas_row_splits1[a_fsas_idx0 + 1] - 1 - a_fsas_row_splits1[a_fsas_idx0];
dest_state = a_fsas_final_state_idx1;
acoustic_score = 0.0;
}
Copy link
Collaborator

@pkufool pkufool Nov 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this block and the above block are not necessary, we can know which sequence has no final state by the shape of ArcInfo at the last frame (see first K2_EVAL in FormatOutput), at the last frame, there will be only one state (the final state) or no state at all.
Another thing, if we modify ai.u.dest_a_fsas_state_idx01 here, it will mess the arc-info and might raise an error for chunk by chunk decoding. Actually, if we know which sequcence has no final-arc, we can set the dest-state of the arcs at the last frame to the extra state we added (see first K2_EVAL in FormatOutput) without knowing ai.u.dest_info_state_idx1.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The last frame log_probs is manually added [0.0, -inf, -inf, -inf, ..., -inf, -inf].
The main purpose of this block is setting those -inf to 0.0.
Or all active arcs will be pruned by function PruneTimeRange

This is also the reason when the input num_frames == 20, while the generated lattice length is only 10!
Not only the final arc is missing, but also the last "10" frames are pruned by PruneTimeRange .

Before fix, lattice length is 10.

0 1 0 0 -8.41582e-05                                                                                                                                                        
1 2 0 0 -4.69674e-05                                                                                                                                                        
2 3 0 0 -9.17907e-06                                                                                                                                                        
3 4 0 0 -1.10864e-05                                                                                                                                                        
4 5 0 0 -1.4305e-05                                                                                                                                                         
5 6 0 0 -2.63449e-05                                                                                                                                                        
6 7 0 0 -6.55649e-06                                                                                                                                                        
7 8 0 0 -1.19209e-05                                                                                                                                                        
8 9 0 0 -3.12323e-05                                                                                                                                                        
9 10 0 0 -2.36032e-05                                                                                                                                                       
11

After fix the length is 20.
image

ArcInfo ai;
ai.a_fsas_arc_idx012 = a_fsas_arc_idx012;
ai.arc_loglike = acoustic_score + arc.score;
Expand All @@ -709,7 +769,7 @@ class MultiGraphDenseIntersectPruned {
// convert to an idx01; this relies on the fact that
// sinfo.abs_state_id == arc.src_state + a_fsas_fsa_idx0x.
ai.u.dest_a_fsas_state_idx01 =
sinfo.a_fsas_state_idx01 + arc.dest_state - arc.src_state;
sinfo.a_fsas_state_idx01 + dest_state - arc.src_state;
ai_data[ai_arc_idx012] = ai;
});
return ai;
Expand Down Expand Up @@ -1459,6 +1519,7 @@ class MultiGraphDenseIntersectPruned {
float output_beam_;
int32_t min_active_;
int32_t max_active_;
bool allow_partial_;
Array1<float> dynamic_beams_; // dynamic beams (initially just search_beam_
// but change due to max_active/min_active
// constraints).
Expand Down Expand Up @@ -1521,13 +1582,14 @@ class MultiGraphDenseIntersectPruned {
void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas,
float search_beam, float output_beam,
int32_t min_active_states, int32_t max_active_states,
bool allow_partial,
FsaVec *out, Array1<int32_t> *arc_map_a,
Array1<int32_t> *arc_map_b) {
NVTX_RANGE("IntersectDensePruned");
FsaVec a_vec = FsaToFsaVec(a_fsas);
MultiGraphDenseIntersectPruned intersector(a_vec, b_fsas, search_beam,
output_beam, min_active_states,
max_active_states);
max_active_states, allow_partial);

intersector.Intersect();
intersector.FormatOutput(out, arc_map_a, arc_map_b);
Expand Down
22 changes: 15 additions & 7 deletions k2/csrc/intersect_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ TEST(Intersect, RandomSingle) {
K2_LOG(INFO) << "fsas_b = " << fsas_b;
FsaVec out_fsas2;
Array1<int32_t> arc_map_a2, arc_map_b2;
// IntersectDensePruned() treats epsilons as normal symbols, so we need to
// IntersectDense() treats epsilons as normal symbols, so we need to
// as well.

ArcSort(&fsa); // CAUTION if you later test the arc_maps: we arc-sort here,
Expand Down Expand Up @@ -339,7 +339,7 @@ TEST(Intersect, RandomFsaVec) {
K2_LOG(INFO) << "fsas_b = " << fsas_b;
FsaVec out_fsas2;
Array1<int32_t> arc_map_a2, arc_map_b2;
// IntersectDensePruned() treats epsilons as normal symbols, so we need to
// IntersectDense() treats epsilons as normal symbols, so we need to
// as well.

ArcSort(&fsavec); // CAUTION if you later test the arc_maps: we arc-sort
Expand Down Expand Up @@ -401,11 +401,12 @@ TEST(IntersectPruned, Simple) {

float beam = 100000;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -458,11 +459,12 @@ TEST(IntersectPruned, TwoDense) {

float beam = 100000;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -507,11 +509,12 @@ TEST(IntersectPruned, TwoFsas) {

float beam = 100000;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

FsaVec out_fsas;
Array1<int32_t> arc_map_a, arc_map_b;
IntersectDensePruned(fsa_vec, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;

Expand Down Expand Up @@ -575,8 +578,10 @@ TEST(IntersectPruned, RandomSingle) {
FsaVec out_fsas;
float beam = 1000.0;
int32_t max_active = 10000, min_active = 0;
bool allow_partial = false;

IntersectDensePruned(fsa, dfsavec, beam, beam, min_active, max_active,
&out_fsas, &arc_map_a, &arc_map_b);
allow_partial, &out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas << ", arc_map_b = " << arc_map_b;

FsaVec fsas_b = ConvertDenseToFsaVec(dfsavec);
Expand Down Expand Up @@ -679,8 +684,11 @@ TEST(IntersectPruned, RandomFsaVec) {
FsaVec out_fsas;
float search_beam = 1000.0, output_beam = 1000.0;
int32_t min_active = 0, max_active = 10;
bool allow_partial = false;

IntersectDensePruned(fsavec, dfsavec, search_beam, output_beam, min_active,
max_active, &out_fsas, &arc_map_a, &arc_map_b);
max_active, allow_partial,
&out_fsas, &arc_map_a, &arc_map_b);
K2_LOG(INFO) << "out_fsas = " << out_fsas
<< ", arc_map_a = " << arc_map_a
<< ", arc_map_b = " << arc_map_b;
Expand Down
8 changes: 5 additions & 3 deletions k2/python/csrc/torch/fsa_algo.cu
Original file line number Diff line number Diff line change
Expand Up @@ -200,21 +200,23 @@ static void PybindIntersectDensePruned(py::module &m) {
"intersect_dense_pruned",
[](FsaVec &a_fsas, DenseFsaVec &b_fsas, float search_beam,
float output_beam, int32_t min_active_states,
int32_t max_active_states)
int32_t max_active_states, bool allow_partial)
-> std::tuple<FsaVec, torch::Tensor, torch::Tensor> {
DeviceGuard guard(a_fsas.Context());
Array1<int32_t> arc_map_a;
Array1<int32_t> arc_map_b;
FsaVec out;

IntersectDensePruned(a_fsas, b_fsas, search_beam, output_beam,
min_active_states, max_active_states, &out,
min_active_states, max_active_states,
allow_partial, &out,
&arc_map_a, &arc_map_b);
return std::make_tuple(out, ToTorch(arc_map_a), ToTorch(arc_map_b));
},
py::arg("a_fsas"), py::arg("b_fsas"), py::arg("search_beam"),
py::arg("output_beam"), py::arg("min_active_states"),
py::arg("max_active_states"));
py::arg("max_active_states"),
py::arg("allow_partial") = false);
}

static void PybindIntersectDense(py::module &m) {
Expand Down
30 changes: 22 additions & 8 deletions k2/python/k2/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ def forward(ctx,
output_beam: float,
min_active_states: int,
max_active_states: int,
allow_partial: bool,
unused_scores_a: torch.Tensor,
unused_scores_b: torch.Tensor,
seqframe_idx_name: Optional[str] = None,
Expand All @@ -383,16 +384,21 @@ def forward(ctx,
output_beam:
Pruning beam for the output of intersection (vs. best path);
equivalent to kaldi's lattice-beam. E.g. 8.
max_active_states:
Maximum number of FSA states that are allowed to be active on any
given frame for any given intersection/composition task. This is
advisory, in that it will try not to exceed that but may not always
succeed. You can use a very large number if no constraint is needed.
min_active_states:
Minimum number of FSA states that are allowed to be active on any
given frame for any given intersection/composition task. This is
advisory, in that it will try not to have fewer than this number
active. Set it to zero if there is no constraint.
max_active_states:
Maximum number of FSA states that are allowed to be active on any
given frame for any given intersection/composition task. This is
advisory, in that it will try not to exceed that but may not always
succeed. You can use a very large number if no constraint is needed.
allow_partial If true and there was no final state active,
we will treat all the states on the
last frame to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
unused_scores_a:
It equals to `a_fsas.scores` and its sole purpose is for back
propagation.
Expand All @@ -418,7 +424,8 @@ def forward(ctx,
search_beam=search_beam,
output_beam=output_beam,
min_active_states=min_active_states,
max_active_states=max_active_states)
max_active_states=max_active_states,
allow_partial=allow_partial)

out_fsa[0] = Fsa(ragged_arc)

Expand Down Expand Up @@ -653,7 +660,8 @@ def intersect_dense_pruned(a_fsas: Fsa,
min_active_states: int,
max_active_states: int,
seqframe_idx_name: Optional[str] = None,
frame_idx_name: Optional[str] = None) -> Fsa:
frame_idx_name: Optional[str] = None,
allow_partial: bool = False) -> Fsa:
'''Intersect array of FSAs on CPU/GPU.

Caution:
Expand Down Expand Up @@ -684,6 +692,11 @@ def intersect_dense_pruned(a_fsas: Fsa,
frame for any given intersection/composition task. This is advisory,
in that it will try not to exceed that but may not always succeed.
You can use a very large number if no constraint is needed.
allow_partial If true and there was no final state active,
we will treat all the states on the
last frame to be final state. If false, we only
care about the real final state in the decoding
graph on the last frame when generating lattice.
seqframe_idx_name:
If set (e.g. to 'seqframe'), an attribute in the output will be created
that encodes the sequence-index and the frame-index within that
Expand Down Expand Up @@ -717,7 +730,8 @@ def intersect_dense_pruned(a_fsas: Fsa,
# in `out_fsa[0].scores`
_IntersectDensePrunedFunction.apply(a_fsas, b_fsas, out_fsa, search_beam,
output_beam, min_active_states,
max_active_states, a_fsas.scores,
max_active_states, allow_partial,
a_fsas.scores,
b_fsas.scores, seqframe_idx_name,
frame_idx_name)
return out_fsa[0]
Expand Down