Skip to content

Commit

Permalink
fix unittest. remove writing M
Browse files Browse the repository at this point in the history
  • Loading branch information
MARD1NO committed Nov 1, 2024
1 parent 5f99222 commit bf85f9d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
29 changes: 21 additions & 8 deletions benchmark/test_attention_perf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
from typing import Generator

import torch

from .performance_utils import Benchmark


class AttentionBenchmark(Benchmark):
"""
benchmark for attention
"""

def __init__(self, *args, input_fn, **kwargs):
super().__init__(*args, **kwargs)
self.input_fn = input_fn

def get_input_iter(self, cur_dtype) -> Generator:
for seq_len in [1024, 2048, 3072, 4096]:
yield from self.input_fn(cur_dtype, seq_len)


def test_perf_scaled_dot_product_attention():
def scaled_dot_product_attention_kwargs(dtype, batch, seq_len):
def scaled_dot_product_attention_kwargs(dtype, seq_len):
num_heads = 8
head_size = 128
batch = 4

query = torch.randn(
(batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype
Expand All @@ -17,19 +34,15 @@ def scaled_dot_product_attention_kwargs(dtype, batch, seq_len):
value = torch.randn(
(batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype
)
return {"query": query, "key": key, "value": value, "is_causal": True}
yield query, key, value, None, 0.0, True

seq_len = [1024, 2048, 3072, 4096]
bench = Benchmark(
bench = AttentionBenchmark(
op_name="scaled_dot_product_attention",
input_fn=scaled_dot_product_attention_kwargs,
torch_op=torch.nn.functional.scaled_dot_product_attention,
arg_func=None,
dtypes=[
# torch.float32,
torch.float16,
],
batch=4,
sizes=seq_len,
kwargs_func=scaled_dot_product_attention_kwargs,
)
bench.run()
12 changes: 1 addition & 11 deletions src/flag_gems/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def _attn_fwd_inner(
configs = [
triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w)
for BM in [64, 128]
for BN in [32, 64]
for BN in [32, 64, 128]
for s in [1, 2, 3, 4]
for w in [4, 8]
]
Expand All @@ -137,7 +137,6 @@ def _attn_fwd(
V,
attn_mask,
sm_scale,
M,
Out, #
stride_q_batch,
stride_q_head,
Expand Down Expand Up @@ -297,10 +296,7 @@ def _attn_fwd(
HAS_ATTN_MASK, #
)
# epilogue
m_i += tl.math.log2(l_i)
acc = acc / l_i[:, None]
m_ptrs = M + off_hz * Q_CTX + offs_m
tl.store(m_ptrs, m_i, mask=q_load_mask)
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None])


Expand Down Expand Up @@ -339,11 +335,6 @@ def scaled_dot_product_attention(
query.shape[0] * query.shape[1],
1,
)
M = torch.empty(
(query.shape[0], query.shape[1], query.shape[2]),
device=query.device,
dtype=torch.float32,
)

if attn_mask is not None:
HAS_ATTN_MASK = True
Expand All @@ -364,7 +355,6 @@ def scaled_dot_product_attention(
value,
attn_mask,
sm_scale,
M,
o, #
query.stride(0),
query.stride(1),
Expand Down

0 comments on commit bf85f9d

Please sign in to comment.