Skip to content

Commit

Permalink
[Fix] Implement lock to ensure FCFS of requests to same model (#549)
Browse files Browse the repository at this point in the history
A model cannot handle > 1 concurrent request (e.g. >1 calls to
`chat.completions.create()`) since we do not support continuous
batching, and each request requires its own resources such as the KV
cache. (Though "concurrent" requests to different models in the same
engine is supported)

As a result, as pointed out in
#522, when users try something
like the following code:

```typescript
const engine = await CreateMLCEngine("Phi-3-mini-4k-instruct-q4f16_1-MLC")
async function sendRequest() {
  const reply = await engine.chat.completions.create({
    messages: [{ role: "user", content: "Hello!" }],
    max_tokens: 64,
  });
  console.log(reply.choices[0].message.content);
}
await Promise.all([sendRequest(), sendRequest()]);
```

the model's state and the generation result are messed up.

To resolve this, we implement `CustomLock` using Promise, maintaining a
queue to ensure FCFS for incoming requests to a model, such that for a
single model, a request only starts when all previous requests are
finished. The code above now works.

### Implementation Details
- We add `loadedModelIdToLock` to MLCEngine, maintaining a lock for each
loaded engine
- Reminder: the need for a critical section is only per model, since
each loaded model has its own `LLMChatPipeline` / `EmbeddingPipeline`
- `loadedModelIdToLock` is cleared in `unload()`, set in
`reloadInternal()`
- We acquire lock at the very beginning of `completion()`,
`chatCompletion()` and `embedding()`, after knowing which model this
current call will use
- We release lock at the end of `embedding()`, `completion()` and
`chatCompletion()` (for non-streaming cases), and `asyncGenerate()` (for
streaming cases)
- Since we also want to release the lock when errors occur, we wrap the
code with a big `try` `finally`
- Since `asyncGenerate()` is an async generator, we add `try` `catch`
fine-grainedly, only in places that can throw errors
- This makes the code less readable, but not sure if there is a better
solution.
- For WebWorkerMLCEngine, no special handling is needed, since the
WebWorkerMLCEngineHandler calls the underlying engine's APIs (e.g.
`chatCompletion()`), which will block

### Tested
- Tested `CustomLock` implementation with unit test (implementation
follows [this blog
post](https://jackpordi.com/posts/locks-in-js-because-why-not))
- Above example now works
- [get-started, get-started-web-worker] x [streaming, non-streaming] x
[concurrent requests, single request]
- examples/simple-chat-ts
- examples/multi-models
- WebLLMChat (with generation interrupts, manual termination of service
worker)
- Opening two tabs WebLLMChat, sending concurrent request, the latter
request will wait for the previous one to finish (prior to this PR,
garbage output will be generated just like the above simple example,
since the two WebLLMChat shares the same service worker, hence the same
engine).
  • Loading branch information
CharlieFRuan authored Aug 19, 2024
1 parent b598fc1 commit 1b691a5
Show file tree
Hide file tree
Showing 5 changed files with 348 additions and 197 deletions.
9 changes: 6 additions & 3 deletions examples/multi-models/src/main.ts
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,8 @@ async function parallelGeneration() {
);

// We can serve the two requests concurrently
let message1 = "";
let message2 = "";

async function getModel1Response() {
let message1 = "";
const asyncChunkGenerator1 = await engine.chat.completions.create(request1);
for await (const chunk of asyncChunkGenerator1) {
// console.log(chunk);
Expand All @@ -110,6 +108,7 @@ async function parallelGeneration() {
}

async function getModel2Response() {
let message2 = "";
const asyncChunkGenerator2 = await engine.chat.completions.create(request2);
for await (const chunk of asyncChunkGenerator2) {
// console.log(chunk);
Expand All @@ -123,6 +122,10 @@ async function parallelGeneration() {
}

await Promise.all([getModel1Response(), getModel2Response()]);
// Note: concurrent requests to the same model are executed sequentially in FCFS,
// unlike to different models like above
// Fore more, see https://github.com/mlc-ai/web-llm/pull/549
// await Promise.all([getModel1Response(), getModel1Response()]);

// without specifying from which model to get message, error will throw due to ambiguity
console.log("Final message 1:\n", await engine.getMessage(selectedModel1));
Expand Down
Loading

0 comments on commit 1b691a5

Please sign in to comment.