Skip to content

Commit

Permalink
update gqa cuda benchmark for smooth_softmax
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 21, 2024
1 parent d3f9cab commit 5551c6e
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 22 deletions.
51 changes: 31 additions & 20 deletions onnxruntime/test/python/transformers/benchmark_gqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def plot_prompt_performance(
head_size: int,
max_seq_len: int,
local_window_size: Optional[int] = None,
use_smooth_softmax: bool = False,
):
import triton

Expand All @@ -55,6 +56,7 @@ def plot_prompt_performance(
"kv_num_heads": kv_num_heads,
"head_size": head_size,
"local_window_size": local_window_size,
"use_smooth_softmax": use_smooth_softmax,
},
)
]
Expand All @@ -68,6 +70,7 @@ def benchmark(
kv_num_heads: int,
head_size: int,
local_window_size: Optional[int] = None,
use_smooth_softmax: bool = False,
device="cuda",
):
warmup = 15
Expand All @@ -82,6 +85,7 @@ def benchmark(
kv_num_heads=kv_num_heads,
head_size=head_size,
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
use_smooth_softmax=use_smooth_softmax,
device=device,
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
)
Expand All @@ -103,6 +107,7 @@ def plot_token_performance(
head_size: int,
max_seq_len: int,
local_window_size: Optional[int] = None,
use_smooth_softmax: bool = False,
):
import triton

Expand All @@ -121,6 +126,7 @@ def plot_token_performance(
"kv_num_heads": kv_num_heads,
"head_size": head_size,
"local_window_size": local_window_size,
"use_smooth_softmax": use_smooth_softmax,
},
)
]
Expand All @@ -134,6 +140,7 @@ def benchmark(
kv_num_heads: int,
head_size: int,
local_window_size: Optional[int] = None,
use_smooth_softmax: bool = False,
device="cuda",
):
warmup = 15
Expand All @@ -150,6 +157,7 @@ def benchmark(
local_window_size=local_window_size if provider in ["ort_gqa_local", "ort_gqa_local_packed"] else -1,
do_rotary=True, # Most models use rotary positional embeddings
is_packed_qkv=provider in ["ort_gqa_packed", "ort_gqa_local_packed"],
use_smooth_softmax=use_smooth_softmax,
device=device,
)

Expand Down Expand Up @@ -186,26 +194,29 @@ def run_performance_test(sm: int):

for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures:
for batch_size in [1, 4]:
plot_prompt_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
model_name=model_name,
)
plot_token_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
model_name=model_name,
)
for use_smooth_softmax in [False, True]:
plot_prompt_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
use_smooth_softmax=use_smooth_softmax,
model_name=model_name,
)
plot_token_performance(
sm=sm,
batch_size=batch_size,
num_heads=num_heads,
kv_num_heads=kv_num_heads,
head_size=head_size,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
use_smooth_softmax=use_smooth_softmax,
model_name=model_name,
)


if __name__ == "__main__":
Expand Down
19 changes: 18 additions & 1 deletion onnxruntime/test/python/transformers/benchmark_gqa_windows.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def save_results(results, filename):
"Max Sequence Length",
"Sequence Length",
"Past Sequence Length",
"Smooth Softmax",
"Model Name",
],
)
Expand All @@ -36,6 +37,7 @@ def benchmark(
sequence_length: int = 1,
past_sequence_length: int = 0,
local_window_size: Optional[int] = None,
use_smooth_softmax: bool = False,
model_name: str = "Llama3-8B",
):
warmup = 15
Expand All @@ -50,6 +52,7 @@ def benchmark(
kv_num_heads=kv_num_heads,
head_size=head_size,
local_window_size=local_window_size if local_window_size else -1,
use_smooth_softmax=use_smooth_softmax,
do_rotary=True, # Most models use rotary positional embeddings
is_packed_qkv=model_name in ["Phi-3-mini-128k", "Phi-3-small-128k"],
device="cuda",
Expand Down Expand Up @@ -93,6 +96,8 @@ def run_performance_tests(args):
# Reduce max sequence length when GPU memory is not enough.
threshold = 131072 if memory_in_gb > 24 else 65536 if memory_in_gb > 12 else 32768

smooth_softmax = args.use_smooth_softmax

all_metrics = []
for num_heads, head_size, kv_num_heads, max_seq_len, local_window_size, model_name in configures:
prompt_metrics_model = []
Expand Down Expand Up @@ -131,6 +136,7 @@ def run_performance_tests(args):
sequence_length=sequence_length,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
use_smooth_softmax=smooth_softmax,
model_name=model_name,
)
metrics = [*metrics, batch_size, max_seq_len, sequence_length, 0, model_name]
Expand Down Expand Up @@ -169,9 +175,10 @@ def run_performance_tests(args):
past_sequence_length=past_sequence_length,
max_seq_len=min(threshold, max_seq_len),
local_window_size=local_window_size,
use_smooth_softmax=smooth_softmax,
model_name=model_name,
)
metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, model_name]
metrics = [*metrics, batch_size, max_seq_len, 1, past_sequence_length, smooth_softmax, model_name]
token_metrics_model.append(metrics)
all_metrics.append(metrics)
# Calculate average inference interval and throughput for each model
Expand Down Expand Up @@ -209,6 +216,16 @@ def run_performance_tests(args):
default="flash_attention",
help="GQA Kernel to use for benchmarking. Options: flash_attention, memory_efficient",
)

parser.add_argument(
"--use_smooth_softmax",
required=False,
action="store_true",
help="test smooth softmax",
)
parser.set_defaults(use_smooth_softmax=False)


args = parser.parse_args()

if args.kernel == "memory_efficient":
Expand Down
16 changes: 15 additions & 1 deletion onnxruntime/test/python/transformers/test_sparse_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
is_packed_qkv: bool = False,
max_cache_sequence_length=None,
max_rotary_sequence_length=None,
use_smooth_softmax: bool = False,
):
self.operator = operator
self.batch_size = batch_size
Expand Down Expand Up @@ -73,6 +74,8 @@ def __init__(
self.share_buffer = share_buffer
self.is_packed_qkv = is_packed_qkv

self.use_smooth_softmax = use_smooth_softmax

def shape_dict(self):
shapes = {
"query": (
Expand Down Expand Up @@ -166,6 +169,7 @@ def __init__(
is_packed_qkv=False,
max_cache_sequence_length=None,
max_rotary_sequence_length=None,
use_smooth_softmax: bool = False,
):
super().__init__(
"GroupQueryAttention",
Expand All @@ -185,6 +189,7 @@ def __init__(
is_packed_qkv=is_packed_qkv,
max_cache_sequence_length=max_cache_sequence_length,
max_rotary_sequence_length=max_rotary_sequence_length,
use_smooth_softmax=use_smooth_softmax,
)
# local_window_size is for ORT only, not for Torch implementation.
self.local_window_size = local_window_size
Expand Down Expand Up @@ -529,6 +534,7 @@ def create_group_query_attention_onnx_model(config: GroupQueryAttentionConfig):
local_window_size=config.local_window_size,
do_rotary=1 if config.do_rotary else 0,
rotary_interleaved=config.rotary_interleaved,
smooth_softmax=1 if config.use_smooth_softmax else 0,
domain="com.microsoft",
),
]
Expand Down Expand Up @@ -612,7 +618,15 @@ def group_query_attention_reference(
attn = torch.einsum("bhmd,bhnd->bhmn", query, key).float() * scale
if mask is not None:
attn = attn.masked_fill((1 - mask).bool(), float("-inf"))
attn = attn.softmax(-1)

if config.use_smooth_softmax:
qk_max = attn.amax(axis=-1, keepdim=True)
qk_max = torch.maximum(qk_max, torch.zeros_like(qk_max))
w = torch.exp(attn - qk_max)
attn = w * torch.reciprocal(w.sum(axis=-1, keepdim=True) + torch.exp(-qk_max))
else:
attn = attn.softmax(-1)

attn_output = torch.einsum("bhmn,bhnd->bhmd", attn.type_as(value), value)

result = attn_output.transpose(1, 2).contiguous()
Expand Down

0 comments on commit 5551c6e

Please sign in to comment.