Skip to content

Commit

Permalink
compile success: set default self extend values in noSSM and griffin
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpsy committed Oct 18, 2024
1 parent 5a2a7ee commit 7bb4e0b
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
10 changes: 10 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ struct ConfigNoSSM {
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr ResidualType kResidual = ResidualType::Add;

// Self-extend parameters with defaul values
static constexpr bool kSelfExtend = false;
static constexpr size_t kSelfExtendNgbSize = 0;
static constexpr size_t kSelfExtendGrpSize = 1;
};

struct ConfigBaseGemmaV1 : ConfigNoSSM {
Expand Down Expand Up @@ -372,6 +377,11 @@ struct ConfigGriffin2B {
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;

// Self-extend parameters with defaul values
static constexpr bool kSelfExtend = false;
static constexpr size_t kSelfExtendNgbSize = 0;
static constexpr size_t kSelfExtendGrpSize = 1;
};

} // namespace gcpp
Expand Down
22 changes: 13 additions & 9 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,8 +232,8 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t num_interleaved = num_tokens * num_queries;

// Self extend
constexpr size_t ngb_size = TConfig::self_extend_ngb_size;
constexpr size_t grp_size = TConfig::self_extend_grp_size;
constexpr size_t ngb_size = TConfig::kSelfExtendNgbSize;
constexpr size_t grp_size = TConfig::kSelfExtendGrpSize;

// For the computation of Q, K, and V, it is useful to remember that
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
Expand Down Expand Up @@ -298,8 +298,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;

// When embedding position, we will use grouped key position
if (pos > ngb_size && TConfig::kSelfExtend) {
pos /= grp_size;
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
pos /= grp_size;
}
}
if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above).
Expand Down Expand Up @@ -331,11 +333,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,

// Apply rope and scaling to Q.
size_t pos = batch_start + batch_idx;
if (pos > ngb_size && TConfig::kSelfExtend) {
const grp_pos = pos / grp_size;
const shift = ngb_size - ngb_size / grp_size
const shifted_grouped_pos = grp_pos + shift
pos = shifted_grouped_pos;
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
}
PostQK<TConfig>(q, pos, layer);
MulByConst(kQueryScale, q, kQKVDim);
Expand Down

0 comments on commit 7bb4e0b

Please sign in to comment.