Skip to content

Commit

Permalink
Fix prune range for empty reference
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Sep 13, 2024
1 parent 21302da commit c0b787f
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,7 @@ def get_rnnt_prune_ranges(
)
s_range = S + 1

if is_regular:
if is_regular and S != 0:
assert (
s_range >= 2
), f"""Pruning range for standard RNN-T should be equal to or greater
Expand Down Expand Up @@ -1079,7 +1079,9 @@ def get_hat_logprobs_pruned(
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
assert termination_symbol == 0, f"Termination symbol must be 0, but got {termination_symbol}"
assert (
termination_symbol == 0
), f"Termination symbol must be 0, but got {termination_symbol}"

# For blank symbol, log-prob is log-sigmoid of the score
logp_b = torch.nn.functional.logsigmoid(logits[..., 0])
Expand Down Expand Up @@ -1141,7 +1143,9 @@ def get_hat_logprobs_pruned(
px = torch.cat(
(
px,
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
Expand Down

0 comments on commit c0b787f

Please sign in to comment.