diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py index c0daa888944d4..79cc8e41bf343 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -225,7 +225,6 @@ def run_performance_tests(args): ) parser.set_defaults(use_smooth_softmax=False) - args = parser.parse_args() if args.kernel == "memory_efficient": diff --git a/onnxruntime/test/python/transformers/test_gqa_cpu.py b/onnxruntime/test/python/transformers/test_gqa_cpu.py index b6b8aee15852f..3b16e0320da61 100644 --- a/onnxruntime/test/python/transformers/test_gqa_cpu.py +++ b/onnxruntime/test/python/transformers/test_gqa_cpu.py @@ -145,6 +145,7 @@ def create_group_query_attention_graph_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -169,6 +170,7 @@ def create_group_query_attention_graph_prompt( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -331,6 +333,7 @@ def create_group_query_attention_graph_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -357,6 +360,7 @@ def create_group_query_attention_graph_past( local_window_size=local_window_size, do_rotary=rotary, rotary_interleaved=rotary_interleaved, + smooth_softmax=1 if use_smooth_softmax else 0, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -667,6 +671,7 @@ def gqa_prompt_func( past_kv_format=Formats.BSNH, share_buffer=True, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_prompt( config, @@ -676,6 +681,7 @@ def gqa_prompt_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None @@ -773,6 +779,7 @@ def gqa_past_func( share_buffer=True, window_size=-1, rotary_interleaved=False, + use_smooth_softmax=False, ): onnx_model_str = create_group_query_attention_graph_past( config, @@ -782,6 +789,7 @@ def gqa_past_func( rotary=cos is not None, rotary_interleaved=rotary_interleaved, packed=new_k is None, + use_smooth_softmax=use_smooth_softmax, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() @@ -918,6 +926,7 @@ def attention_ref( window_size=(-1, -1), # -1 means infinite window size upcast=True, reorder_ops=False, + use_smooth_softmax=False, ): """ Arguments: @@ -935,6 +944,7 @@ def attention_ref( reorder_ops: whether to change the order of operations (scaling k instead of scaling k, etc.) without changing the math. This is to estimate the numerical error from operation reordering. + use_smooth_softmax: whether use smooth softmax or not Output: output: (batch_size, seqlen_q, nheads, head_dim) attention: (batch_size, nheads, seqlen_q, seqlen_k), softmax after dropout @@ -964,10 +974,19 @@ def attention_ref( q.device, ) scores.masked_fill_(local_mask, float("-inf")) - attention = torch.softmax(scores, dim=-1) + + if use_smooth_softmax: + qk_max = scores.amax(axis=-1, keepdim=True) + qk_max = torch.maximum(qk_max, torch.zeros_like(qk_max)) + w = torch.exp(scores - qk_max) + attention = w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-qk_max)) + else: + attention = torch.softmax(scores, dim=-1) + # Some rows might be completely masked out so we fill them with zero instead of NaN if window_size[0] >= 0 or window_size[1] >= 0: attention = attention.masked_fill(torch.all(local_mask, dim=-1, keepdim=True), 0.0) + # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: @@ -984,7 +1003,14 @@ def attention_ref( def attention_qkvpacked_ref( - qkv, key_padding_mask=None, dropout_p=0.0, dropout_mask=None, causal=False, upcast=True, reorder_ops=False + qkv, + key_padding_mask=None, + dropout_p=0.0, + dropout_mask=None, + causal=False, + upcast=True, + reorder_ops=False, + use_smooth_softmax=False, ): return attention_ref( qkv[:, :, 0], @@ -997,6 +1023,7 @@ def attention_qkvpacked_ref( upcast=upcast, causal=causal, reorder_ops=reorder_ops, + use_smooth_softmax=use_smooth_softmax, ) @@ -1008,6 +1035,7 @@ def parity_check_gqa_prompt( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1108,7 +1136,16 @@ def parity_check_gqa_prompt( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1132,6 +1169,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1148,6 +1186,7 @@ def parity_check_gqa_prompt( past_format, True, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1172,6 +1211,8 @@ def parity_check_gqa_prompt( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1201,6 +1242,7 @@ def parity_check_gqa_prompt_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1275,7 +1317,16 @@ def parity_check_gqa_prompt_no_buff( k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + new_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1299,6 +1350,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_prompt_func( @@ -1315,6 +1367,7 @@ def parity_check_gqa_prompt_no_buff( past_format, False, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) @@ -1339,6 +1392,8 @@ def parity_check_gqa_prompt_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1368,6 +1423,7 @@ def parity_check_gqa_past( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1473,7 +1529,16 @@ def parity_check_gqa_past( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1497,6 +1562,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1513,6 +1579,7 @@ def parity_check_gqa_past( True, left_window_size, rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1539,6 +1606,8 @@ def parity_check_gqa_past( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, " B:", config.batch_size, " S:", @@ -1566,6 +1635,7 @@ def parity_check_gqa_past_no_buff( rotary=False, rotary_interleaved=False, packed=False, + use_smooth_softmax=False, rtol=1e-3, atol=1e-3, ): @@ -1677,7 +1747,16 @@ def parity_check_gqa_past_no_buff( v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, + k_cache_rep, + v_cache_rep, + None, + key_padding_mask, + 0.0, + None, + causal=True, + window_size=window_size, + use_smooth_softmax=use_smooth_softmax, ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1701,6 +1780,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) else: out, present_k, present_v = gqa_past_func( @@ -1717,6 +1797,7 @@ def parity_check_gqa_past_no_buff( False, window_size=left_window_size, rotary_interleaved=rotary_interleaved, + use_smooth_softmax=use_smooth_softmax, ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) @@ -1737,6 +1818,8 @@ def parity_check_gqa_past_no_buff( rotary, " rotary_interleaved:", rotary_interleaved, + " smooth_softmax:", + use_smooth_softmax, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1787,26 +1870,29 @@ def test_gqa_no_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_prompt( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_prompt_no_buff( - config, - local=local, - past_format=past_kv_format, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for use_smooth_softmax in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) def test_gqa_past(self): print("-------- TEST GQA PAST (TOKEN GEN) ---------") @@ -1838,31 +1924,34 @@ def test_gqa_past(self): for local in [False, True]: for rotary, rotary_interleaved in [(False, False), (True, False), (True, True)]: for packed in [False, True]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - past_kv_format = Formats.BNSH - all_close = parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) - all_close = parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - rotary=rotary, - rotary_interleaved=rotary_interleaved, - packed=packed, - ) - self.assertTrue(all_close) + for use_smooth_softmax in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + past_kv_format = Formats.BNSH + all_close = parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) + all_close = parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + use_smooth_softmax=use_smooth_softmax, + ) + self.assertTrue(all_close) if __name__ == "__main__":