Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 657831926
  • Loading branch information
pculliton authored and copybara-github committed Jul 31, 2024
1 parent a24eda8 commit 1982a6b
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 6 deletions.
3 changes: 3 additions & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ cc_library(
"gemma/instantiations/gr2b_bf16.cc",
"gemma/instantiations/gr2b_f32.cc",
"gemma/instantiations/gr2b_sfp.cc",
"gemma/instantiations/gemma2_2b_bf16.cc",
"gemma/instantiations/gemma2_2b_f32.cc",
"gemma/instantiations/gemma2_2b_sfp.cc",
],
hdrs = [
"gemma/activations.h",
Expand Down
3 changes: 3 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ set(SOURCES
gemma/instantiations/tiny_bf16.cc
gemma/instantiations/tiny_f32.cc
gemma/instantiations/tiny_sfp.cc
gemma/instantiations/gemma2_2b_bf16.cc
gemma/instantiations/gemma2_2b_f32.cc
gemma/instantiations/gemma2_2b_sfp.cc
gemma/kv_cache.cc
gemma/kv_cache.h
gemma/tokenizer.cc
Expand Down
15 changes: 9 additions & 6 deletions gemma/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,13 @@
namespace gcpp {

constexpr const char* kModelFlags[] = {
"2b-pt", "2b-it", // Gemma 2B
"7b-pt", "7b-it", // Gemma 7B
"9b-pt", "9b-it", // Gemma 9B
"27b-pt", "27b-it", // Gemma 27B
"gr2b-pt", "gr2b-it", // RecurrentGemma
"tiny", // Gemma Tiny (mostly for debugging)
"2b-pt", "2b-it", // Gemma 2B
"7b-pt", "7b-it", // Gemma 7B
"9b-pt", "9b-it", // Gemma 9B
"27b-pt", "27b-it", // Gemma 27B
"gr2b-pt", "gr2b-it", // RecurrentGemma
"tiny", // Gemma Tiny (mostly for debugging)
"gemma2-2b-pt", "gemma2-2b-it", // Gemma2 2B
};
constexpr Model kModelTypes[] = {
Model::GEMMA_2B, Model::GEMMA_2B, // Gemma 2B
Expand All @@ -43,6 +44,7 @@ constexpr Model kModelTypes[] = {
Model::GEMMA_27B, Model::GEMMA_27B, // Gemma 27B
Model::GRIFFIN_2B, Model::GRIFFIN_2B, // RecurrentGemma
Model::GEMMA_TINY, // Gemma Tiny
Model::GEMMA2_2B, Model::GEMMA2_2B, // Gemma2 2B
};
constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B
Expand All @@ -51,6 +53,7 @@ constexpr ModelTraining kModelTraining[] = {
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 27B
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // RecurrentGemma
ModelTraining::GEMMA_IT, // Gemma Tiny
ModelTraining::GEMMA_PT, ModelTraining::GEMMA_IT, // Gemma 2B2
};

constexpr size_t kNumModelFlags =
Expand Down
10 changes: 10 additions & 0 deletions gemma/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ enum class Model {
GEMMA_27B,
GRIFFIN_2B,
GEMMA_TINY,
GEMMA2_2B,
};

// Instruction-tuned models require extra 'turn structure' tokens in prompts.
Expand Down Expand Up @@ -99,6 +100,9 @@ decltype(auto) CallForModel(Model model, TArgs&&... args) {
return FuncT<ConfigGemma27B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GRIFFIN_2B:
return FuncT<ConfigGriffin2B<TWeight>>()(std::forward<TArgs>(args)...);
case Model::GEMMA2_2B:
return FuncT<ConfigGemma2_2B<TWeight>>()(std::forward<TArgs>(args)...);

default:
HWY_ABORT("Model type %d unknown.", static_cast<int>(model));
}
Expand Down Expand Up @@ -142,6 +146,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
GEMMA_FOREACH_WEIGHT(X, ConfigGemma9B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma27B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGriffin2B) \
GEMMA_FOREACH_WEIGHT(X, ConfigGemma2_2B) \
static_assert(true, "Allow trailing ;")

// Used by GEMMA_EXPORT_AND_DISPATCH. For a given TWEIGHT (e.g. float),
Expand Down Expand Up @@ -178,6 +183,11 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
ARGS; \
break; \
} \
case Model::GEMMA2_2B: { \
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(FUNC<ConfigGemma2_2B<TWEIGHT>>) \
ARGS; \
break; \
} \
default: \
HWY_ABORT("Model type %d unknown.", static_cast<int>(MODEL)); \
}
Expand Down
22 changes: 22 additions & 0 deletions gemma/configs.h
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,28 @@ struct ConfigGemma2B : public ConfigBaseGemmaV1 {
static constexpr bool kAbsolutePE = false;
};

template <typename TWeight>
struct ConfigGemma2_2B : public ConfigBaseGemmaV2 {
using Weight = TWeight; // make accessible where we only have a TConfig

static constexpr int kSeqLen = 8192;
static constexpr int kVocabSize = 256000;
static constexpr std::array<LayerAttentionType, 26> kLayerConfig =
FixedLayerConfig<26>(LayerAttentionType::kGemma);
static constexpr std::array<size_t, 26> kAttentionWindowSizes =
RepeatedAttentionWindowSizes<26, 2>({4096, kSeqLen});
static constexpr int kLayers = kLayerConfig.size();
static constexpr int kGemmaLayers = kLayers;
static constexpr int kModelDim = 2304;
static constexpr int kFFHiddenDim = 8 * 2304 / 2; // = 9216
static constexpr int kHeads = 8;
static constexpr int kKVHeads = 4;
static constexpr int kQKVDim = 256; // query size == key size == value size
static constexpr int kTopK = gcpp::kTopK;
static constexpr bool kAbsolutePE = false;
static constexpr QueryScaleType kQueryScale = QueryScaleType::SqrtKeySize;
};

template <typename TWeight>
struct ConfigGemmaTiny : public ConfigNoSSM {
using Weight = TWeight; // make accessible where we only have a TConfig
Expand Down
21 changes: 21 additions & 0 deletions gemma/instantiations/gemma2_2b_bf16.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_bf16.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<hwy::bfloat16_t>
#include "gemma/gemma-inl.h"
21 changes: 21 additions & 0 deletions gemma/instantiations/gemma2_2b_f32.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_f32.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<float>
#include "gemma/gemma-inl.h"
21 changes: 21 additions & 0 deletions gemma/instantiations/gemma2_2b_sfp.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"gemma/instantiations/gemma2_2b_sfp.cc"
#include "hwy/foreach_target.h" // IWYU pragma: keep
#define GEMMA_CONFIG ConfigGemma2_2B<SfpStream>
#include "gemma/gemma-inl.h"

0 comments on commit 1982a6b

Please sign in to comment.