From 5551c6ec2ec426cefedb99274952b4581410181f Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 21 Aug 2024 11:14:31 -0700 Subject: [PATCH] update gqa cuda benchmark for smooth_softmax --- .../test/python/transformers/benchmark_gqa.py | 51 +++++++++++-------- .../transformers/benchmark_gqa_windows.py | 19 ++++++- .../transformers/test_sparse_attention.py | 16 +++++- 3 files changed, 64 insertions(+), 22 deletions(-) diff --git a/onnxruntime/test/python/transformers/benchmark_gqa.py b/onnxruntime/test/python/transformers/benchmark_gqa.py index 5e028519b9f34..53d015a029083 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa.py @@ -37,6 +37,7 @@ def plot_prompt_performance( head_size: int, max_seq_len: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, ): import triton @@ -55,6 +56,7 @@ def plot_prompt_performance( "kv_num_heads": kv_num_heads, "head_size": head_size, "local_window_size": local_window_size, + "use_smooth_softmax": use_smooth_softmax, }, ) ] @@ -68,6 +70,7 @@ def benchmark( kv_num_heads: int, head_size: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, device="cuda", ): warmup = 15 @@ -82,6 +85,7 @@ def benchmark( kv_num_heads=kv_num_heads, head_size=head_size, local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, + use_smooth_softmax=use_smooth_softmax, device=device, is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], ) @@ -103,6 +107,7 @@ def plot_token_performance( head_size: int, max_seq_len: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, ): import triton @@ -121,6 +126,7 @@ def plot_token_performance( "kv_num_heads": kv_num_heads, "head_size": head_size, "local_window_size": local_window_size, + "use_smooth_softmax": use_smooth_softmax, }, ) ] @@ -134,6 +140,7 @@ def benchmark( kv_num_heads: int, head_size: int, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, device="cuda", ): warmup = 15 @@ -150,6 +157,7 @@ def benchmark( local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1, do_rotary=True, # Most models use rotary positional embeddings is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"], + use_smooth_softmax=use_smooth_softmax, device=device, ) @@ -186,26 +194,29 @@ def run_performance_test(sm: int): for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: for batch_size in [1, 4]: - plot_prompt_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - max_seq_len=min(threshold, max_seq_len), - local_window_size=local_window_size, - model_name=model_name, - ) - plot_token_performance( - sm=sm, - batch_size=batch_size, - num_heads=num_heads, - kv_num_heads=kv_num_heads, - head_size=head_size, - max_seq_len=min(threshold, max_seq_len), - local_window_size=local_window_size, - model_name=model_name, - ) + for use_smooth_softmax in [False, True]: + plot_prompt_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + use_smooth_softmax=use_smooth_softmax, + model_name=model_name, + ) + plot_token_performance( + sm=sm, + batch_size=batch_size, + num_heads=num_heads, + kv_num_heads=kv_num_heads, + head_size=head_size, + max_seq_len=min(threshold, max_seq_len), + local_window_size=local_window_size, + use_smooth_softmax=use_smooth_softmax, + model_name=model_name, + ) if __name__ == "__main__": diff --git a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py index b781ccf03f138..c0daa888944d4 100644 --- a/onnxruntime/test/python/transformers/benchmark_gqa_windows.py +++ b/onnxruntime/test/python/transformers/benchmark_gqa_windows.py @@ -19,6 +19,7 @@ def save_results(results, filename): "Max Sequence Length", "Sequence Length", "Past Sequence Length", + "Smooth Softmax", "Model Name", ], ) @@ -36,6 +37,7 @@ def benchmark( sequence_length: int = 1, past_sequence_length: int = 0, local_window_size: Optional[int] = None, + use_smooth_softmax: bool = False, model_name: str = "Llama3-8B", ): warmup = 15 @@ -50,6 +52,7 @@ def benchmark( kv_num_heads=kv_num_heads, head_size=head_size, local_window_size=local_window_size if local_window_size else -1, + use_smooth_softmax=use_smooth_softmax, do_rotary=True, # Most models use rotary positional embeddings is_packed_qkv=model_name in ["Phi-3-mini-128k", "Phi-3-small-128k"], device="cuda", @@ -93,6 +96,8 @@ def run_performance_tests(args): # Reduce max sequence length when GPU memory is not enough. threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768 + smooth_softmax = args.use_smooth_softmax + all_metrics = [] for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures: prompt_metrics_model = [] @@ -131,6 +136,7 @@ def run_performance_tests(args): sequence_length=sequence_length, max_seq_len=min(threshold, max_seq_len), local_window_size=local_window_size, + use_smooth_softmax=smooth_softmax, model_name=model_name, ) metrics = [*metrics, batch_size, max_seq_len, sequence_length, 0, model_name] @@ -169,9 +175,10 @@ def run_performance_tests(args): past_sequence_length=past_sequence_length, max_seq_len=min(threshold, max_seq_len), local_window_size=local_window_size, + use_smooth_softmax=smooth_softmax, model_name=model_name, ) - metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, model_name] + metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, smooth_softmax, model_name] token_metrics_model.append(metrics) all_metrics.append(metrics) # Calculate average inference interval and throughput for each model @@ -209,6 +216,16 @@ def run_performance_tests(args): default="flash_attention", help="GQA Kernel to use for benchmarking. Options: flash_attention, memory_efficient", ) + + parser.add_argument( + "--use_smooth_softmax", + required=False, + action="store_true", + help="test smooth softmax", + ) + parser.set_defaults(use_smooth_softmax=False) + + args = parser.parse_args() if args.kernel == "memory_efficient": diff --git a/onnxruntime/test/python/transformers/test_sparse_attention.py b/onnxruntime/test/python/transformers/test_sparse_attention.py index f18bcdba65579..3e66c465b916e 100644 --- a/onnxruntime/test/python/transformers/test_sparse_attention.py +++ b/onnxruntime/test/python/transformers/test_sparse_attention.py @@ -43,6 +43,7 @@ def __init__( is_packed_qkv: bool = False, max_cache_sequence_length=None, max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, ): self.operator = operator self.batch_size = batch_size @@ -73,6 +74,8 @@ def __init__( self.share_buffer = share_buffer self.is_packed_qkv = is_packed_qkv + self.use_smooth_softmax = use_smooth_softmax + def shape_dict(self): shapes = { "query": ( @@ -166,6 +169,7 @@ def __init__( is_packed_qkv=False, max_cache_sequence_length=None, max_rotary_sequence_length=None, + use_smooth_softmax: bool = False, ): super().__init__( "GroupQueryAttention", @@ -185,6 +189,7 @@ def __init__( is_packed_qkv=is_packed_qkv, max_cache_sequence_length=max_cache_sequence_length, max_rotary_sequence_length=max_rotary_sequence_length, + use_smooth_softmax=use_smooth_softmax, ) # local_window_size is for ORT only, not for Torch implementation. self.local_window_size = local_window_size @@ -529,6 +534,7 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig): local_window_size=config.local_window_size, do_rotary=1 if config.do_rotary else 0, rotary_interleaved=config.rotary_interleaved, + smooth_softmax=1 if config.use_smooth_softmax else 0, domain="com.microsoft", ), ] @@ -612,7 +618,15 @@ def group_query_attention_reference( attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale if mask is not None: attn = attn.masked_fill((1 - mask).bool(), float("-inf")) - attn = attn.softmax(-1) + + if config.use_smooth_softmax: + qk_max = attn.amax(axis=-1, keepdim=True) + qk_max = torch.maximum(qk_max, torch.zeros_like(qk_max)) + w = torch.exp(attn - qk_max) + attn = w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-qk_max)) + else: + attn = attn.softmax(-1) + attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value) result = attn_output.transpose(1, 2).contiguous()