From 7bb4e0b3a7a627a1738ecf3b8707db7ada141869 Mon Sep 17 00:00:00 2001 From: Nanubala Gnana Sai <45007169+jonpsy@users.noreply.github.com> Date: Fri, 18 Oct 2024 14:25:28 +0530 Subject: [PATCH] compile success: set default self extend values in noSSM and griffin --- gemma/configs.h | 10 ++++++++++ gemma/gemma-inl.h | 22 +++++++++++++--------- 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index be995f9..e24edc8 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -151,6 +151,11 @@ struct ConfigNoSSM { static constexpr PostQKType kPostQK = PostQKType::Rope; static constexpr ActivationType kActivation = ActivationType::Gelu; static constexpr ResidualType kResidual = ResidualType::Add; + + // Self-extend parameters with defaul values + static constexpr bool kSelfExtend = false; + static constexpr size_t kSelfExtendNgbSize = 0; + static constexpr size_t kSelfExtendGrpSize = 1; }; struct ConfigBaseGemmaV1 : ConfigNoSSM { @@ -372,6 +377,11 @@ struct ConfigGriffin2B { static constexpr ActivationType kActivation = ActivationType::Gelu; static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize; static constexpr ResidualType kResidual = ResidualType::Add; + + // Self-extend parameters with defaul values + static constexpr bool kSelfExtend = false; + static constexpr size_t kSelfExtendNgbSize = 0; + static constexpr size_t kSelfExtendGrpSize = 1; }; } // namespace gcpp diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 7bcaaa5..af8d787 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -232,8 +232,8 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t num_interleaved = num_tokens * num_queries; // Self extend - constexpr size_t ngb_size = TConfig::self_extend_ngb_size; - constexpr size_t grp_size = TConfig::self_extend_grp_size; + constexpr size_t ngb_size = TConfig::kSelfExtendNgbSize; + constexpr size_t grp_size = TConfig::kSelfExtendGrpSize; // For the computation of Q, K, and V, it is useful to remember that // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] @@ -298,8 +298,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; // When embedding position, we will use grouped key position - if (pos > ngb_size && TConfig::kSelfExtend) { - pos /= grp_size; + if constexpr (TConfig::kSelfExtend) { + if (pos > ngb_size) { + pos /= grp_size; + } } if constexpr (kIsMHA) { // For MHA, copy KV into the KV cache from scratch space (see above). @@ -331,11 +333,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, // Apply rope and scaling to Q. size_t pos = batch_start + batch_idx; - if (pos > ngb_size && TConfig::kSelfExtend) { - const grp_pos = pos / grp_size; - const shift = ngb_size - ngb_size / grp_size - const shifted_grouped_pos = grp_pos + shift - pos = shifted_grouped_pos; + if constexpr (TConfig::kSelfExtend) { + if (pos > ngb_size) { + const size_t grp_pos = pos / grp_size; + const size_t shift = ngb_size - ngb_size / grp_size; + const size_t shifted_grouped_pos = grp_pos + shift; + pos = shifted_grouped_pos; + } } PostQK(q, pos, layer); MulByConst(kQueryScale, q, kQKVDim);