Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce QueryResult in GemmaEnv and add a shortcut for WrapAndTokenize. #419

Merged
merged 1 commit into from
Oct 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,6 @@ tokenizer : tokenizer.spm
compressed_weights : 2b-it-sfp.sbs
model : 2b-it
weights : [no path specified]
max_tokens : 3072
max_generated_tokens : 2048

*Usage*
Expand Down
1 change: 0 additions & 1 deletion backprop/optimize_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ TEST(OptimizeTest, GradientDescent) {
return token != ReverseSequenceSampler::kEndToken;
};
RuntimeConfig runtime = {
.max_tokens = 32,
.max_generated_tokens = 16,
.temperature = 1.0f,
.verbosity = 0,
Expand Down
22 changes: 11 additions & 11 deletions evals/benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,15 +81,15 @@ int BenchmarkGoldens(GemmaEnv& env, const std::string& golden_path) {
size_t total_tokens = 0;
const double time_start = hwy::platform::Now();
for (auto& [question, expected_answer] : queries_answers) {
const auto [answer, token_count] = env.QueryModel(question);
total_tokens += token_count;
if (answer.find(expected_answer) != std::string::npos) {
QueryResult result = env.QueryModel(question);
total_tokens += result.tokens_generated;
if (result.response.find(expected_answer) != std::string::npos) {
correct_answers++;
} else {
std::cout << "Wrong!\n";
std::cout << "Input: " << question << "\n";
std::cout << "Expected: " << expected_answer << "\n";
std::cout << "Output: " << answer << "\n\n" << std::flush;
std::cout << "Output: " << result.response << "\n\n" << std::flush;
}
}
LogSpeedStats(time_start, total_tokens);
Expand All @@ -108,17 +108,17 @@ int BenchmarkSummary(GemmaEnv& env, const Path& text) {
prompt.append(ReadFileToString(text));
prompt.append("\nSummarize this text.\n");
const double time_start = hwy::platform::Now();
const auto [answer, token_count] = env.QueryModel(prompt);
std::cout << answer.substr(prompt.size()) << "\n" << std::flush;
LogSpeedStats(time_start, token_count);
QueryResult result = env.QueryModel(prompt);
std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;
LogSpeedStats(time_start, result.tokens_generated);
return EXIT_SUCCESS;
}

int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t batch_tokens) {
std::string input = ReadFileToString(text);
std::vector<int> prompt = env.Tokenize(input);
prompt.resize(std::min<size_t>(env.MaxTokens(), prompt.size()));
std::cout << "Number of input tokens: " << prompt.size() << "\n";
const double time_start = hwy::platform::Now();
float total_entropy = 0.0f;
Expand Down Expand Up @@ -156,11 +156,11 @@ int BenchmarkTriviaQA(GemmaEnv& env, const Path& json_file,
while (std::getline(trivia_file, line)) {
json data = json::parse(line);
std::string q(data["question"]);
const auto [answer, token_count] = env.QueryModel(q);
std::cout << answer << "\n";
QueryResult result = env.QueryModel(q);
std::cout << result.response << "\n";
bool correct = false;
for (const std::string expected : data["answer"]["aliases"]) {
if (answer.find(expected) != std::string::npos) {
if (result.response.find(expected) != std::string::npos) {
correct = true;
break;
}
Expand Down
70 changes: 33 additions & 37 deletions evals/benchmark_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,12 @@
#include <stdio.h>
#include <time.h>

#include <algorithm>
#include <cstdio>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <utility> // std::pair
#include <vector>

// Placeholder for internal header, do not modify.
Expand Down Expand Up @@ -76,7 +74,6 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
}
InitGenerator(inference, gen_);
runtime_config_ = {
.max_tokens = inference.max_tokens,
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.verbosity = app.verbosity,
Expand All @@ -99,29 +96,29 @@ GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}

std::pair<std::string, size_t> GemmaEnv::QueryModel(
const std::vector<int>& tokens) {
std::string res;
size_t total_tokens = 0;
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;

const BatchStreamFunc batch_stream_token = [&res, &total_tokens, this](
size_t query_index, size_t pos,
int token, float) {
++total_tokens;
res += StringFromTokens(std::vector<int>{token});
return true;
};
const BatchStreamFunc batch_stream_token =
[&result, &tokens, this](size_t /*query_index*/, size_t /*pos*/,
int token, float /*score*/) {
++result.tokens_generated;
result.response += StringFromTokens(std::vector<int>{token});
if (result.tokens_generated == tokens.size()) {
result.response_start_pos = result.response.size();
}
return true;
};
if (runtime_config_.verbosity >= 2) {
std::cout << "Max tokens: " << runtime_config_.max_tokens
<< "\tmax generated tokens: "
std::cout << "max generated tokens: "
<< runtime_config_.max_generated_tokens
<< "\ttemperature: " << runtime_config_.temperature << "\n";
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
return {res, total_tokens};
return result;
}

void GemmaEnv::QueryModel(
Expand All @@ -134,27 +131,29 @@ void GemmaEnv::QueryModel(
runtime_config_.stream_token = previous_stream_token;
}

std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const QueriesPromptTokens& queries_prompt) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries != 0);
std::vector<std::pair<std::string, size_t>> res(num_queries);
std::fill(res.begin(), res.end(), std::make_pair("", 0));
const BatchStreamFunc batch_stream_token = [&res, this](size_t query_index,
size_t pos, int token,
float) {
std::vector<QueryResult> res(num_queries);
const BatchStreamFunc batch_stream_token = [&res, &queries_prompt, this](
size_t query_index, size_t pos,
int token, float) {
std::string token_text;
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].first.append(token_text);
res[query_index].second += 1;
res[query_index].response.append(token_text);
res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
queries_prompt[query_index].size()) {
res[query_index].response_start_pos = res[query_index].response.size();
}
return true;
};
if (runtime_config_.verbosity >= 2) {
fprintf(stderr,
"Max tok: %zu max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_tokens, runtime_config_.max_generated_tokens,
runtime_config_.temperature, runtime_config_.prefill_tbatch_size,
fprintf(stderr, "Max gen: %zu temp: %f tbatch: %zu qbatch: %zu\n",
runtime_config_.max_generated_tokens, runtime_config_.temperature,
runtime_config_.prefill_tbatch_size,
runtime_config_.decode_qbatch_size);
}

Expand All @@ -178,21 +177,18 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
return res;
}

std::pair<std::string, size_t> GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt =
WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, input);
QueryResult GemmaEnv::QueryModel(std::string& input) {
const std::vector<int> prompt = WrapAndTokenize(input);
return QueryModel(prompt);
}

std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
std::vector<QueryResult> GemmaEnv::BatchQueryModel(
const std::vector<std::string>& inputs) {
std::vector<std::vector<int>> prompts;
prompts.reserve(inputs.size());
for (auto& input : inputs) {
std::string mutable_prompt = input;
prompts.push_back(WrapAndTokenize(model_->Tokenizer(), model_->Info(),
/*pos=*/0, mutable_prompt));
prompts.push_back(WrapAndTokenize(mutable_prompt));
}
std::vector<PromptTokens> prompt_vector;
prompt_vector.reserve(prompts.size());
Expand All @@ -206,7 +202,7 @@ std::vector<std::pair<std::string, size_t>> GemmaEnv::BatchQueryModel(
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_tokens=*/3072, prompt,
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());
Expand Down
26 changes: 19 additions & 7 deletions evals/benchmark_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include <memory>
#include <random>
#include <string>
#include <utility>
#include <vector>

#include "gemma/gemma.h"
Expand All @@ -33,6 +32,14 @@ namespace gcpp {

void InitGenerator(const InferenceArgs& inference, std::mt19937& gen);

// Return type for query model calls.
struct QueryResult {
std::string response;
size_t tokens_generated = 0;
// The position in the response at which the generated tokens start.
size_t response_start_pos = 0;
};

// Convenience class to load a model and run inference.
class GemmaEnv {
public:
Expand All @@ -41,8 +48,9 @@ class GemmaEnv {
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);

size_t MaxTokens() const { return runtime_config_.max_tokens; }
// Sets the maximum number of output tokens to generate.
size_t MaxGeneratedTokens() const {
return runtime_config_.max_generated_tokens;
}
void SetMaxGeneratedTokens(size_t max_generated_tokens) {
runtime_config_.max_generated_tokens = max_generated_tokens;
}
Expand All @@ -59,6 +67,10 @@ class GemmaEnv {
return tokens;
}

std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->Info(), 0, input);
}

std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
Expand All @@ -67,12 +79,12 @@ class GemmaEnv {

// Runs inference on the given input and returns the top-1 result string and
// the number of tokens that were generated.
std::pair<std::string, size_t> QueryModel(const std::vector<int>& tokens);
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
QueryResult QueryModel(const std::vector<int>& tokens);
std::vector<QueryResult> BatchQueryModel(
const QueriesPromptTokens& queries_prompt);
// Adds turn structure to input, tokenizes and calls the above overload.
std::pair<std::string, size_t> QueryModel(std::string& input);
std::vector<std::pair<std::string, size_t>> BatchQueryModel(
QueryResult QueryModel(std::string& input);
std::vector<QueryResult> BatchQueryModel(
const std::vector<std::string>& inputs);

// Runs inference on the given input and calls the callback for each token.
Expand Down
6 changes: 3 additions & 3 deletions evals/benchmarks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ void RunPrompt(const std::string& original_prompt, benchmark::State& state) {
size_t total_tokens = 0;
for (auto s : state) {
std::string prompt = original_prompt; // reset from original
auto [response, n] = s_env->QueryModel(prompt);
QueryResult result = s_env->QueryModel(prompt);
if (s_env->Verbosity() != 0) {
fprintf(stdout, "|%s|\n", response.c_str());
fprintf(stdout, "|%s|\n", result.response.c_str());
}
total_tokens += n;
total_tokens += result.tokens_generated;
}

state.SetItemsProcessed(total_tokens);
Expand Down
11 changes: 5 additions & 6 deletions evals/cross_entropy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ namespace gcpp {

HWY_EXPORT(CallSoftmax);

float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity) {
const StreamFunc stream_token = [](int /*token*/, float) { return true; };
Expand All @@ -112,8 +112,8 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
size_t vocab_size) -> TokenAndProb {
// input is logits, not yet probabilities
HWY_DYNAMIC_DISPATCH(CallSoftmax)(probs, vocab_size);
// We are called for each token, but pos starts at 1. Clamping max_tokens
// to prompt.size() should prevent overrun.
// We are called for each token, but pos starts at 1. Clamping
// max_generated_tokens to prompt.size() should prevent overrun.
HWY_ASSERT(pos < prompt.size());
const int token = prompt[pos];
const float prob = probs[token];
Expand All @@ -136,10 +136,9 @@ float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
};

std::vector<int> prompt0 = { prompt[0] };
max_tokens = HWY_MIN(max_tokens, prompt.size());
max_generated_tokens = HWY_MIN(max_generated_tokens, prompt.size());
RuntimeConfig runtime = {
.max_tokens = max_tokens,
.max_generated_tokens = max_tokens - 1,
.max_generated_tokens = max_generated_tokens - 1,
.temperature = 0.0f,
.verbosity = verbosity,
.gen = nullptr,
Expand Down
2 changes: 1 addition & 1 deletion evals/cross_entropy.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

namespace gcpp {

float ComputeCrossEntropy(Gemma& gemma, size_t max_tokens,
float ComputeCrossEntropy(Gemma& gemma, size_t max_generated_tokens,
const std::vector<int>& prompt, KVCache& kv_cache,
int verbosity);

Expand Down
5 changes: 3 additions & 2 deletions evals/debug_prompt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,9 @@ int Run(int argc, char** argv) {
json_base[std::to_string(pos)][debug_key] = v;
};

const auto [answer, token_count] = env.QueryModel(prompt_args.prompt);
std::cout << answer.substr(prompt_args.prompt.size()) << "\n" << std::flush;
QueryResult result = env.QueryModel(prompt_args.prompt);
std::cout << result.response.substr(result.response_start_pos) << "\n"
<< std::flush;

if (env.MutableConfig().layers_output) {
std::ofstream output_f(prompt_args.layers_output.path, std::ofstream::out);
Expand Down
17 changes: 8 additions & 9 deletions evals/gemma_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ class GemmaTest : public ::testing::Test {
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
std::string mutable_prompt = prompt;
auto [response, n] = s_env->QueryModel(mutable_prompt); // Uses turns.
return response;
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
return result.response;
}
// Otherwise, do not use turn structure.
const std::vector<int> tokens = s_env->TokenizeAndPrependBOS(prompt);
auto [response, n] = s_env->QueryModel(tokens);
return response;
QueryResult result = s_env->QueryModel(tokens);
return result.response;
}

std::vector<std::string> BatchGemmaReply(
Expand All @@ -72,8 +72,8 @@ class GemmaTest : public ::testing::Test {
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
for (auto [response, n] : s_env->BatchQueryModel(inputs)) {
replies.push_back(response);
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
return replies;
}
Expand All @@ -88,8 +88,8 @@ class GemmaTest : public ::testing::Test {
prompt_spans.push_back(PromptTokens(prompt.data(), prompt.size()));
}
QueriesPromptTokens prompts(prompt_spans.data(), prompt_spans.size());
for (auto [response, n] : s_env->BatchQueryModel(prompts)) {
replies.push_back(response);
for (const QueryResult& result : s_env->BatchQueryModel(prompts)) {
replies.push_back(result.response);
}
return replies;
}
Expand Down Expand Up @@ -167,7 +167,6 @@ TEST_F(GemmaTest, Multiturn) {
return true;
};
RuntimeConfig runtime_config{
.max_tokens = 128,
.max_generated_tokens = 64,
.temperature = 0.0f,
.verbosity = 2,
Expand Down
Loading
Loading