diff --git a/src/service_worker.ts b/src/service_worker.ts index eba3e564..9c800ba6 100644 --- a/src/service_worker.ts +++ b/src/service_worker.ts @@ -2,7 +2,7 @@ import * as tvmjs from "tvmjs"; import log from "loglevel"; import { AppConfig, ChatOptions, MLCEngineConfig } from "./config"; import { ReloadParams, WorkerRequest, WorkerResponse } from "./message"; -import { MLCEngineInterface, InitProgressReport, LogLevel } from "./types"; +import { MLCEngineInterface, InitProgressReport } from "./types"; import { MLCEngineWorkerHandler, WebWorkerMLCEngine, @@ -158,12 +158,19 @@ export class ServiceWorkerMLCEngineHandler extends MLCEngineWorkerHandler { * PostMessageHandler wrapper for sending message from client to service worker */ export class ServiceWorker implements ChatWorker { - serviceWorker: IServiceWorker; - onmessage: () => void; + _onmessage: (event: MessageEvent) => void = () => {}; - constructor(serviceWorker: IServiceWorker) { - this.serviceWorker = serviceWorker; - this.onmessage = () => {}; + get onmessage() { + return this._onmessage; + } + + set onmessage(handler: (event: any) => void) { + this._onmessage = handler; + + if (!("serviceWorker" in navigator)) { + throw new Error("Service worker API is not available"); + } + (navigator.serviceWorker as ServiceWorkerContainer).onmessage = handler; } postMessage(message: WorkerRequest) { @@ -206,7 +213,7 @@ export async function CreateServiceWorkerMLCEngine( "Please refresh the page to retry initializing the service worker.", ); } - const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(serviceWorker); + const serviceWorkerMLCEngine = new ServiceWorkerMLCEngine(); if (engineConfig?.logLevel) { serviceWorkerMLCEngine.setLogLevel(engineConfig.logLevel); } @@ -227,34 +234,11 @@ export async function CreateServiceWorkerMLCEngine( export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { missedHeatbeat = 0; - constructor(worker: IServiceWorker, keepAliveMs = 10000) { + constructor(keepAliveMs = 10000) { if (!("serviceWorker" in navigator)) { throw new Error("Service worker API is not available"); } - super(new ServiceWorker(worker)); - const onmessage = this.onmessage.bind(this); - - (navigator.serviceWorker as ServiceWorkerContainer).addEventListener( - "message", - (event: MessageEvent) => { - const msg = event.data; - log.trace( - `MLC client message: [${msg.kind}] ${JSON.stringify(msg.content)}`, - ); - try { - if (msg.kind === "heartbeat") { - this.missedHeatbeat = 0; - return; - } - onmessage(msg); - } catch (err: any) { - // This is expected to throw if user has multiple windows open - if (!err.message.startsWith("return from a unknown uuid")) { - log.error("CreateWebServiceWorkerMLCEngine.onmessage", err); - } - } - }, - ); + super(new ServiceWorker()); setInterval(() => { this.worker.postMessage({ kind: "keepAlive", uuid: crypto.randomUUID() }); @@ -263,6 +247,25 @@ export class ServiceWorkerMLCEngine extends WebWorkerMLCEngine { }, keepAliveMs); } + onmessage(event: any): void { + const msg = event.data; + log.trace( + `MLC client message: [${msg.kind}] ${JSON.stringify(msg.content)}`, + ); + try { + if (msg.kind === "heartbeat") { + this.missedHeatbeat = 0; + return; + } + super.onmessage(msg); + } catch (err: any) { + // This is expected to throw if user has multiple windows open + if (!err.message.startsWith("return from a unknown uuid")) { + log.error("CreateWebServiceWorkerMLCEngine.onmessage", err); + } + } + } + /** * Initialize the chat with a model. *