Skip to content

Commit

Permalink
[Fix] Allow concurrent inference for multi model in WebWorker (#546)
Browse files Browse the repository at this point in the history
This is a follow-up to #542.
Update `examples/multi-model` to use web worker, and to also show case
generating responses from two models concurrently from the same engine.
This is already supported for `MLCEngine` prior to this PR, but
`WebWorkerMLCEngine` needed a patch. Specifically:

- Prior to this PR, `WebWorkerMLCEngineHandler` maintains a single
`asyncGenreator`, assuming there is only one model.
- Now, to support concurrent streaming request, we replace
`this.asyncGenerator` with `this.loadedModelIdToAsyncGenerator`, which
maps from a model id to its dedicated `asyncGenerator`
  • Loading branch information
CharlieFRuan authored Aug 13, 2024
1 parent d926cff commit d351b6a
Show file tree
Hide file tree
Showing 6 changed files with 234 additions and 99 deletions.
134 changes: 134 additions & 0 deletions examples/multi-models/src/main.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/**
* This example demonstrates loading multiple models in the same engine concurrently.
* sequentialGeneration() shows inference each model one at a time.
* parallelGeneration() shows inference both models at the same time.
* This example uses WebWorkerMLCEngine, but the same idea applies to MLCEngine and
* ServiceWorkerMLCEngine as well.
*/

import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// Prepare request for each model, same for both methods
const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k";
const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k";
const prompt1 = "Tell me about California in 3 short sentences.";
const prompt2 = "Tell me about New York City in 3 short sentences.";
setLabel("prompt-label-1", `(with model ${selectedModel1})\n` + prompt1);
setLabel("prompt-label-2", `(with model ${selectedModel2})\n` + prompt2);

const request1: webllm.ChatCompletionRequestStreaming = {
stream: true,
stream_options: { include_usage: true },
messages: [{ role: "user", content: prompt1 }],
model: selectedModel1, // without specifying it, error will throw due to ambiguity
max_tokens: 128,
};

const request2: webllm.ChatCompletionRequestStreaming = {
stream: true,
stream_options: { include_usage: true },
messages: [{ role: "user", content: prompt2 }],
model: selectedModel2, // without specifying it, error will throw due to ambiguity
max_tokens: 128,
};

/**
* Chat completion (OpenAI style) with streaming, with two models in the pipeline.
*/
async function sequentialGeneration() {
const engine = await webllm.CreateWebWorkerMLCEngine(
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }),
[selectedModel1, selectedModel2],
{ initProgressCallback: initProgressCallback },
);

const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
let message1 = "";
for await (const chunk of asyncChunkGenerator1) {
// console.log(chunk);
message1 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-1", message1);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
let message2 = "";
for await (const chunk of asyncChunkGenerator2) {
// console.log(chunk);
message2 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-2", message2);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
}

/**
* Chat completion (OpenAI style) with streaming, with two models in the pipeline.
*/
async function parallelGeneration() {
const engine = await webllm.CreateWebWorkerMLCEngine(
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }),
[selectedModel1, selectedModel2],
{ initProgressCallback: initProgressCallback },
);

// We can serve the two requests concurrently
let message1 = "";
let message2 = "";

async function getModel1Response() {
const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
for await (const chunk of asyncChunkGenerator1) {
// console.log(chunk);
message1 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-1", message1);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
}

async function getModel2Response() {
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
for await (const chunk of asyncChunkGenerator2) {
// console.log(chunk);
message2 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-2", message2);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
}

await Promise.all([getModel1Response(), getModel2Response()]);

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
}

// Pick one to run
sequentialGeneration();
// parallelGeneration();
17 changes: 12 additions & 5 deletions examples/multi-models/src/multi_models.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,21 @@ <h2>WebLLM Test Page</h2>
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>
<h3>Prompt 1</h3>
<label id="prompt-label-1"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<h3>Response from model 1</h3>
<label id="generate-label-1"> </label>
<br />

<h3>Prompt 2</h3>
<label id="prompt-label-2"> </label>

<h3>Response from model 2</h3>
<label id="generate-label-2"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./multi_models.ts"></script>
<script type="module" src="./main.ts"></script>
</body>
</html>
76 changes: 0 additions & 76 deletions examples/multi-models/src/multi_models.ts

This file was deleted.

7 changes: 7 additions & 0 deletions examples/multi-models/src/worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

// Hookup an engine to a worker handler
const handler = new WebWorkerMLCEngineHandler();
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
11 changes: 11 additions & 0 deletions src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ export interface ForwardTokensAndSampleParams {
// handler will call reload(). An engine can load multiple models, hence both are list.
// TODO(webllm-team): should add appConfig here as well if rigorous.
// Fore more, see https://github.com/mlc-ai/web-llm/pull/471

// Note on the messages with selectedModelId:
// This is the modelId this request uses. It is needed to identify which async generator
// to instantiate / use, since an engine can load multiple models, thus the handler
// needs to maintain multiple generators.
export interface ChatCompletionNonStreamingParams {
request: ChatCompletionRequestNonStreaming;
modelId: string[];
chatOpts?: ChatOptions[];
}
export interface ChatCompletionStreamInitParams {
request: ChatCompletionRequestStreaming;
selectedModelId: string;
modelId: string[];
chatOpts?: ChatOptions[];
}
Expand All @@ -82,6 +88,7 @@ export interface CompletionNonStreamingParams {
}
export interface CompletionStreamInitParams {
request: CompletionCreateParamsStreaming;
selectedModelId: string;
modelId: string[];
chatOpts?: ChatOptions[];
}
Expand All @@ -90,6 +97,9 @@ export interface EmbeddingParams {
modelId: string[];
chatOpts?: ChatOptions[];
}
export interface CompletionStreamNextChunkParams {
selectedModelId: string;
}

export interface CustomRequestParams {
requestName: string;
Expand All @@ -106,6 +116,7 @@ export type MessageContent =
| CompletionNonStreamingParams
| CompletionStreamInitParams
| EmbeddingParams
| CompletionStreamNextChunkParams
| CustomRequestParams
| InitProgressReport
| LogLevel
Expand Down
Loading

0 comments on commit d351b6a

Please sign in to comment.