Skip to content

Commit

Permalink
init push
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpsy committed Oct 18, 2024
1 parent 1982a6b commit 5a2a7ee
Showing 1 changed file with 17 additions and 2 deletions.
19 changes: 17 additions & 2 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t batch_start = interleaved_start / num_queries;
const size_t num_interleaved = num_tokens * num_queries;

// Self extend
constexpr size_t ngb_size = TConfig::self_extend_ngb_size;
constexpr size_t grp_size = TConfig::self_extend_grp_size;

// For the computation of Q, K, and V, it is useful to remember that
// qkv_einsum_w has shape [(kHeads + kKVHeads * 2), kKQVDim, kModelDim]
// and kQStride = kQKVDim * (kIsMHA ? 3 : 1);
Expand Down Expand Up @@ -286,12 +290,17 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
const size_t interleaved_idx = task / kKVHeads;
const size_t query_idx = interleaved_idx % num_queries;
const size_t batch_idx = interleaved_idx / num_queries;
const size_t pos = batch_start + batch_idx;
size_t pos = batch_start + batch_idx;
const size_t cache_pos = div_seq_len.Remainder(pos);
const size_t kv_offset = cache_pos * kCachePosSize +
layer * kCacheLayerSize + head * kQKVDim * 2;
KVCache& kv_cache = kv_caches[query_idx];
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;

// When embedding position, we will use grouped key position
if (pos > ngb_size && TConfig::kSelfExtend) {
pos /= grp_size;
}
if constexpr (kIsMHA) {
// For MHA, copy KV into the KV cache from scratch space (see above).
const float* HWY_RESTRICT q =
Expand Down Expand Up @@ -321,7 +330,13 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
activations.q.Batch(interleaved_idx) + head * kQStride;

// Apply rope and scaling to Q.
const size_t pos = batch_start + batch_idx;
size_t pos = batch_start + batch_idx;
if (pos > ngb_size && TConfig::kSelfExtend) {
const grp_pos = pos / grp_size;
const shift = ngb_size - ngb_size / grp_size
const shifted_grouped_pos = grp_pos + shift
pos = shifted_grouped_pos;
}
PostQK<TConfig>(q, pos, layer);
MulByConst(kQueryScale, q, kQKVDim);

Expand Down

0 comments on commit 5a2a7ee

Please sign in to comment.