diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 633d898..7bcaaa5 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -231,6 +231,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t batch_start = interleaved_start / num_queries; const size_t num_interleaved = num_tokens * num_queries; + // Self extend + constexpr size_t ngb_size = TConfig::self_extend_ngb_size; + constexpr size_t grp_size = TConfig::self_extend_grp_size; + // For the computation of Q, K, and V, it is useful to remember that // qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim] // and kQStride = kQKVDim * (kIsMHA ? 3 : 1); @@ -286,12 +290,17 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, const size_t interleaved_idx = task / kKVHeads; const size_t query_idx = interleaved_idx % num_queries; const size_t batch_idx = interleaved_idx / num_queries; - const size_t pos = batch_start + batch_idx; + size_t pos = batch_start + batch_idx; const size_t cache_pos = div_seq_len.Remainder(pos); const size_t kv_offset = cache_pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim * 2; KVCache& kv_cache = kv_caches[query_idx]; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; + + // When embedding position, we will use grouped key position + if (pos > ngb_size && TConfig::kSelfExtend) { + pos /= grp_size; + } if constexpr (kIsMHA) { // For MHA, copy KV into the KV cache from scratch space (see above). const float* HWY_RESTRICT q = @@ -321,7 +330,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens, activations.q.Batch(interleaved_idx) + head * kQStride; // Apply rope and scaling to Q. - const size_t pos = batch_start + batch_idx; + size_t pos = batch_start + batch_idx; + if (pos > ngb_size && TConfig::kSelfExtend) { + const grp_pos = pos / grp_size; + const shift = ngb_size - ngb_size / grp_size + const shifted_grouped_pos = grp_pos + shift + pos = shifted_grouped_pos; + } PostQK(q, pos, layer); MulByConst(kQueryScale, q, kQKVDim);