-
Notifications
You must be signed in to change notification settings - Fork 874
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Fix] Allow concurrent inference for multi model in WebWorker (#546)
This is a follow-up to #542. Update `examples/multi-model` to use web worker, and to also show case generating responses from two models concurrently from the same engine. This is already supported for `MLCEngine` prior to this PR, but `WebWorkerMLCEngine` needed a patch. Specifically: - Prior to this PR, `WebWorkerMLCEngineHandler` maintains a single `asyncGenreator`, assuming there is only one model. - Now, to support concurrent streaming request, we replace `this.asyncGenerator` with `this.loadedModelIdToAsyncGenerator`, which maps from a model id to its dedicated `asyncGenerator`
- Loading branch information
1 parent
d926cff
commit d351b6a
Showing
6 changed files
with
234 additions
and
99 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
/** | ||
* This example demonstrates loading multiple models in the same engine concurrently. | ||
* sequentialGeneration() shows inference each model one at a time. | ||
* parallelGeneration() shows inference both models at the same time. | ||
* This example uses WebWorkerMLCEngine, but the same idea applies to MLCEngine and | ||
* ServiceWorkerMLCEngine as well. | ||
*/ | ||
|
||
import * as webllm from "@mlc-ai/web-llm"; | ||
|
||
function setLabel(id: string, text: string) { | ||
const label = document.getElementById(id); | ||
if (label == null) { | ||
throw Error("Cannot find label " + id); | ||
} | ||
label.innerText = text; | ||
} | ||
|
||
const initProgressCallback = (report: webllm.InitProgressReport) => { | ||
setLabel("init-label", report.text); | ||
}; | ||
|
||
// Prepare request for each model, same for both methods | ||
const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k"; | ||
const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k"; | ||
const prompt1 = "Tell me about California in 3 short sentences."; | ||
const prompt2 = "Tell me about New York City in 3 short sentences."; | ||
setLabel("prompt-label-1", `(with model ${selectedModel1})\n` + prompt1); | ||
setLabel("prompt-label-2", `(with model ${selectedModel2})\n` + prompt2); | ||
|
||
const request1: webllm.ChatCompletionRequestStreaming = { | ||
stream: true, | ||
stream_options: { include_usage: true }, | ||
messages: [{ role: "user", content: prompt1 }], | ||
model: selectedModel1, // without specifying it, error will throw due to ambiguity | ||
max_tokens: 128, | ||
}; | ||
|
||
const request2: webllm.ChatCompletionRequestStreaming = { | ||
stream: true, | ||
stream_options: { include_usage: true }, | ||
messages: [{ role: "user", content: prompt2 }], | ||
model: selectedModel2, // without specifying it, error will throw due to ambiguity | ||
max_tokens: 128, | ||
}; | ||
|
||
/** | ||
* Chat completion (OpenAI style) with streaming, with two models in the pipeline. | ||
*/ | ||
async function sequentialGeneration() { | ||
const engine = await webllm.CreateWebWorkerMLCEngine( | ||
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }), | ||
[selectedModel1, selectedModel2], | ||
{ initProgressCallback: initProgressCallback }, | ||
); | ||
|
||
const asyncChunkGenerator1 = await engine.chat.completions.create(request1); | ||
let message1 = ""; | ||
for await (const chunk of asyncChunkGenerator1) { | ||
// console.log(chunk); | ||
message1 += chunk.choices[0]?.delta?.content || ""; | ||
setLabel("generate-label-1", message1); | ||
if (chunk.usage) { | ||
console.log(chunk.usage); // only last chunk has usage | ||
} | ||
// engine.interruptGenerate(); // works with interrupt as well | ||
} | ||
const asyncChunkGenerator2 = await engine.chat.completions.create(request2); | ||
let message2 = ""; | ||
for await (const chunk of asyncChunkGenerator2) { | ||
// console.log(chunk); | ||
message2 += chunk.choices[0]?.delta?.content || ""; | ||
setLabel("generate-label-2", message2); | ||
if (chunk.usage) { | ||
console.log(chunk.usage); // only last chunk has usage | ||
} | ||
// engine.interruptGenerate(); // works with interrupt as well | ||
} | ||
|
||
// without specifying from which model to get message, error will throw due to ambiguity | ||
console.log("Final message 1:\n", await engine.getMessage(selectedModel1)); | ||
console.log("Final message 2:\n", await engine.getMessage(selectedModel2)); | ||
} | ||
|
||
/** | ||
* Chat completion (OpenAI style) with streaming, with two models in the pipeline. | ||
*/ | ||
async function parallelGeneration() { | ||
const engine = await webllm.CreateWebWorkerMLCEngine( | ||
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }), | ||
[selectedModel1, selectedModel2], | ||
{ initProgressCallback: initProgressCallback }, | ||
); | ||
|
||
// We can serve the two requests concurrently | ||
let message1 = ""; | ||
let message2 = ""; | ||
|
||
async function getModel1Response() { | ||
const asyncChunkGenerator1 = await engine.chat.completions.create(request1); | ||
for await (const chunk of asyncChunkGenerator1) { | ||
// console.log(chunk); | ||
message1 += chunk.choices[0]?.delta?.content || ""; | ||
setLabel("generate-label-1", message1); | ||
if (chunk.usage) { | ||
console.log(chunk.usage); // only last chunk has usage | ||
} | ||
// engine.interruptGenerate(); // works with interrupt as well | ||
} | ||
} | ||
|
||
async function getModel2Response() { | ||
const asyncChunkGenerator2 = await engine.chat.completions.create(request2); | ||
for await (const chunk of asyncChunkGenerator2) { | ||
// console.log(chunk); | ||
message2 += chunk.choices[0]?.delta?.content || ""; | ||
setLabel("generate-label-2", message2); | ||
if (chunk.usage) { | ||
console.log(chunk.usage); // only last chunk has usage | ||
} | ||
// engine.interruptGenerate(); // works with interrupt as well | ||
} | ||
} | ||
|
||
await Promise.all([getModel1Response(), getModel2Response()]); | ||
|
||
// without specifying from which model to get message, error will throw due to ambiguity | ||
console.log("Final message 1:\n", await engine.getMessage(selectedModel1)); | ||
console.log("Final message 2:\n", await engine.getMessage(selectedModel2)); | ||
} | ||
|
||
// Pick one to run | ||
sequentialGeneration(); | ||
// parallelGeneration(); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm"; | ||
|
||
// Hookup an engine to a worker handler | ||
const handler = new WebWorkerMLCEngineHandler(); | ||
self.onmessage = (msg: MessageEvent) => { | ||
handler.onmessage(msg); | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.