Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fused_log_softmax option to pruned_rnnt_loss #1293

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,7 +1079,9 @@ def get_hat_logprobs_pruned(
rnnt_type != "modified" or T >= S
), f"Modified transducer requires T >= S, but got T={T} and S={S}"
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
assert termination_symbol == 0, f"Termination symbol must be 0, but got {termination_symbol}"
assert (
termination_symbol == 0
), f"Termination symbol must be 0, but got {termination_symbol}"

# For blank symbol, log-prob is log-sigmoid of the score
logp_b = torch.nn.functional.logsigmoid(logits[..., 0])
Expand Down Expand Up @@ -1141,7 +1143,9 @@ def get_hat_logprobs_pruned(
px = torch.cat(
(
px,
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
Expand Down Expand Up @@ -1184,6 +1188,7 @@ def get_rnnt_logprobs_pruned(
termination_symbol: int,
boundary: Tensor,
rnnt_type: str = "regular",
fused_log_softmax: bool = True,
) -> Tuple[Tensor, Tensor]:
"""Construct px, py for mutual_information_recursion with pruned output.

Expand Down Expand Up @@ -1222,6 +1227,8 @@ def get_rnnt_logprobs_pruned(
*next* context on the *current* frame, e.g. if we emit
c given "a b" context, we are forced to emit "blank"
given "b c" context on the current frame.
fused_log_softmax:
If False, you should call log_softmax outside of loss. Default True.
Returns:
(px, py) (the names are quite arbitrary)::

Expand Down Expand Up @@ -1261,7 +1268,9 @@ def get_rnnt_logprobs_pruned(
assert rnnt_type in ["regular", "modified", "constrained"], rnnt_type
_validate_st_lengths(S, T, rnnt_type == "regular", boundary)

normalizers = torch.logsumexp(logits, dim=3)
normalizers = 0
if fused_log_softmax:
normalizers = torch.logsumexp(logits, dim=3)

symbols_with_terminal = torch.cat(
(
Expand Down Expand Up @@ -1358,6 +1367,7 @@ def rnnt_loss_pruned(
delay_penalty: float = 0.0,
reduction: Optional[str] = "mean",
use_hat_loss: bool = False,
fused_log_softmax: bool = True,
) -> Tensor:
"""A RNN-T loss with pruning, which uses the output of a pruned 'joiner'
network as input, i.e. a 4 dimensions tensor with shape (B, T, s_range, C),
Expand Down Expand Up @@ -1414,6 +1424,8 @@ def rnnt_loss_pruned(
the blank distribution separately as a Bernoulli distribution, and the
non-blanks are modeled as a multinomial. This formulation may be useful
for performing internal LM estimation, as described in the paper.
fused_log_softmax:
If False, you should call log_softmax outside of loss. Default True.
Returns:
If reduction is `none`, returns a tensor of shape (B,), containing the
total RNN-T loss values for each sequence of the batch, otherwise a scalar
Expand All @@ -1427,6 +1439,7 @@ def rnnt_loss_pruned(
termination_symbol=termination_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
fused_log_softmax=fused_log_softmax,
)
else:
px, py = get_hat_logprobs_pruned(
Expand Down
92 changes: 92 additions & 0 deletions k2/python/tests/rnnt_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,6 +568,98 @@ def test_rnnt_loss_pruned(self):
)
print(f"Pruned loss with range {r} : {pruned_loss}")

def test_rnnt_loss_pruned_fused_log_softmax(self):
print("\ntest_rnnt_loss_pruned_fused_log_softmax.")
B = 4
T = 300
S = 50
C = 10

frames = torch.randint(S, T, (B,))
seq_length = torch.randint(3, S - 1, (B,))
T = torch.max(frames)
S = torch.max(seq_length)

am_ = torch.randn((B, T, C), dtype=torch.float64)
lm_ = torch.randn((B, S + 1, C), dtype=torch.float64)
symbols_ = torch.randint(0, C - 1, (B, S))
terminal_symbol = C - 1

boundary_ = torch.zeros((B, 4), dtype=torch.int64)
boundary_[:, 2] = seq_length
boundary_[:, 3] = frames

for rnnt_type in ["regular", "modified", "constrained"]:
for device in self.devices:
# normal rnnt
am = am_.to(device)
lm = lm_.to(device)
symbols = symbols_.to(device)
boundary = boundary_.to(device)

logits = am.unsqueeze(2) + lm.unsqueeze(1)
logits = logits.float()

# nonlinear transform
logits = torch.sigmoid(logits)
k2_loss = k2.rnnt_loss(
logits=logits,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
)

print(f"Unpruned rnnt loss with {rnnt_type} rnnt : {k2_loss}")

# pruning
k2_simple_loss, (px_grad, py_grad) = k2.rnnt_loss_simple(
lm=lm,
am=am,
symbols=symbols,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
return_grad=True,
reduction="none",
)

for r in range(2, 50, 5):
ranges = k2.get_rnnt_prune_ranges(
px_grad=px_grad,
py_grad=py_grad,
boundary=boundary,
s_range=r,
)
# (B, T, r, C)
pruned_am, pruned_lm = k2.do_rnnt_pruning(
am=am, lm=lm, ranges=ranges
)

logits = pruned_am + pruned_lm
# nonlinear transform
logits = torch.sigmoid(logits)

for fused_log_softmax in [True, False]:
if not fused_log_softmax:
logits = logits.log_softmax(dim=-1)
pruned_loss = k2.rnnt_loss_pruned(
logits=logits,
symbols=symbols,
ranges=ranges,
termination_symbol=terminal_symbol,
boundary=boundary,
rnnt_type=rnnt_type,
reduction="none",
fused_log_softmax=fused_log_softmax,
)
if fused_log_softmax:
expected = pruned_loss
assert torch.allclose(expected, pruned_loss)

print(f"Pruned loss with range {r} : {pruned_loss}")

# Test the sequences that only have small number of symbols,
# at this circumstance, the s_range would be greater than S, which will
# raise errors (like, nan or inf loss) in our previous versions.
Expand Down
Loading