From 4a196d15940b0f328735c888e2e861d67602ffcf Mon Sep 17 00:00:00 2001 From: aciddelgado <139922440+aciddelgado@users.noreply.github.com> Date: Sat, 23 Mar 2024 14:30:35 -0700 Subject: [PATCH] Packed QKV and Rotary Embedding Support for sm<80 GQA (#20012) ### Description Add support for packed qkv input and rotary embedding with sm<80 using memory efficient attention kernel. ### Motivation and Context Allows lower-end gpus to run gqa with packed qkv input and rotary embedding. --- .../cuda/bert/group_query_attention.cc | 23 ++- .../cuda/bert/group_query_attention_impl.cu | 160 ++++++++++++++++-- .../cuda/bert/group_query_attention_impl.h | 2 + .../python/transformers/test_flash_attn.py | 95 ++++++----- 4 files changed, 216 insertions(+), 64 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index 814aa1fb3c8f0..112f609d46598 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -159,8 +159,6 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && - do_rotary_ == false && - key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -172,18 +170,31 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (use_memory_efficient_attention && needs_buff) { kv_buffer_bytes = (sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.seqlen_present_kv_cache * parameters.head_size); } + size_t rotary_buffer_bytes = 0; + if (use_memory_efficient_attention && do_rotary_) { + rotary_buffer_bytes = 2 * sizeof(T) * parameters.batch_size * parameters.num_heads * parameters.sequence_length * parameters.head_size; + rotary_buffer_bytes += sizeof(int64_t) * parameters.batch_size * parameters.sequence_length; + } size_t fmha_buffer_bytes = 0; if (use_memory_efficient_attention && MemoryEfficientAttentionParams::need_workspace(parameters.head_size, sizeof(T) == sizeof(float))) { fmha_buffer_bytes = (parameters.batch_size * parameters.sequence_length * parameters.num_heads * parameters.head_size * sizeof(float)); } + size_t unpacked_qkv_bytes = 0; + if (use_memory_efficient_attention && parameters.is_packed_qkv) { + unpacked_qkv_bytes = (parameters.batch_size * parameters.sequence_length * (parameters.num_heads + 2 * parameters.kv_num_heads) * parameters.head_size * sizeof(T)); + } auto k_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(kv_buffer_bytes, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(rotary_buffer_bytes, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(fmha_buffer_bytes, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(unpacked_qkv_bytes, context->GetComputeStream()); #else constexpr bool use_memory_efficient_attention = false; auto k_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto v_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto rotary_buffer = GetScratchBuffer(0, context->GetComputeStream()); auto fmha_buffer = GetScratchBuffer(0, context->GetComputeStream()); + auto unpacked_qkv_buffer = GetScratchBuffer(0, context->GetComputeStream()); #endif // seqlens_k buffer @@ -251,7 +262,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } - // Rotary + if (unpacked_qkv_buffer != nullptr) { + data.unpacked_qkv_buffer = reinterpret_cast(unpacked_qkv_buffer.get()); + } + if (rotary_buffer != nullptr) { + data.rotary_buffer = reinterpret_cast(rotary_buffer.get()); + } + // Rotary Embedding if (parameters.do_rotary) { data.cos_cache = reinterpret_cast(cos_cache->Data()); data.sin_cache = reinterpret_cast(sin_cache->Data()); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index afba83be34e2d..f519be1c97149 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -42,6 +42,7 @@ limitations under the License. #include "contrib_ops/cuda/bert/group_query_attention_impl.h" #include "contrib_ops/cuda/bert/attention_impl.h" #include "core/providers/cuda/shared_inc/cuda_call.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" #include using namespace onnxruntime::cuda; @@ -150,6 +151,8 @@ __global__ void ConcatNewToPastKVLarge(const int new_seqlen, template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block, const bool past_only = false) { @@ -171,14 +174,14 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatNewToPastKV<<>>(kv_sequence_length, past_sequence_length, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -191,7 +194,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), reinterpret_cast(data.present_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -200,7 +203,7 @@ Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameter H, kv_num_heads, reinterpret_cast(data.past_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), reinterpret_cast(data.present_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); @@ -281,6 +284,8 @@ __global__ void ConcatKVInPlaceLarge(const int max_seqlen, template Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, + const void* new_key, + const void* new_value, cudaStream_t stream, const int max_threads_per_block) { const int batch_size = parameters.batch_size; @@ -300,12 +305,12 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, const dim3 block(H, kv_num_heads, 1); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlace<<>>(present_sequence_length, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } else { @@ -316,14 +321,14 @@ Status LaunchConcatKVInPlace(contrib::GroupQueryAttentionParameters& parameters, H, kv_num_heads, reinterpret_cast(data.present_key), - reinterpret_cast(data.key), + reinterpret_cast(new_key), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); ConcatKVInPlaceLarge<<>>(present_sequence_length, H, kv_num_heads, reinterpret_cast(data.present_value), - reinterpret_cast(data.value), + reinterpret_cast(new_value), seqlens_k, past_kv_format == AttentionQkvFormat::Q_K_V_BSNH); } @@ -468,6 +473,83 @@ Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, i return CUDA_CALL(cudaGetLastError()); } +// Kernel to unpack qkv from packed qkv +template +__global__ void UnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, + const int batch_size) { + const int tid = threadIdx.x + blockIdx.x * blockDim.x; + int d = (num_heads + 2 * kv_num_heads) * head_size; + const int qkv_size = batch_size * sequence_length * d; + const int q_size = num_heads * head_size; + const int k_size = kv_num_heads * head_size; + if (tid < qkv_size) { + int batch = tid / (d * sequence_length); + int sequence = (tid % (d * sequence_length)) / d; + int offset = tid % d; + if (offset < q_size) { + int unpacked_i = batch * sequence_length * num_heads * head_size + sequence * num_heads * head_size + offset; + unpacked_q[unpacked_i] = packed_qkv[tid]; + } else if (offset < q_size + k_size) { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size); + unpacked_k[unpacked_i] = packed_qkv[tid]; + } else { + int unpacked_i = batch * sequence_length * kv_num_heads * head_size + sequence * kv_num_heads * head_size + (offset - q_size - k_size); + unpacked_v[unpacked_i] = packed_qkv[tid]; + } + } +} + +// Unpack packed qkv +template +Status LaunchUnpackQKV(const T* packed_qkv, T* unpacked_q, T* unpacked_k, T* unpacked_v, const int num_heads, + const int kv_num_heads, const int head_size, const int sequence_length, const int batch_size, + cudaStream_t stream, const int max_threads_per_block) { + const int threads = max_threads_per_block; + const int blocks = (batch_size * sequence_length * (num_heads + 2 * kv_num_heads) * head_size + threads - 1) / threads; + UnpackQKV<<>>(packed_qkv, unpacked_q, unpacked_k, unpacked_v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size); + return CUDA_CALL(cudaGetLastError()); +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsPrompt(int32_t* seqlens_k, int64_t* position_ids, const int seqlen, + const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + int b = tid / seqlen; + int s = tid % seqlen; + if (b < batch_size) { + if (s < seqlens_k[b] + 1) { + position_ids[tid] = s; + } else { + position_ids[tid] = 1; + } + } +} + +// Kernel to convert seqlens_k to position_ids +__global__ void SeqlensToPosIdsToken(int32_t* seqlens_k, int64_t* position_ids, const int batch_size) { + int tid = blockDim.x * blockIdx.x + threadIdx.x; + if (tid < batch_size) { + position_ids[tid] = seqlens_k[tid]; + } +} + +// Convert seqlens_k to position_ids +Status LaunchSeqlensToPosIds(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, + int64_t* position_ids, cudaStream_t stream, const int max_threads_per_block) { + const int seqlen = parameters.sequence_length; + const int batch_size = parameters.batch_size; + const int threads = max_threads_per_block; + const int blocks = (batch_size * seqlen + threads - 1) / threads; + if (parameters.is_prompt) { + SeqlensToPosIdsPrompt<<>>(seqlens_k, position_ids, seqlen, batch_size); + } else { + SeqlensToPosIdsToken<<>>(seqlens_k, position_ids, batch_size); + } + return CUDA_CALL(cudaGetLastError()); +} + ////////// Launch Kernels #if USE_FLASH_ATTENTION @@ -517,7 +599,8 @@ Status FlashAttention( seqlens_k = data.seqlens_k_total; } } else if (!parameters.kv_share_buffer) { // copy past kv to present kv - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, nullptr, nullptr, stream, max_threads_per_block, + true)); } void* present_key = reinterpret_cast(const_cast(data.present_key)); @@ -563,15 +646,62 @@ Status EfficientAttention( const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - const void* query = reinterpret_cast(data.query); - const void* key = reinterpret_cast(data.key); - const void* value = reinterpret_cast(data.value); + const void* query; + const void* key; + const void* value; + + if (!parameters.is_packed_qkv) { + query = reinterpret_cast(data.query); + key = reinterpret_cast(data.key); + value = reinterpret_cast(data.value); + } else { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q = reinterpret_cast(data.unpacked_qkv_buffer); + auto k = reinterpret_cast(data.unpacked_qkv_buffer + q_size); + auto v = reinterpret_cast(data.unpacked_qkv_buffer + q_size + k_size); + ORT_RETURN_IF_ERROR(LaunchUnpackQKV(reinterpret_cast(data.query), q, k, v, num_heads, kv_num_heads, + head_size, sequence_length, batch_size, stream, max_threads_per_block)); + query = reinterpret_cast(q); + key = reinterpret_cast(k); + value = reinterpret_cast(v); + } + + if (parameters.do_rotary) { + size_t q_size = static_cast(batch_size * sequence_length * num_heads * head_size); + size_t k_size = static_cast(batch_size * sequence_length * kv_num_heads * head_size); + auto q_buffer = reinterpret_cast(data.rotary_buffer); + auto k_buffer = q_buffer + q_size; + auto position_ids_buff = reinterpret_cast(k_buffer + k_size); + ORT_RETURN_IF_ERROR(LaunchSeqlensToPosIds(parameters, data.seqlens_k, position_ids_buff, stream, + max_threads_per_block)); + DUMP_TENSOR_INIT(); + DUMP_TENSOR("position_ids", position_ids_buff, batch_size, sequence_length); + // Launch rotary embedding kernel + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, q_buffer, reinterpret_cast(query), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + ORT_RETURN_IF_ERROR(LaunchRotaryEmbeddingKernel(stream, k_buffer, reinterpret_cast(key), + position_ids_buff, data.cos_cache, data.sin_cache, + parameters.batch_size, parameters.sequence_length, + parameters.kv_num_heads, parameters.head_size, + parameters.rotary_dim, parameters.seqlen_present_kv_cache, + /*position_ids_format*/ 1, parameters.rotary_interleaved, + device_prop.maxThreadsPerBlock, /*transposed*/ false)); + query = reinterpret_cast(q_buffer); + key = reinterpret_cast(k_buffer); + } if (parameters.is_prompt) { // Launch kernel to copy seqlen constexpr int thr_per_blk = 256; int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); + repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, + batch_size); } else { ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); } @@ -583,7 +713,7 @@ Status EfficientAttention( "Past and present kv shall share the same tensor when kv_share_buffer is on."); } // Concatenate new kv in place - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, key, value, stream, max_threads_per_block)); } else { // Not share buffer case if (data.past_key != nullptr && data.past_key == data.present_key) { @@ -591,7 +721,7 @@ Status EfficientAttention( "Past and present kv share the same tensor but kv_share_buffer is not on."); } // Copy past and concat new KV to present buffer - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, key, value, stream, max_threads_per_block)); } // Ungroup if grouped, otherwise use present kv directly diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index 1bf91f9c875eb..32341afa0e3fa 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -30,6 +30,8 @@ struct GroupQueryAttentionData { int* seqlens_k_total = nullptr; // Memory Efficient buffers T* fmha_buffer = nullptr; + T* unpacked_qkv_buffer = nullptr; + T* rotary_buffer = nullptr; T* k = nullptr; T* v = nullptr; // Output Tensors diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index b784c83329c76..183d6218567a7 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -1216,8 +1216,6 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) - # print(k.shape) - # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1328,10 +1326,6 @@ def parity_check_gqa_prompt( out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) - # print((out - out_ref)[0, :, 0, 0]) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1724,9 +1718,6 @@ def parity_check_gqa_past( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) - # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1939,18 +1930,6 @@ def parity_check_gqa_past_no_buff( out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() - # print(cache_seqlens[0]) - # print((out - out_ref)[0]) - # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) - - # Make sure past-present buffer updating correctly - # assert numpy.allclose( - # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True - # ) - # assert numpy.allclose( - # present_v[:, :, :-1, :], v_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True - # ) - # Compare results print( "NO buff", @@ -2078,10 +2057,27 @@ def test_gqa_no_past(self): for sq, skv in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + rtol=2e-3, + atol=2e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + rtol=2e-3, + atol=2e-3, + past_format=Formats.BNSH, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (PROMPT CASE) --------") @@ -2092,12 +2088,12 @@ def test_gqa_no_past(self): for h in h_sizes: for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + for packed in [False, True]: config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) parity_check_gqa_prompt( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2105,7 +2101,7 @@ def test_gqa_no_past(self): parity_check_gqa_prompt_no_buff( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rotary=rotary, rotary_interleaved=rotary_interleaved, packed=packed, @@ -2145,21 +2141,28 @@ def test_gqa_past(self): for s, s2 in seqs: for n, n2 in num_h: for h in h_sizes: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for packed in [False, True]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + past_format=Formats.BNSH, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if major < 8 or platform.system() != "Linux": return print("------- FLASH ATTENTION (TOKEN GEN) -------") @@ -2170,13 +2173,13 @@ def test_gqa_past(self): for h in h_sizes: for local in [False, True]: for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: - for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + for packed in [False, True]: sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 config = Config(b, s, s2, sp, n, n2, h) parity_check_gqa_past( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, rotary=rotary, @@ -2186,7 +2189,7 @@ def test_gqa_past(self): parity_check_gqa_past_no_buff( config, local=local, - past_format=past_kv_format, + past_format=Formats.BNSH, rtol=1e-3, atol=1e-3, rotary=rotary,