Skip to content

Commit

Permalink
Introduce QueryResult in GemmaEnv and add a shortcut for WrapAndToken…
Browse files Browse the repository at this point in the history
…ize.

Remove max_tokens (and rely on only max_generated_tokens).

PiperOrigin-RevId: 684468885
  • Loading branch information
danielkeysers authored and copybara-github committed Oct 14, 2024
1 parent 2892e23 commit 47841cf
Show file tree
Hide file tree
Showing 17 changed files with 99 additions and 144 deletions.
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs
model : 2b-it
weights : [no path specified]
max_tokens : 3072
max_generated_tokens : 2048

*Usage*
Expand Down
1 change: 0 additions & 1 deletion backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) {
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
.max_tokens = 32,
.max_generated_tokens = 16,
.temperature = 1.0f,
.verbosity = 0,
Expand Down
22 changes: 11 additions & 11 deletions evals/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
size_t total_tokens = 0;
const double time_start = hwy::platform::Now();
for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = env.QueryModel(question);
total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) {
QueryResult result = env.QueryModel(question);
total_tokens += result.tokens_generated;
if (result.response.find(expected_answer) != std::string::npos) {
correct_answers++;
} else {
std::cout << "Wrong!\n";
std::cout << "Input: " << question << "\n";
std::cout << "Expected: " << expected_answer << "\n";
std::cout << "Output: " << answer << "\n\n" << std::flush;
std::cout << "Output: " << result.response << "\n\n" << std::flush;
}
}
LogSpeedStats(time_start, total_tokens);
Expand All @@ -108,17 +108,17 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now();
const auto [answer, token_count] = env.QueryModel(prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count);
QueryResult result = env.QueryModel(prompt);
std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;
LogSpeedStats(time_start, result.tokens_generated);
return EXIT_SUCCESS;
}

int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now();
float total_entropy = 0.0f;
Expand Down Expand Up @@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
while (std::getline(trivia_file, line)) {
json data = json::parse(line);
std::string q(data["question"]);
const auto [answer, token_count] = env.QueryModel(q);
std::cout << answer << "\n";
QueryResult result = env.QueryModel(q);
std::cout << result.response << "\n";
bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) {
if (answer.find(expected) != std::string::npos) {
if (result.response.find(expected) != std::string::npos) {
correct = true;
break;
}
Expand Down
70 changes: 33 additions & 37 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
#include <stdio.h>
#include <time.h>

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <utility> // std::pair
#include <vector>

// Placeholder for internal header, do not modify.
Expand Down Expand Up @@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
}
InitGenerator(inference, gen_);
runtime_config_ = {
.max_tokens = inference.max_tokens,
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = app.verbosity,
Expand All @@ -99,29 +96,29 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}

std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) {
std::string res;
size_t total_tokens = 0;
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;

const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this](
size_t query_index, size_t pos,
int token, float) {
++total_tokens;
res += StringFromTokens(std::vector<int>{token});
return true;
};
const BatchStreamFunc batch_stream_token =
[&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/,
int token, float /*score*/) {
++result.tokens_generated;
result.response += StringFromTokens(std::vector<int>{token});
if (result.tokens_generated == tokens.size()) {
result.response_start_pos = result.response.size();
}
return true;
};
if (runtime_config_.verbosity >= 2) {
std::cout << "Max tokens: " << runtime_config_.max_tokens
<< "\tmax generated tokens: "
std::cout << "max generated tokens: "
<< runtime_config_.max_generated_tokens
<< "\ttemperature: " << runtime_config_.temperature << "\n";
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
return {res, total_tokens};
return result;
}

void GemmaEnv::QueryModel(
Expand All @@ -134,27 +131,29 @@ void GemmaEnv::QueryModel(
runtime_config_.stream_token = previous_stream_token;
}

std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0);
std::vector<std::pair<std::string, size_t>> res(num_queries);
std::fill(res.begin(), res.end(), std::make_pair("", 0));
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index,
size_t pos, int token,
float) {
std::vector<QueryResult> res(num_queries);
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
size_t query_index, size_t pos,
int token, float) {
std::string token_text;
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].first.append(token_text);
res[query_index].second += 1;
res[query_index].response.append(token_text);
res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) {
res[query_index].response_start_pos = res[query_index].response.size();
}
return true;
};
if (runtime_config_.verbosity >= 2) {
fprintf(stderr,
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_tokens, runtime_config_.max_generated_tokens,
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);
}

Expand All @@ -178,21 +177,18 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
return res;
}

std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, input);
QueryResult GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt);
}

std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt));
prompts.push_back(WrapAndTokenize(mutable_prompt));
}
std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size());
Expand All @@ -206,7 +202,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());
Expand Down
26 changes: 19 additions & 7 deletions evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>

#include "gemma/gemma.h"
Expand All @@ -33,6 +32,14 @@ namespace gcpp {

void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);

// Return type for query model calls.
struct QueryResult {
std::string response;
size_t tokens_generated = 0;
// The position in the response at which the generated tokens start.
size_t response_start_pos = 0;
};

// Convenience class to load a model and run inference.
class GemmaEnv {
public:
Expand All @@ -41,8 +48,9 @@ class GemmaEnv {
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);

size_t MaxTokens() const { return runtime_config_.max_tokens; }
// Sets the maximum number of output tokens to generate.
size_t MaxGeneratedTokens() const {
return runtime_config_.max_generated_tokens;
}
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
runtime_config_.max_generated_tokens = max_generated_tokens;
}
Expand All @@ -59,6 +67,10 @@ class GemmaEnv {
return tokens;
}

std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
}

std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
Expand All @@ -67,12 +79,12 @@ class GemmaEnv {

// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
QueryResult QueryModel(const std::vector<int>& tokens);
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt);
// Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input);
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);

// Runs inference on the given input and calls the callback for each token.
Expand Down
6 changes: 3 additions & 3 deletions evals/benchmarks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0;
for (auto s : state) {
std::string prompt = original_prompt; // reset from original
auto [response, n] = s_env->QueryModel(prompt);
QueryResult result = s_env->QueryModel(prompt);
if (s_env->Verbosity() != 0) {
fprintf(stdout, "|%s|\n", response.c_str());
fprintf(stdout, "|%s|\n", result.response.c_str());
}
total_tokens += n;
total_tokens += result.tokens_generated;
}

state.SetItemsProcessed(total_tokens);
Expand Down
11 changes: 5 additions & 6 deletions evals/cross_entropy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace gcpp {

HWY_EXPORT(CallSoftmax);

float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) {
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
Expand All @@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
size_t vocab_size) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
// We are called for each token, but pos starts at 1. Clamping max_tokens
// to prompt.size() should prevent overrun.
// We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size());
const int token = prompt[pos];
const float prob = probs[token];
Expand All @@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
};

std::vector<int> prompt0 = { prompt[0] };
max_tokens = HWY_MIN(max_tokens, prompt.size());
max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
Expand Down
2 changes: 1 addition & 1 deletion evals/cross_entropy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace gcpp {

float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);

Expand Down
5 changes: 3 additions & 2 deletions evals/debug_prompt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ int Run(int argc, char** argv) {
json_base[std::to_string(pos)][debug_key] = v;
};

const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
QueryResult result = env.QueryModel(prompt_args.prompt);
std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;

if (env.MutableConfig().layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
Expand Down
17 changes: 8 additions & 9 deletions evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test {
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
std::string mutable_prompt = prompt;
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
return response;
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
return result.response;
}
// Otherwise, do not use turn structure.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens);
return response;
QueryResult result = s_env->QueryModel(tokens);
return result.response;
}

std::vector<std::string> BatchGemmaReply(
Expand All @@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test {
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
replies.push_back(response);
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
Expand All @@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
replies.push_back(response);
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
replies.push_back(result.response);
}
return replies;
}
Expand Down Expand Up @@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) {
return true;
};
RuntimeConfig runtime_config{
.max_tokens = 128,
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,
Expand Down
Loading

0 comments on commit 47841cf

Please sign in to comment.