Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
pkufool committed Jun 24, 2024
1 parent 656a840 commit 214a129
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 33 deletions.
101 changes: 70 additions & 31 deletions k2/python/k2/rnnt_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def get_rnnt_logprobs(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

px_lm = torch.gather(lm[:, :S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]

px = px_am + px_lm # [B][S][T+1], last slice with indexes out of
# boundary is -inf
Expand Down Expand Up @@ -313,9 +315,9 @@ def rnnt_loss_simple(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down Expand Up @@ -423,14 +425,18 @@ def get_rnnt_logprobs_joint(
px = torch.cat(
(
px,
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

px[:, :, :T] -= normalizers[:, :S, :]

py = logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() # [B][S+1][T]
py = (
logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone()
) # [B][S+1][T]
py -= normalizers

if rnnt_type == "regular":
Expand Down Expand Up @@ -515,9 +521,9 @@ def rnnt_loss(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down Expand Up @@ -571,7 +577,9 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor:
return x


def _adjust_pruning_lower_bound(s_begin: torch.Tensor, s_range: int) -> torch.Tensor:
def _adjust_pruning_lower_bound(
s_begin: torch.Tensor, s_range: int
) -> torch.Tensor:
"""Adjust s_begin (pruning lower bounds) to make it satisfy the following
constraints
Expand Down Expand Up @@ -608,13 +616,17 @@ def _adjust_pruning_lower_bound(s_begin: torch.Tensor, s_range: int) -> torch.Te
(B, T) = s_begin.shape
s_begin = _monotonic_lower_bound(s_begin)
# do the magic transformation
s_begin = -(s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device))
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
# make the transformed tensor to be non-decreasing
s_begin = _monotonic_lower_bound(s_begin)
# make start symbol to be zero.
s_begin = torch.clamp(s_begin, min=0)
# do the magic transformation again to recover s_begin
s_begin = -(s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device))
s_begin = -(
s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)
)
return s_begin


Expand Down Expand Up @@ -685,7 +697,9 @@ def get_rnnt_prune_ranges(
if is_regular:
Ss = boundary[:, 2]
Ts = boundary[:, 3]
s_range_min = Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item()
s_range_min = (
Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item()
)
if s_range < s_range_min:
print(
f"Warning: get_rnnt_prune_ranges - got s_range={s_range} "
Expand Down Expand Up @@ -841,13 +855,19 @@ def get_rnnt_prune_ranges_deprecated(
than 2, or no valid paths could survive pruning. Given {s_range}"""

px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device)
py_pad = torch.zeros((B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device)
py_pad = torch.zeros(
(B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device
)
py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2)
tot_grad = torch.cat((px_grad, px_pad), dim=1) + py_grad_padded # (B, S + 1, T1)
tot_grad = (
torch.cat((px_grad, px_pad), dim=1) + py_grad_padded
) # (B, S + 1, T1)

tot_grad = torch.cat(
(
torch.zeros((B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device),
torch.zeros(
(B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device
),
tot_grad,
),
dim=1,
Expand Down Expand Up @@ -927,7 +947,9 @@ def do_rnnt_pruning(
lm_pruned = torch.gather(
lm,
dim=1,
index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, decoder_dim)),
index=ranges.reshape(B, T * s_range, 1).expand(
(B, T * s_range, decoder_dim)
),
).reshape(B, T, s_range, decoder_dim)
return am_pruned, lm_pruned

Expand Down Expand Up @@ -958,7 +980,10 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor):
assert shifts.shape == (B, T), (shifts.shape, B, T)

index = (
torch.arange(S, device=src.device).view((1, S)).repeat((T, 1)).repeat((B, 1, 1))
torch.arange(S, device=src.device)
.view((1, S))
.repeat((T, 1))
.repeat((B, 1, 1))
)
index = (index - shifts.reshape(B, T, 1)) % S
return torch.gather(src, 2, index)
Expand Down Expand Up @@ -1118,7 +1143,9 @@ def get_hat_logprobs_pruned(
px = torch.cat(
(
px,
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
Expand Down Expand Up @@ -1293,7 +1320,9 @@ def get_rnnt_logprobs_pruned(
px = torch.cat(
(
px,
torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype),
torch.full(
(B, S, 1), float("-inf"), device=px.device, dtype=px.dtype
),
),
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..
Expand Down Expand Up @@ -1433,9 +1462,9 @@ def rnnt_loss_pruned(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down Expand Up @@ -1597,7 +1626,9 @@ def get_rnnt_logprobs_smoothed(
+ torch.finfo(lm_probs.dtype).tiny
) # [1][1][C]
amonly_normalizers = (
torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log()
torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C))
.reshape(B, T, 1)
.log()
+ am_max
) # [B][T][1]
amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T]
Expand Down Expand Up @@ -1631,12 +1662,16 @@ def get_rnnt_logprobs_smoothed(
dim=2,
) # now: [B][S][T+1], index [:,:,T] has -inf..

px_lm = torch.gather(lm[:, :S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1]
px_lm = torch.gather(
lm[:, :S], dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]
px_lm_unigram = torch.gather(
unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1)
) # [B][S][1]

px = px_am + px_lm # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T]
px = (
px_am + px_lm
) # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T]
px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T]

px_amonly = (
Expand Down Expand Up @@ -1664,10 +1699,14 @@ def get_rnnt_logprobs_smoothed(
am_only_scale = 1.0e-20

px_interp = (
px * combined_scale + px_lmonly * lm_only_scale + px_amonly * am_only_scale
px * combined_scale
+ px_lmonly * lm_only_scale
+ px_amonly * am_only_scale
)
py_interp = (
py * combined_scale + py_lmonly * lm_only_scale + py_amonly * am_only_scale
py * combined_scale
+ py_lmonly * lm_only_scale
+ py_amonly * am_only_scale
)

if rnnt_type == "regular":
Expand Down Expand Up @@ -1780,9 +1819,9 @@ def rnnt_loss_smoothed(
).expand(B, 1, 1)
else:
offset = (boundary[:, 3] - 1) / 2
penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape(
1, 1, T0
)
penalty = offset.reshape(B, 1, 1) - torch.arange(
T0, device=px.device
).reshape(1, 1, T0)
penalty = penalty * delay_penalty
px += penalty.to(px.dtype)

Expand Down
8 changes: 6 additions & 2 deletions k2/python/tests/rnnt_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,11 +280,15 @@ def test_rnnt_loss_random(self):
rnnt_type=rnnt_type,
)
assert (
px.shape == (B, S, T) if rnnt_type != "regular" else (B, S, T + 1)
px.shape == (B, S, T)
if rnnt_type != "regular"
else (B, S, T + 1)
)
assert py.shape == (B, S + 1, T)
assert symbols.shape == (B, S)
m = k2.mutual_information_recursion(px=px, py=py, boundary=boundary)
m = k2.mutual_information_recursion(
px=px, py=py, boundary=boundary
)

if device == torch.device("cpu"):
expected = -torch.mean(m)
Expand Down

0 comments on commit 214a129

Please sign in to comment.