Skip to content
This repository has been archived by the owner on Oct 13, 2022. It is now read-only.

Commit

Permalink
fix get_texts() to work with recent changes in k2.
Browse files Browse the repository at this point in the history
  • Loading branch information
danpovey committed Jun 2, 2021
1 parent 82bbd54 commit 49cd5ed
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions snowfall/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Any, Dict, Iterable, List, Optional, TextIO, Tuple, Union

import k2
import k2.ragged as k2r
import kaldialign
import torch
import torch.distributed as dist
Expand Down Expand Up @@ -314,18 +315,28 @@ def get_texts(best_paths: k2.Fsa, indices: Optional[torch.Tensor] = None) -> Lis
decoded.
'''
# remove any 0's or -1's (there should be no 0's left but may be -1's.)
aux_labels = k2.ragged.remove_values_leq(best_paths.aux_labels, 0)
aux_shape = k2.ragged.compose_ragged_shapes(best_paths.arcs.shape(),
aux_labels.shape())
# remove the states and arcs axes.
aux_shape = k2.ragged.remove_axis(aux_shape, 1)
aux_shape = k2.ragged.remove_axis(aux_shape, 1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())

if isinstance(best_paths.aux_labels, k2.RaggedInt):
aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0)
aux_shape = k2r.compose_ragged_shapes(best_paths.arcs.shape(),
aux_labels.shape())

# remove the states and arcs axes.
aux_shape = k2r.remove_axis(aux_shape, 1)
aux_shape = k2r.remove_axis(aux_shape, 1)
aux_labels = k2.RaggedInt(aux_shape, aux_labels.values())
else:
# remove axis corresponding to states.
aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1)
aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels)
# remove 0's and -1's.
aux_labels = k2r.remove_values_leq(aux_labels, 0)

assert (aux_labels.num_axes() == 2)
aux_labels, _ = k2.ragged.index(aux_labels,
aux_labels, _ = k2r.index(aux_labels,
invert_permutation(indices).to(dtype=torch.int32,
device=best_paths.device))
return k2.ragged.to_list(aux_labels)
return k2r.to_list(aux_labels)


def invert_permutation(indices: torch.Tensor) -> torch.Tensor:
Expand Down

0 comments on commit 49cd5ed

Please sign in to comment.