Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[OpenAI] Add include_usage in streaming, add prefill decode speed to usage #456

Merged
merged 4 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ const reply = await engine.chat.completions.create({
messages,
});
console.log(reply.choices[0].message);
console.log(await engine.runtimeStatsText());
console.log(reply.usage);
```

### Streaming
Expand All @@ -180,17 +180,20 @@ const chunks = await engine.chat.completions.create({
messages,
temperature: 1,
stream: true, // <-- Enable streaming
stream_options: { include_usage: true },
});

let reply = "";
for await (const chunk of chunks) {
reply += chunk.choices[0].delta.content || "";
reply += chunk.choices[0]?.delta.content || "";
console.log(reply);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
}

const fullReply = await engine.getMessage()
const fullReply = await engine.getMessage();
console.log(fullReply);
console.log(await engine.runtimeStatsText());
```

## Advanced Usage
Expand Down
16 changes: 9 additions & 7 deletions examples/function-calling/src/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ async function main() {

const request: webllm.ChatCompletionRequest = {
stream: true, // works with stream as well, where the last chunk returns tool_calls
stream_options: { include_usage: true },
messages: [
{
role: "user",
Expand All @@ -55,24 +56,25 @@ async function main() {
if (!request.stream) {
const reply0 = await engine.chat.completions.create(request);
console.log(reply0.choices[0]);
console.log(reply0.usage);
} else {
// If streaming, the last chunk returns tool calls
const asyncChunkGenerator = await engine.chat.completions.create(request);
let message = "";
let lastChunk: webllm.ChatCompletionChunk | undefined;
let usageChunk: webllm.ChatCompletionChunk | undefined;
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
lastChunk = chunk;
if (!chunk.usage) {
lastChunk = chunk;
}
usageChunk = chunk;
}
console.log(lastChunk!.choices[0].delta);
console.log(usageChunk!.usage);
}

console.log(await engine.runtimeStatsText());
}

main();
12 changes: 6 additions & 6 deletions examples/get-started-web-worker/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async function mainNonStreaming() {
const reply0 = await engine.chat.completions.create(request);
console.log(reply0);

console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

/**
Expand All @@ -67,6 +67,7 @@ async function mainStreaming() {

const request: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{
role: "system",
Expand All @@ -86,15 +87,14 @@ async function mainStreaming() {
let message = "";
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
console.log("Final message:\n", await engine.getMessage()); // the concatenated message
console.log(await engine.runtimeStatsText());
}

// Run one of the function below
Expand Down
2 changes: 1 addition & 1 deletion examples/get-started/src/get_started.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async function main() {
top_logprobs: 2,
});
console.log(reply0);
console.log(await engine.runtimeStatsText());
console.log(reply0.usage);

// To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)`
}
Expand Down
2 changes: 1 addition & 1 deletion examples/json-mode/src/json_mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async function main() {
const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("First reply's last choice:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

main();
6 changes: 3 additions & 3 deletions examples/json-schema/src/json_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async function simpleStructuredTextExample() {
const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("Output:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

// The json schema and prompt is taken from
Expand Down Expand Up @@ -129,7 +129,7 @@ async function harryPotterExample() {
const reply = await engine.chatCompletion(request);
console.log(reply);
console.log("Output:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
console.log(reply.usage);
}

async function functionCallingExample() {
Expand Down Expand Up @@ -207,7 +207,7 @@ async function functionCallingExample() {
const reply = await engine.chat.completions.create(request);
console.log(reply.choices[0].message.content);

console.log(await engine.runtimeStatsText());
console.log(reply.usage);
}

async function main() {
Expand Down
4 changes: 2 additions & 2 deletions examples/multi-round-chat/src/multi_round_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async function main() {
const replyMessage0 = await engine.getMessage();
console.log(reply0);
console.log(replyMessage0);
console.log(reply0.usage);

// Round 1
// Append generated response to messages
Expand All @@ -62,6 +63,7 @@ async function main() {
const replyMessage1 = await engine.getMessage();
console.log(reply1);
console.log(replyMessage1);
console.log(reply1.usage);

// If we used multiround chat, request1 should only prefill a small number of tokens
const prefillTokens0 = reply0.usage?.prompt_tokens;
Expand All @@ -75,8 +77,6 @@ async function main() {
) {
throw Error("Multi-round chat is not triggered as expected.");
}

console.log(await engine.runtimeStatsText());
}

main();
3 changes: 2 additions & 1 deletion examples/seed-to-reproduce/src/seed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async function main() {
const reply0 = await engine.chat.completions.create(request);
console.log(reply0);
console.log("First reply's last choice:\n" + (await engine.getMessage()));
console.log(reply0.usage);

const reply1 = await engine.chat.completions.create(request);
console.log(reply1);
Expand All @@ -56,7 +57,7 @@ async function main() {
}
}

console.log(await engine.runtimeStatsText());
console.log(reply1.usage);
}

// Run one of the functions
Expand Down
12 changes: 6 additions & 6 deletions examples/service-worker/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async function mainNonStreaming() {
console.log(reply0);
setLabel("generate-label", reply0.choices[0].message.content || "");

console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

/**
Expand All @@ -84,6 +84,7 @@ async function mainStreaming() {

const request: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{
role: "system",
Expand All @@ -103,15 +104,14 @@ async function mainStreaming() {
let message = "";
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
console.log("Final message:\n", await engine.getMessage()); // the concatenated message
console.log(await engine.runtimeStatsText());
}

registerServiceWorker();
Expand Down
10 changes: 5 additions & 5 deletions examples/streaming/src/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ async function main() {

const request: webllm.ChatCompletionRequest = {
stream: true,
stream_options: { include_usage: true },
messages: [
{
role: "system",
Expand All @@ -39,15 +40,14 @@ async function main() {
let message = "";
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
message += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label", message);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
console.log("Final message:\n", await engine.getMessage()); // the concatenated message
console.log(await engine.runtimeStatsText());
}

main();
36 changes: 35 additions & 1 deletion src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -456,7 +456,6 @@ export class MLCEngine implements MLCEngineInterface {
tool_calls: tool_calls,
}
: {},
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
finish_reason: finish_reason,
index: 0,
},
Expand All @@ -466,6 +465,33 @@ export class MLCEngine implements MLCEngineInterface {
created: created,
};
yield lastChunk;

if (request.stream_options?.include_usage) {
const completion_tokens =
this.getPipeline().getCurRoundDecodingTotalTokens();
const prompt_tokens = this.getPipeline().getCurRoundPrefillTotalTokens();
const prefill_tokens_per_s =
this.getPipeline().getCurRoundPrefillTokensPerSec();
const decode_tokens_per_s =
this.getPipeline().getCurRoundDecodingTokensPerSec();
const usageChunk: ChatCompletionChunk = {
id: id,
choices: [],
usage: {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
extra: {
prefill_tokens_per_s: prefill_tokens_per_s,
decode_tokens_per_s: decode_tokens_per_s,
},
} as CompletionUsage,
model: model,
object: "chat.completion.chunk",
created: created,
};
yield usageChunk;
}
}

/**
Expand Down Expand Up @@ -522,6 +548,8 @@ export class MLCEngine implements MLCEngineInterface {
const choices: Array<ChatCompletion.Choice> = [];
let completion_tokens = 0;
let prompt_tokens = 0;
let prefill_time = 0;
let decode_time = 0;
for (let i = 0; i < n; i++) {
let outputMessage: string;
if (this.interruptSignal) {
Expand Down Expand Up @@ -573,6 +601,8 @@ export class MLCEngine implements MLCEngineInterface {
});
completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens();
prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens();
prefill_time += this.getPipeline().getCurRoundPrefillTotalTime();
decode_time += this.getPipeline().getCurRoundDecodingTotalTime();
}

const response: ChatCompletion = {
Expand All @@ -585,6 +615,10 @@ export class MLCEngine implements MLCEngineInterface {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
extra: {
prefill_tokens_per_s: prompt_tokens / prefill_time,
decode_tokens_per_s: completion_tokens / decode_time,
},
} as CompletionUsage,
};

Expand Down
Loading
Loading