diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 9a80c5f9d..9966b8248 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -716,7 +716,7 @@ def get_rnnt_prune_ranges( ) s_range = S + 1 - if is_regular: + if is_regular and S != 0: assert ( s_range >= 2 ), f"""Pruning range for standard RNN-T should be equal to or greater @@ -1079,7 +1079,9 @@ def get_hat_logprobs_pruned( rnnt_type != "modified" or T >= S ), f"Modified transducer requires T >= S, but got T={T} and S={S}" assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type - assert termination_symbol == 0, f"Termination symbol must be 0, but got {termination_symbol}" + assert ( + termination_symbol == 0 + ), f"Termination symbol must be 0, but got {termination_symbol}" # For blank symbol, log-prob is log-sigmoid of the score logp_b = torch.nn.functional.logsigmoid(logits[..., 0]) @@ -1141,7 +1143,9 @@ def get_hat_logprobs_pruned( px = torch.cat( ( px, - torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype), + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), ), dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf..