diff --git a/src/engine.ts b/src/engine.ts index 82df2127..754aaf07 100644 --- a/src/engine.ts +++ b/src/engine.ts @@ -745,6 +745,8 @@ export class MLCEngine implements MLCEngineInterface { async chatCompletion( request: ChatCompletionRequest, ): Promise | ChatCompletion> { + const initialInterruptSignal = this.interruptSignal; + // 0. Check model loaded and preprocess inputs const [selectedModelId, selectedPipeline, selectedChatConfig] = this.getLLMStates("ChatCompletionRequest", request.model); @@ -771,6 +773,10 @@ export class MLCEngine implements MLCEngineInterface { // 0.5 Block wait until this pipeline finishes all previous requests const lock = this.loadedModelIdToLock.get(selectedModelId)!; await lock.acquire(); + // If interruptGenerate called during the wait for lock, respect that. But reset if locked from a previous call to interruptGenerate + if (initialInterruptSignal && this.interruptSignal) { + this.interruptSignal = false; + } // 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`) if (request.stream) { @@ -901,6 +907,8 @@ export class MLCEngine implements MLCEngineInterface { async completion( request: CompletionCreateParams, ): Promise | Completion> { + const initialInterruptSignal = this.interruptSignal; + // 0. Check model loaded and preprocess inputs const [selectedModelId, selectedPipeline, selectedChatConfig] = this.getLLMStates("CompletionCreateParams", request.model); @@ -920,6 +928,10 @@ export class MLCEngine implements MLCEngineInterface { // 0.5 Block wait until this pipeline finishes all previous requests const lock = this.loadedModelIdToLock.get(selectedModelId)!; await lock.acquire(); + // If interruptGenerate called during the wait for lock, respect that. But reset if locked from a previous call to interruptGenerate + if (initialInterruptSignal && this.interruptSignal) { + this.interruptSignal = false; + } // 1. If request is streaming, return an AsyncIterable (an iterable version of `_generate()`) if (request.stream) {