Skip to content

Commit

Permalink
compile success: set default self extend values in noSSM and griffin
Browse files Browse the repository at this point in the history
  • Loading branch information
jonpsy committed Oct 19, 2024
1 parent 02ce1e3 commit 8cf3966
Show file tree
Hide file tree
Showing 2 changed files with 275 additions and 0 deletions.
260 changes: 260 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,266 @@ ModelConfig ConfigFromModel(Model model);

// Returns the sub-config for the ViT model of the PaliGemma model.
ModelConfig VitConfig(const ModelConfig& config);
template <class TConfig, typename = void>
struct CacheLayerSize {
constexpr size_t operator()() const {
return TConfig::kKVHeads * TConfig::kQKVDim * 2;
}
};

template <class TConfig, typename = void>
struct CachePosSize {
constexpr size_t operator()() const {
return TConfig::kGemmaLayers * CacheLayerSize<TConfig>()();
}
};

struct ConfigNoSSM {
static constexpr int kGriffinLayers = 0;

static constexpr int kConv1dWidth = 0;
static constexpr bool kFFBiases = false;
static constexpr bool kSoftmaxAttnOutputBiases = false;
static constexpr bool kUseHalfRope = false;
static constexpr bool kUseLocalAttention = false;
static constexpr bool kInterleaveQKV = true;
static constexpr int kNumTensorScales = 0;

static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr ResidualType kResidual = ResidualType::Add;

// Self-extend parameters with defaul values
static constexpr bool kSelfExtend = false;
static constexpr size_t kSelfExtendNgbSize = 0;
static constexpr size_t kSelfExtendGrpSize = 1;
};

struct ConfigBaseGemmaV1 : ConfigNoSSM {
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

struct ConfigBaseGemmaV2 : ConfigNoSSM {
static constexpr float kAttCap = 50.0f;
static constexpr float kFinalCap = 30.0f;
static constexpr PostNormType kPostNorm = PostNormType::Scale;
};

template <typename TWeight>
struct ConfigGemma27B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 46> kLayerConfig =
FixedLayerConfig<46>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 46> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<46, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 4608;
static constexpr int kFFHiddenDim = 16 * 4608 / 2; // = 36864
static constexpr int kHeads = 32;
static constexpr int kKVHeads = 16;
static constexpr int kQKVDim = 128; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale =
QueryScaleType::SqrtModelDimDivNumHeads;
};

template <typename TWeight>
struct ConfigGemma9B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 42> kLayerConfig =
FixedLayerConfig<42>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 42> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<42, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3584;
static constexpr int kFFHiddenDim = 8 * 3584 / 2; // = 14336
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 8;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

template <typename TWeight>
struct ConfigGemma7B : public ConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 28> kLayerConfig =
FixedLayerConfig<28>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 28> kAttentionWindowSizes =
FixedAttentionWindowSizes<28>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 3072;
static constexpr int kFFHiddenDim = 16 * 3072 / 2; // = 24576
static constexpr int kHeads = 16;
static constexpr int kKVHeads = 16; // standard MHA
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};

template <typename TWeight>
struct ConfigGemma2B : public ConfigBaseGemmaV1 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = gcpp::kSeqLen;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 18> kLayerConfig =
FixedLayerConfig<18>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 18> kAttentionWindowSizes =
FixedAttentionWindowSizes<18>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2048;
static constexpr int kFFHiddenDim = 16 * 2048 / 2; // = 16384
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
};

template <typename TWeight>
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
FixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

template <typename TWeight>
struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 32;
static constexpr int kVocabSize = 64;
static constexpr std::array<LayerAttentionType, 3> kLayerConfig =
FixedLayerConfig<3>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 3> kAttentionWindowSizes =
FixedAttentionWindowSizes<3>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 128;
static constexpr int kFFHiddenDim = 256;
static constexpr int kHeads = 4;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 16; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;

static constexpr float kAttCap = 0.0f;
// This is required for optimize_test to pass.
static constexpr float kFinalCap = 30.0f;
};

template <typename TWeight>
struct ConfigGriffin2B {
using Weight = TWeight; // make accessible where we only have a TConfig

// Griffin uses local attention, so kSeqLen is actually the local attention
// window.
static constexpr int kSeqLen = 2048;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig = {
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGemma,
LayerAttentionType::kGriffinRecurrentBlock,
LayerAttentionType::kGriffinRecurrentBlock,
};
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
FixedAttentionWindowSizes<26>(kSeqLen);
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers =
NumLayersOfTypeBefore(kLayerConfig, LayerAttentionType::kGemma, kLayers);
static constexpr int kGriffinLayers =
NumLayersOfTypeBefore(kLayerConfig,
LayerAttentionType::kGriffinRecurrentBlock,
kLayers);
static constexpr int kModelDim = 2560;
static constexpr int kFFHiddenDim = 7680;
static constexpr int kHeads = 10;
static constexpr int kKVHeads = 1;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr PostNormType kPostNorm = PostNormType::None;

// No SoftCap.
static constexpr float kAttCap = 0.0f;
static constexpr float kFinalCap = 0.0f;

// SSM config.
static constexpr int kConv1dWidth = 4;
static constexpr bool kFFBiases = true;
static constexpr bool kSoftmaxAttnOutputBiases = true;
static constexpr bool kUseHalfRope = true;
static constexpr bool kUseLocalAttention = true;
static constexpr bool kInterleaveQKV = false;
static constexpr int kNumTensorScales = 140;
static constexpr PostQKType kPostQK = PostQKType::Rope;
static constexpr ActivationType kActivation = ActivationType::Gelu;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
static constexpr ResidualType kResidual = ResidualType::Add;

// Self-extend parameters with defaul values
static constexpr bool kSelfExtend = false;
static constexpr size_t kSelfExtendNgbSize = 0;
static constexpr size_t kSelfExtendGrpSize = 1;
};

} // namespace gcpp

Expand Down
15 changes: 15 additions & 0 deletions gemma/gemma-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,13 @@ class GemmaAttention {
PositionalEncodingQK(is_mha_ ? mha_kv : kv, pos, layer_, 1.0f,
kv);

// When embedding position, we will use grouped key position
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
pos /= grp_size;
}
}

// If MHA, also copy V into KVCache.
if (is_mha_) {
hwy::CopyBytes(mha_kv + layer_config_.qkv_dim,
Expand Down Expand Up @@ -417,6 +424,14 @@ class GemmaAttention {

// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
if constexpr (TConfig::kSelfExtend) {
if (pos > ngb_size) {
const size_t grp_pos = pos / grp_size;
const size_t shift = ngb_size - ngb_size / grp_size;
const size_t shifted_grouped_pos = grp_pos + shift;
pos = shifted_grouped_pos;
}
}
PositionalEncodingQK(q, pos, layer_, query_scale, q);

const size_t start_pos = StartPos(pos, layer_);
Expand Down

0 comments on commit 8cf3966

Please sign in to comment.