Skip to content

Commit

Permalink
[OpenAI] Add usage to streaming, add prefill and decode speed to usage
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlieFRuan committed Jun 4, 2024
1 parent 22f0d37 commit 4d340b1
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 19 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ const reply = await engine.chat.completions.create({
messages,
});
console.log(reply.choices[0].message);
console.log(await engine.runtimeStatsText());
console.log(reply.usage);
```

### Streaming
Expand All @@ -183,14 +183,16 @@ const chunks = await engine.chat.completions.create({
});

let reply = "";
let lastChunk: webllm.ChatCompletionChunk | undefined = undefined;
for await (const chunk of chunks) {
reply += chunk.choices[0].delta.content || "";
lastChunk = chunk;
console.log(reply);
}

const fullReply = await engine.getMessage()
console.log(fullReply);
console.log(await engine.runtimeStatsText());
console.log(lastChunk.usage);
```

## Advanced Usage
Expand Down
4 changes: 2 additions & 2 deletions examples/function-calling/src/function_calling.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ async function main() {
if (!request.stream) {
const reply0 = await engine.chat.completions.create(request);
console.log(reply0.choices[0]);
console.log(reply0.usage);
} else {
// If streaming, the last chunk returns tool calls
const asyncChunkGenerator = await engine.chat.completions.create(request);
Expand All @@ -70,9 +71,8 @@ async function main() {
lastChunk = chunk;
}
console.log(lastChunk!.choices[0].delta);
console.log(lastChunk!.usage);
}

console.log(await engine.runtimeStatsText());
}

main();
6 changes: 4 additions & 2 deletions examples/get-started-web-worker/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async function mainNonStreaming() {
const reply0 = await engine.chat.completions.create(request);
console.log(reply0);

console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

/**
Expand Down Expand Up @@ -84,17 +84,19 @@ async function mainStreaming() {

const asyncChunkGenerator = await engine.chat.completions.create(request);
let message = "";
let lastChunk: webllm.ChatCompletionChunk | undefined = undefined;
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
setLabel("generate-label", message);
lastChunk = chunk;
// engine.interruptGenerate(); // works with interrupt as well
}
console.log("Final message:\n", await engine.getMessage()); // the concatenated message
console.log(await engine.runtimeStatsText());
console.log(lastChunk!.usage);
}

// Run one of the function below
Expand Down
2 changes: 1 addition & 1 deletion examples/get-started/src/get_started.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async function main() {
top_logprobs: 2,
});
console.log(reply0);
console.log(await engine.runtimeStatsText());
console.log(reply0.usage);

// To change model, either create a new engine via `CreateMLCEngine()`, or call `engine.reload(modelId)`
}
Expand Down
2 changes: 1 addition & 1 deletion examples/json-mode/src/json_mode.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async function main() {
const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("First reply's last choice:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

main();
6 changes: 3 additions & 3 deletions examples/json-schema/src/json_schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async function simpleStructuredTextExample() {
const reply0 = await engine.chatCompletion(request);
console.log(reply0);
console.log("Output:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

// The json schema and prompt is taken from
Expand Down Expand Up @@ -129,7 +129,7 @@ async function harryPotterExample() {
const reply = await engine.chatCompletion(request);
console.log(reply);
console.log("Output:\n" + (await engine.getMessage()));
console.log(await engine.runtimeStatsText());
console.log(reply.usage);
}

async function functionCallingExample() {
Expand Down Expand Up @@ -207,7 +207,7 @@ async function functionCallingExample() {
const reply = await engine.chat.completions.create(request);
console.log(reply.choices[0].message.content);

console.log(await engine.runtimeStatsText());
console.log(reply.usage);
}

async function main() {
Expand Down
4 changes: 2 additions & 2 deletions examples/multi-round-chat/src/multi_round_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ async function main() {
const replyMessage0 = await engine.getMessage();
console.log(reply0);
console.log(replyMessage0);
console.log(reply0.usage);

// Round 1
// Append generated response to messages
Expand All @@ -62,6 +63,7 @@ async function main() {
const replyMessage1 = await engine.getMessage();
console.log(reply1);
console.log(replyMessage1);
console.log(reply1.usage);

// If we used multiround chat, request1 should only prefill a small number of tokens
const prefillTokens0 = reply0.usage?.prompt_tokens;
Expand All @@ -75,8 +77,6 @@ async function main() {
) {
throw Error("Multi-round chat is not triggered as expected.");
}

console.log(await engine.runtimeStatsText());
}

main();
3 changes: 2 additions & 1 deletion examples/seed-to-reproduce/src/seed.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async function main() {
const reply0 = await engine.chat.completions.create(request);
console.log(reply0);
console.log("First reply's last choice:\n" + (await engine.getMessage()));
console.log(reply0.usage);

const reply1 = await engine.chat.completions.create(request);
console.log(reply1);
Expand All @@ -56,7 +57,7 @@ async function main() {
}
}

console.log(await engine.runtimeStatsText());
console.log(reply1.usage);
}

// Run one of the functions
Expand Down
6 changes: 4 additions & 2 deletions examples/service-worker/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async function mainNonStreaming() {
console.log(reply0);
setLabel("generate-label", reply0.choices[0].message.content || "");

console.log(await engine.runtimeStatsText());
console.log(reply0.usage);
}

/**
Expand Down Expand Up @@ -101,17 +101,19 @@ async function mainStreaming() {

const asyncChunkGenerator = await engine.chat.completions.create(request);
let message = "";
let lastChunk: webllm.ChatCompletionChunk | undefined = undefined;
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
setLabel("generate-label", message);
lastChunk = chunk;
// engine.interruptGenerate(); // works with interrupt as well
}
console.log("Final message:\n", await engine.getMessage()); // the concatenated message
console.log(await engine.runtimeStatsText());
console.log(lastChunk!.usage);
}

registerServiceWorker();
Expand Down
7 changes: 6 additions & 1 deletion examples/streaming/src/streaming.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,17 +37,22 @@ async function main() {

const asyncChunkGenerator = await engine.chat.completions.create(request);
let message = "";
let lastChunk: webllm.ChatCompletionChunk | undefined = undefined;
for await (const chunk of asyncChunkGenerator) {
console.log(chunk);
if (chunk.choices[0].delta.content) {
// Last chunk has undefined content
message += chunk.choices[0].delta.content;
}
setLabel("generate-label", message);
lastChunk = chunk;
// engine.interruptGenerate(); // works with interrupt as well
}
console.log("Final message:\n", await engine.getMessage()); // the concatenated message
console.log(await engine.runtimeStatsText());
if (lastChunk?.usage) {
// If streaming finished before ending, we would not have usage.
console.log(lastChunk.usage);
}
}

main();
22 changes: 21 additions & 1 deletion src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,14 @@ export class MLCEngine implements MLCEngineInterface {
) as Array<ChatCompletionChunk.Choice.Delta.ToolCall>;
}

const completion_tokens =
this.getPipeline().getCurRoundDecodingTotalTokens();
const prompt_tokens = this.getPipeline().getCurRoundPrefillTotalTokens();
const prefill_tokens_per_s =
this.getPipeline().getCurRoundPrefillTokensPerSec();
const decode_tokens_per_s =
this.getPipeline().getCurRoundDecodingTokensPerSec();

const lastChunk: ChatCompletionChunk = {
id: id,
choices: [
Expand All @@ -456,14 +464,20 @@ export class MLCEngine implements MLCEngineInterface {
tool_calls: tool_calls,
}
: {},
// eslint-disable-next-line @typescript-eslint/no-non-null-assertion
finish_reason: finish_reason,
index: 0,
},
],
model: model,
object: "chat.completion.chunk",
created: created,
usage: {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
prefill_tokens_per_s: prefill_tokens_per_s,
decode_tokens_per_s: decode_tokens_per_s,
} as CompletionUsage,
};
yield lastChunk;
}
Expand Down Expand Up @@ -522,6 +536,8 @@ export class MLCEngine implements MLCEngineInterface {
const choices: Array<ChatCompletion.Choice> = [];
let completion_tokens = 0;
let prompt_tokens = 0;
let prefill_time = 0;
let decode_time = 0;
for (let i = 0; i < n; i++) {
let outputMessage: string;
if (this.interruptSignal) {
Expand Down Expand Up @@ -573,6 +589,8 @@ export class MLCEngine implements MLCEngineInterface {
});
completion_tokens += this.getPipeline().getCurRoundDecodingTotalTokens();
prompt_tokens += this.getPipeline().getCurRoundPrefillTotalTokens();
prefill_time += this.getPipeline().getCurRoundPrefillTotalTime();
decode_time += this.getPipeline().getCurRoundDecodingTotalTime();
}

const response: ChatCompletion = {
Expand All @@ -585,6 +603,8 @@ export class MLCEngine implements MLCEngineInterface {
completion_tokens: completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: completion_tokens + prompt_tokens,
prefill_tokens_per_s: prompt_tokens / prefill_time,
decode_tokens_per_s: completion_tokens / decode_time,
} as CompletionUsage,
};

Expand Down
48 changes: 47 additions & 1 deletion src/llm_chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ export class LLMChatPipeline {
private decodingTotalTokens = 0;
private prefillTotalTime = 0;
private prefillTotalTokens = 0;
// same as `prefillTotalTokens` and `decodingTotalTokens`, but reset at every `prefillStep()`
// same stats as above, but reset at every `prefillStep()`
private curRoundDecodingTotalTokens = 0;
private curRoundPrefillTotalTokens = 0;
private curRoundDecodingTotalTime = 0;
private curRoundPrefillTotalTime = 0;

// LogitProcessor
private logitProcessor?: LogitProcessor = undefined;
Expand Down Expand Up @@ -356,6 +358,20 @@ export class LLMChatPipeline {
return this.curRoundPrefillTotalTokens;
}

/**
* @returns the time spent on decode for a single request or a single choice in the request.
*/
getCurRoundDecodingTotalTime(): number {
return this.curRoundDecodingTotalTime;
}

/**
* @returns the time spent on for a single request or a single choice in the request.
*/
getCurRoundPrefillTotalTime(): number {
return this.curRoundPrefillTotalTime;
}

/**
* @returns Runtime stats information.
*/
Expand All @@ -366,6 +382,30 @@ export class LLMChatPipeline {
);
}

/**
* @returns Runtime stats information, starting from the last prefill performed.
*/
curRoundRuntimeStatsText(): string {
return (
`prefill: ${this.getCurRoundPrefillTokensPerSec().toFixed(4)} tokens/sec, ` +
`decoding: ${this.getCurRoundDecodingTokensPerSec().toFixed(4)} tokens/sec`
);
}

/**
* @returns Prefill tokens per second, starting from the last prefill performed.
*/
getCurRoundPrefillTokensPerSec(): number {
return this.curRoundPrefillTotalTokens / this.curRoundPrefillTotalTime;
}

/**
* @returns Prefill tokens per second, starting from the last prefill performed.
*/
getCurRoundDecodingTokensPerSec(): number {
return this.curRoundDecodingTotalTokens / this.curRoundDecodingTotalTime;
}

/**
* Set the seed for the RNG `this.tvm.rng`.
*/
Expand Down Expand Up @@ -411,6 +451,8 @@ export class LLMChatPipeline {
this.tokenLogprobArray = [];
this.curRoundDecodingTotalTokens = 0;
this.curRoundPrefillTotalTokens = 0;
this.curRoundPrefillTotalTime = 0;
this.curRoundDecodingTotalTime = 0;
this.stopTriggered = false;
const conversation = this.conversation;

Expand Down Expand Up @@ -481,6 +523,7 @@ export class LLMChatPipeline {
this.prefillTotalTime += (tend - tstart) / 1e3;
this.prefillTotalTokens += promptTokens.length;
this.curRoundPrefillTotalTokens += promptTokens.length;
this.curRoundPrefillTotalTime += (tend - tstart) / 1e3;

this.processNextToken(nextToken, genConfig);
}
Expand Down Expand Up @@ -508,6 +551,7 @@ export class LLMChatPipeline {
this.decodingTotalTime += (tend - tstart) / 1e3;
this.decodingTotalTokens += 1;
this.curRoundDecodingTotalTokens += 1;
this.curRoundDecodingTotalTime += (tend - tstart) / 1e3;

this.processNextToken(nextToken, genConfig);
}
Expand Down Expand Up @@ -991,10 +1035,12 @@ export class LLMChatPipeline {
this.prefillTotalTime += (tend - tstart) / 1e3;
this.prefillTotalTokens += inputIds.length;
this.curRoundPrefillTotalTokens += inputIds.length;
this.curRoundPrefillTotalTime += (tend - tstart) / 1e3;
} else {
this.decodingTotalTime += (tend - tstart) / 1e3;
this.decodingTotalTokens += 1;
this.curRoundDecodingTotalTokens += 1;
this.curRoundDecodingTotalTime += (tend - tstart) / 1e3;
}
return nextToken;
}
Expand Down
Loading

0 comments on commit 4d340b1

Please sign in to comment.