From 49cd5eddb3cd12c011119702819952d213ce84b8 Mon Sep 17 00:00:00 2001 From: Daniel Povey Date: Wed, 2 Jun 2021 12:33:23 +0800 Subject: [PATCH] fix get_texts() to work with recent changes in k2. --- snowfall/common.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/snowfall/common.py b/snowfall/common.py index d5dc393d..ff460198 100755 --- a/snowfall/common.py +++ b/snowfall/common.py @@ -12,6 +12,7 @@ from typing import Any, Dict, Iterable, List, Optional, TextIO, Tuple, Union import k2 +import k2.ragged as k2r import kaldialign import torch import torch.distributed as dist @@ -314,18 +315,28 @@ def get_texts(best_paths: k2.Fsa, indices: Optional[torch.Tensor] = None) -> Lis decoded. ''' # remove any 0's or -1's (there should be no 0's left but may be -1's.) - aux_labels = k2.ragged.remove_values_leq(best_paths.aux_labels, 0) - aux_shape = k2.ragged.compose_ragged_shapes(best_paths.arcs.shape(), - aux_labels.shape()) - # remove the states and arcs axes. - aux_shape = k2.ragged.remove_axis(aux_shape, 1) - aux_shape = k2.ragged.remove_axis(aux_shape, 1) - aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + + if isinstance(best_paths.aux_labels, k2.RaggedInt): + aux_labels = k2r.remove_values_leq(best_paths.aux_labels, 0) + aux_shape = k2r.compose_ragged_shapes(best_paths.arcs.shape(), + aux_labels.shape()) + + # remove the states and arcs axes. + aux_shape = k2r.remove_axis(aux_shape, 1) + aux_shape = k2r.remove_axis(aux_shape, 1) + aux_labels = k2.RaggedInt(aux_shape, aux_labels.values()) + else: + # remove axis corresponding to states. + aux_shape = k2r.remove_axis(best_paths.arcs.shape(), 1) + aux_labels = k2.RaggedInt(aux_shape, best_paths.aux_labels) + # remove 0's and -1's. + aux_labels = k2r.remove_values_leq(aux_labels, 0) + assert (aux_labels.num_axes() == 2) - aux_labels, _ = k2.ragged.index(aux_labels, + aux_labels, _ = k2r.index(aux_labels, invert_permutation(indices).to(dtype=torch.int32, device=best_paths.device)) - return k2.ragged.to_list(aux_labels) + return k2r.to_list(aux_labels) def invert_permutation(indices: torch.Tensor) -> torch.Tensor: