Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Allow concurrent inference for multi model in WebWorker #546

Merged
merged 2 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions examples/multi-models/src/main.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
/**
* This example demonstrates loading multiple models in the same engine concurrently.
* sequentialGeneration() shows inference each model one at a time.
* parallelGeneration() shows inference both models at the same time.
* This example uses WebWorkerMLCEngine, but the same idea applies to MLCEngine and
* ServiceWorkerMLCEngine as well.
*/

import * as webllm from "@mlc-ai/web-llm";

function setLabel(id: string, text: string) {
const label = document.getElementById(id);
if (label == null) {
throw Error("Cannot find label " + id);
}
label.innerText = text;
}

const initProgressCallback = (report: webllm.InitProgressReport) => {
setLabel("init-label", report.text);
};

// Prepare request for each model, same for both methods
const selectedModel1 = "Phi-3-mini-4k-instruct-q4f32_1-MLC-1k";
const selectedModel2 = "gemma-2-2b-it-q4f32_1-MLC-1k";
const prompt1 = "Tell me about California in 3 short sentences.";
const prompt2 = "Tell me about New York City in 3 short sentences.";
setLabel("prompt-label-1", `(with model ${selectedModel1})\n` + prompt1);
setLabel("prompt-label-2", `(with model ${selectedModel2})\n` + prompt2);

const request1: webllm.ChatCompletionRequestStreaming = {
stream: true,
stream_options: { include_usage: true },
messages: [{ role: "user", content: prompt1 }],
model: selectedModel1, // without specifying it, error will throw due to ambiguity
max_tokens: 128,
};

const request2: webllm.ChatCompletionRequestStreaming = {
stream: true,
stream_options: { include_usage: true },
messages: [{ role: "user", content: prompt2 }],
model: selectedModel2, // without specifying it, error will throw due to ambiguity
max_tokens: 128,
};

/**
* Chat completion (OpenAI style) with streaming, with two models in the pipeline.
*/
async function sequentialGeneration() {
const engine = await webllm.CreateWebWorkerMLCEngine(
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }),
[selectedModel1, selectedModel2],
{ initProgressCallback: initProgressCallback },
);

const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
let message1 = "";
for await (const chunk of asyncChunkGenerator1) {
// console.log(chunk);
message1 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-1", message1);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
let message2 = "";
for await (const chunk of asyncChunkGenerator2) {
// console.log(chunk);
message2 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-2", message2);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
}

/**
* Chat completion (OpenAI style) with streaming, with two models in the pipeline.
*/
async function parallelGeneration() {
const engine = await webllm.CreateWebWorkerMLCEngine(
new Worker(new URL("./worker.ts", import.meta.url), { type: "module" }),
[selectedModel1, selectedModel2],
{ initProgressCallback: initProgressCallback },
);

// We can serve the two requests concurrently
let message1 = "";
let message2 = "";

async function getModel1Response() {
const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
for await (const chunk of asyncChunkGenerator1) {
// console.log(chunk);
message1 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-1", message1);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
}

async function getModel2Response() {
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
for await (const chunk of asyncChunkGenerator2) {
// console.log(chunk);
message2 += chunk.choices[0]?.delta?.content || "";
setLabel("generate-label-2", message2);
if (chunk.usage) {
console.log(chunk.usage); // only last chunk has usage
}
// engine.interruptGenerate(); // works with interrupt as well
}
}

await Promise.all([getModel1Response(), getModel2Response()]);

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
console.log("Final message 2:\n", await engine.getMessage(selectedModel2));
}

// Pick one to run
sequentialGeneration();
// parallelGeneration();
17 changes: 12 additions & 5 deletions examples/multi-models/src/multi_models.html
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,21 @@ <h2>WebLLM Test Page</h2>
<br />
<label id="init-label"> </label>

<h3>Prompt</h3>
<label id="prompt-label"> </label>
<h3>Prompt 1</h3>
<label id="prompt-label-1"> </label>

<h3>Response</h3>
<label id="generate-label"> </label>
<h3>Response from model 1</h3>
<label id="generate-label-1"> </label>
<br />

<h3>Prompt 2</h3>
<label id="prompt-label-2"> </label>

<h3>Response from model 2</h3>
<label id="generate-label-2"> </label>
<br />
<label id="stats-label"> </label>

<script type="module" src="./multi_models.ts"></script>
<script type="module" src="./main.ts"></script>
</body>
</html>
76 changes: 0 additions & 76 deletions examples/multi-models/src/multi_models.ts

This file was deleted.

7 changes: 7 additions & 0 deletions examples/multi-models/src/worker.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import { WebWorkerMLCEngineHandler } from "@mlc-ai/web-llm";

// Hookup an engine to a worker handler
const handler = new WebWorkerMLCEngineHandler();
self.onmessage = (msg: MessageEvent) => {
handler.onmessage(msg);
};
11 changes: 11 additions & 0 deletions src/message.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,19 @@ export interface ForwardTokensAndSampleParams {
// handler will call reload(). An engine can load multiple models, hence both are list.
// TODO(webllm-team): should add appConfig here as well if rigorous.
// Fore more, see https://github.com/mlc-ai/web-llm/pull/471

// Note on the messages with selectedModelId:
// This is the modelId this request uses. It is needed to identify which async generator
// to instantiate / use, since an engine can load multiple models, thus the handler
// needs to maintain multiple generators.
export interface ChatCompletionNonStreamingParams {
request: ChatCompletionRequestNonStreaming;
modelId: string[];
chatOpts?: ChatOptions[];
}
export interface ChatCompletionStreamInitParams {
request: ChatCompletionRequestStreaming;
selectedModelId: string;
modelId: string[];
chatOpts?: ChatOptions[];
}
Expand All @@ -82,6 +88,7 @@ export interface CompletionNonStreamingParams {
}
export interface CompletionStreamInitParams {
request: CompletionCreateParamsStreaming;
selectedModelId: string;
modelId: string[];
chatOpts?: ChatOptions[];
}
Expand All @@ -90,6 +97,9 @@ export interface EmbeddingParams {
modelId: string[];
chatOpts?: ChatOptions[];
}
export interface CompletionStreamNextChunkParams {
selectedModelId: string;
}

export interface CustomRequestParams {
requestName: string;
Expand All @@ -106,6 +116,7 @@ export type MessageContent =
| CompletionNonStreamingParams
| CompletionStreamInitParams
| EmbeddingParams
| CompletionStreamNextChunkParams
| CustomRequestParams
| InitProgressReport
| LogLevel
Expand Down
Loading
Loading