From 214a1295331adf1d395d9434d1f5f4f4e52578d0 Mon Sep 17 00:00:00 2001 From: pkufool Date: Mon, 24 Jun 2024 17:20:59 +0800 Subject: [PATCH] black --- k2/python/k2/rnnt_loss.py | 101 +++++++++++++++++++++--------- k2/python/tests/rnnt_loss_test.py | 8 ++- 2 files changed, 76 insertions(+), 33 deletions(-) diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index d0df5c8bb..555fd76a0 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -204,7 +204,9 @@ def get_rnnt_logprobs( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - px_lm = torch.gather(lm[:, :S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1] + px_lm = torch.gather( + lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] px = px_am + px_lm # [B][S][T+1], last slice with indexes out of # boundary is -inf @@ -313,9 +315,9 @@ def rnnt_loss_simple( ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape( - 1, 1, T0 - ) + penalty = offset.reshape(B, 1, 1) - torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) penalty = penalty * delay_penalty px += penalty.to(px.dtype) @@ -423,14 +425,18 @@ def get_rnnt_logprobs_joint( px = torch.cat( ( px, - torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype), + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), ), dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. px[:, :, :T] -= normalizers[:, :S, :] - py = logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() # [B][S+1][T] + py = ( + logits[:, :, :, termination_symbol].permute((0, 2, 1)).clone() + ) # [B][S+1][T] py -= normalizers if rnnt_type == "regular": @@ -515,9 +521,9 @@ def rnnt_loss( ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape( - 1, 1, T0 - ) + penalty = offset.reshape(B, 1, 1) - torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) penalty = penalty * delay_penalty px += penalty.to(px.dtype) @@ -571,7 +577,9 @@ def _monotonic_lower_bound(x: torch.Tensor) -> torch.Tensor: return x -def _adjust_pruning_lower_bound(s_begin: torch.Tensor, s_range: int) -> torch.Tensor: +def _adjust_pruning_lower_bound( + s_begin: torch.Tensor, s_range: int +) -> torch.Tensor: """Adjust s_begin (pruning lower bounds) to make it satisfy the following constraints @@ -608,13 +616,17 @@ def _adjust_pruning_lower_bound(s_begin: torch.Tensor, s_range: int) -> torch.Te (B, T) = s_begin.shape s_begin = _monotonic_lower_bound(s_begin) # do the magic transformation - s_begin = -(s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)) + s_begin = -( + s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) + ) # make the transformed tensor to be non-decreasing s_begin = _monotonic_lower_bound(s_begin) # make start symbol to be zero. s_begin = torch.clamp(s_begin, min=0) # do the magic transformation again to recover s_begin - s_begin = -(s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device)) + s_begin = -( + s_begin - (s_range - 1) * torch.arange(0, T, device=s_begin.device) + ) return s_begin @@ -685,7 +697,9 @@ def get_rnnt_prune_ranges( if is_regular: Ss = boundary[:, 2] Ts = boundary[:, 3] - s_range_min = Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item() + s_range_min = ( + Ss.sub(1).div(Ts, rounding_mode="trunc").add(2).max().item() + ) if s_range < s_range_min: print( f"Warning: get_rnnt_prune_ranges - got s_range={s_range} " @@ -841,13 +855,19 @@ def get_rnnt_prune_ranges_deprecated( than 2, or no valid paths could survive pruning. Given {s_range}""" px_pad = torch.zeros((B, 1, T1), dtype=px_grad.dtype, device=px_grad.device) - py_pad = torch.zeros((B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device) + py_pad = torch.zeros( + (B, S + 1, 1), dtype=py_grad.dtype, device=py_grad.device + ) py_grad_padded = py_grad if T1 == T else torch.cat((py_grad, py_pad), dim=2) - tot_grad = torch.cat((px_grad, px_pad), dim=1) + py_grad_padded # (B, S + 1, T1) + tot_grad = ( + torch.cat((px_grad, px_pad), dim=1) + py_grad_padded + ) # (B, S + 1, T1) tot_grad = torch.cat( ( - torch.zeros((B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device), + torch.zeros( + (B, 1, T1), dtype=tot_grad.dtype, device=tot_grad.device + ), tot_grad, ), dim=1, @@ -927,7 +947,9 @@ def do_rnnt_pruning( lm_pruned = torch.gather( lm, dim=1, - index=ranges.reshape(B, T * s_range, 1).expand((B, T * s_range, decoder_dim)), + index=ranges.reshape(B, T * s_range, 1).expand( + (B, T * s_range, decoder_dim) + ), ).reshape(B, T, s_range, decoder_dim) return am_pruned, lm_pruned @@ -958,7 +980,10 @@ def _roll_by_shifts(src: torch.Tensor, shifts: torch.LongTensor): assert shifts.shape == (B, T), (shifts.shape, B, T) index = ( - torch.arange(S, device=src.device).view((1, S)).repeat((T, 1)).repeat((B, 1, 1)) + torch.arange(S, device=src.device) + .view((1, S)) + .repeat((T, 1)) + .repeat((B, 1, 1)) ) index = (index - shifts.reshape(B, T, 1)) % S return torch.gather(src, 2, index) @@ -1118,7 +1143,9 @@ def get_hat_logprobs_pruned( px = torch.cat( ( px, - torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype), + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), ), dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. @@ -1293,7 +1320,9 @@ def get_rnnt_logprobs_pruned( px = torch.cat( ( px, - torch.full((B, S, 1), float("-inf"), device=px.device, dtype=px.dtype), + torch.full( + (B, S, 1), float("-inf"), device=px.device, dtype=px.dtype + ), ), dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. @@ -1433,9 +1462,9 @@ def rnnt_loss_pruned( ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape( - 1, 1, T0 - ) + penalty = offset.reshape(B, 1, 1) - torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) penalty = penalty * delay_penalty px += penalty.to(px.dtype) @@ -1597,7 +1626,9 @@ def get_rnnt_logprobs_smoothed( + torch.finfo(lm_probs.dtype).tiny ) # [1][1][C] amonly_normalizers = ( - torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)).reshape(B, T, 1).log() + torch.mv(am_probs.reshape(-1, C), unigram_lm.reshape(C)) + .reshape(B, T, 1) + .log() + am_max ) # [B][T][1] amonly_normalizers = amonly_normalizers.transpose(1, 2) # [B][1][T] @@ -1631,12 +1662,16 @@ def get_rnnt_logprobs_smoothed( dim=2, ) # now: [B][S][T+1], index [:,:,T] has -inf.. - px_lm = torch.gather(lm[:, :S], dim=2, index=symbols.unsqueeze(-1)) # [B][S][1] + px_lm = torch.gather( + lm[:, :S], dim=2, index=symbols.unsqueeze(-1) + ) # [B][S][1] px_lm_unigram = torch.gather( unigram_lm.expand(B, S, C), dim=2, index=symbols.unsqueeze(-1) ) # [B][S][1] - px = px_am + px_lm # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T] + px = ( + px_am + px_lm + ) # [B][S][T+1] if rnnt_type == "regular", otherwise [B][S][T] px[:, :, :T] -= normalizers[:, :S, :] # px: [B][S][T+1] or [B][S][T] px_amonly = ( @@ -1664,10 +1699,14 @@ def get_rnnt_logprobs_smoothed( am_only_scale = 1.0e-20 px_interp = ( - px * combined_scale + px_lmonly * lm_only_scale + px_amonly * am_only_scale + px * combined_scale + + px_lmonly * lm_only_scale + + px_amonly * am_only_scale ) py_interp = ( - py * combined_scale + py_lmonly * lm_only_scale + py_amonly * am_only_scale + py * combined_scale + + py_lmonly * lm_only_scale + + py_amonly * am_only_scale ) if rnnt_type == "regular": @@ -1780,9 +1819,9 @@ def rnnt_loss_smoothed( ).expand(B, 1, 1) else: offset = (boundary[:, 3] - 1) / 2 - penalty = offset.reshape(B, 1, 1) - torch.arange(T0, device=px.device).reshape( - 1, 1, T0 - ) + penalty = offset.reshape(B, 1, 1) - torch.arange( + T0, device=px.device + ).reshape(1, 1, T0) penalty = penalty * delay_penalty px += penalty.to(px.dtype) diff --git a/k2/python/tests/rnnt_loss_test.py b/k2/python/tests/rnnt_loss_test.py index dfcd01e6d..77d1a3b81 100644 --- a/k2/python/tests/rnnt_loss_test.py +++ b/k2/python/tests/rnnt_loss_test.py @@ -280,11 +280,15 @@ def test_rnnt_loss_random(self): rnnt_type=rnnt_type, ) assert ( - px.shape == (B, S, T) if rnnt_type != "regular" else (B, S, T + 1) + px.shape == (B, S, T) + if rnnt_type != "regular" + else (B, S, T + 1) ) assert py.shape == (B, S + 1, T) assert symbols.shape == (B, S) - m = k2.mutual_information_recursion(px=px, py=py, boundary=boundary) + m = k2.mutual_information_recursion( + px=px, py=py, boundary=boundary + ) if device == torch.device("cpu"): expected = -torch.mean(m)