From bf85f9dc120dcb2735243e5baf265e0fdb2d43c6 Mon Sep 17 00:00:00 2001 From: MARD1NO <359521840@qq.com> Date: Fri, 1 Nov 2024 17:57:06 +0800 Subject: [PATCH] fix unittest. remove writing M --- benchmark/test_attention_perf.py | 29 +++++++++++++++++++++-------- src/flag_gems/ops/attention.py | 12 +----------- 2 files changed, 22 insertions(+), 19 deletions(-) diff --git a/benchmark/test_attention_perf.py b/benchmark/test_attention_perf.py index 403a24e7..3d08cd2a 100644 --- a/benchmark/test_attention_perf.py +++ b/benchmark/test_attention_perf.py @@ -1,12 +1,29 @@ +from typing import Generator + import torch from .performance_utils import Benchmark +class AttentionBenchmark(Benchmark): + """ + benchmark for attention + """ + + def __init__(self, *args, input_fn, **kwargs): + super().__init__(*args, **kwargs) + self.input_fn = input_fn + + def get_input_iter(self, cur_dtype) -> Generator: + for seq_len in [1024, 2048, 3072, 4096]: + yield from self.input_fn(cur_dtype, seq_len) + + def test_perf_scaled_dot_product_attention(): - def scaled_dot_product_attention_kwargs(dtype, batch, seq_len): + def scaled_dot_product_attention_kwargs(dtype, seq_len): num_heads = 8 head_size = 128 + batch = 4 query = torch.randn( (batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype @@ -17,19 +34,15 @@ def scaled_dot_product_attention_kwargs(dtype, batch, seq_len): value = torch.randn( (batch, num_heads, seq_len, head_size), device="cuda", dtype=dtype ) - return {"query": query, "key": key, "value": value, "is_causal": True} + yield query, key, value, None, 0.0, True - seq_len = [1024, 2048, 3072, 4096] - bench = Benchmark( + bench = AttentionBenchmark( op_name="scaled_dot_product_attention", + input_fn=scaled_dot_product_attention_kwargs, torch_op=torch.nn.functional.scaled_dot_product_attention, - arg_func=None, dtypes=[ # torch.float32, torch.float16, ], - batch=4, - sizes=seq_len, - kwargs_func=scaled_dot_product_attention_kwargs, ) bench.run() diff --git a/src/flag_gems/ops/attention.py b/src/flag_gems/ops/attention.py index 0f42c272..9591957c 100644 --- a/src/flag_gems/ops/attention.py +++ b/src/flag_gems/ops/attention.py @@ -115,7 +115,7 @@ def _attn_fwd_inner( configs = [ triton.Config({"BLOCK_M": BM, "BLOCK_N": BN}, num_stages=s, num_warps=w) for BM in [64, 128] - for BN in [32, 64] + for BN in [32, 64, 128] for s in [1, 2, 3, 4] for w in [4, 8] ] @@ -137,7 +137,6 @@ def _attn_fwd( V, attn_mask, sm_scale, - M, Out, # stride_q_batch, stride_q_head, @@ -297,10 +296,7 @@ def _attn_fwd( HAS_ATTN_MASK, # ) # epilogue - m_i += tl.math.log2(l_i) acc = acc / l_i[:, None] - m_ptrs = M + off_hz * Q_CTX + offs_m - tl.store(m_ptrs, m_i, mask=q_load_mask) tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=q_load_mask[:, None]) @@ -339,11 +335,6 @@ def scaled_dot_product_attention( query.shape[0] * query.shape[1], 1, ) - M = torch.empty( - (query.shape[0], query.shape[1], query.shape[2]), - device=query.device, - dtype=torch.float32, - ) if attn_mask is not None: HAS_ATTN_MASK = True @@ -364,7 +355,6 @@ def scaled_dot_product_attention( value, attn_mask, sm_scale, - M, o, # query.stride(0), query.stride(1),