From 6be2483270ac5a1af8781829e4d265b7bd2f6544 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Wed, 4 Oct 2023 21:42:26 +0200 Subject: [PATCH] feat: support reconnecting websocket Instead of asking a user to refresh the browser, we let the websocket reconnect and try to restore the page session. * Fixes #254 * Fixes #161 --- packages/solara-widget-manager/src/manager.ts | 51 +++++---- solara/server/app.py | 10 ++ solara/server/static/main-vuetify.js | 10 +- tests/integration/reconnect_test.py | 101 ++++++++++++++++++ 4 files changed, 151 insertions(+), 21 deletions(-) create mode 100644 tests/integration/reconnect_test.py diff --git a/packages/solara-widget-manager/src/manager.ts b/packages/solara-widget-manager/src/manager.ts index 3a9a2c65f..656aa7557 100644 --- a/packages/solara-widget-manager/src/manager.ts +++ b/packages/solara-widget-manager/src/manager.ts @@ -78,11 +78,15 @@ export class WidgetManager extends JupyterLabManager { ); this._registerWidgets(); this._loader = requireLoader; - const commId = base.uuid(); const kernel = context.sessionContext?.session?.kernel; + this.connectControlComm(); if (!kernel) { throw new Error('No current kernel'); } + } + async connectControlComm() { + const commId = base.uuid(); + const kernel = this.context.sessionContext?.session?.kernel; this.controlComm = kernel.createComm('solara.control', commId); this.controlCommHandler = { onMsg: (msg) => { @@ -107,31 +111,40 @@ export class WidgetManager extends JupyterLabManager { } async check() { // checks if app is still valid (e.g. server restarted and lost the widget state) - const okPromise = new Promise((resolve, reject) => { - this.controlCommHandler = { - onMsg: (msg) => { + // if we are connected to the same kernel, we'll get a reply instantly + // however, if we are connected to a new kernel, we rely on the timeout + // so every time we create a new comm. + + const kernel = this.context.sessionContext?.session?.kernel; + const commId = base.uuid(); + const controlComm = kernel.createComm('solara.control', commId); + controlComm.open({}, {}, []) + try { + return await new Promise((resolve, reject) => { + controlComm.onMsg = (msg) => { const data = msg['content']['data']; - if (data.method === 'finished') { - resolve(data.ok); + if (data.method === 'check') { + if (data.ok === true) { + resolve({ ok: true, message: data.message }); + } else { + resolve({ ok: false, message: data.message }); + } } else { - reject(data.error); + reject({ ok: false, message: "unexpected message" }); } - }, - onClose: () => { + } + controlComm.onClose = () => { console.error("closed solara control comm") - reject() + reject({ ok: false, message: "closed solara control comm" }); } - }; - setTimeout(() => { - reject('timeout'); - }, CONTROL_COMM_TIMEOUT); - }); - this.controlComm.send({ method: 'check' }); - try { - return await okPromise; + setTimeout(() => { + reject('timeout'); + }, CONTROL_COMM_TIMEOUT); + controlComm.send({ method: 'check' }); + }); } catch (e) { - return false; + return { ok: false, message: e }; } } diff --git a/solara/server/app.py b/solara/server/app.py index 4a8e02177..bdaee2665 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -352,6 +352,14 @@ def on_msg(msg): comm.send({"method": "finished", "widget_id": context.container._model_id}) elif method == "check": context = kernel_context.get_current_context() + # if there is no container, we never ran the app + if context.container is not None: + logger.info("Reconnect check: %s is ok", context.id) + comm.send({"method": "check", "ok": True, "message": "All fine"}) + else: + logger.info("Reconnect check: %s is not ok", context.id) + comm.send({"method": "check", "ok": False, "message": "Not reconnected"}) + elif method == "reload": assert app is not None context = kernel_context.get_current_context() @@ -359,6 +367,8 @@ def on_msg(msg): with context: load_app_widget(context.state, app, path) comm.send({"method": "finished"}) + else: + logger.error("Unknown comm method called on solara.control comm: %s", method) comm.on_msg(on_msg) diff --git a/solara/server/static/main-vuetify.js b/solara/server/static/main-vuetify.js index cb7e5d25d..6e2e55a9a 100644 --- a/solara/server/static/main-vuetify.js +++ b/solara/server/static/main-vuetify.js @@ -158,10 +158,16 @@ async function solaraInit(mountId, appName) { } if (s.connectionStatus == 'connected' && !skipReconnectedCheck) { (async () => { - let ok = await widgetManager.check() - if (!ok) { + if (app.$data.needsRefresh) { + // give up + return; + } + const msg = await widgetManager.check() + if (!msg.ok) { app.$data.needsRefresh = true; await solara.shutdownKernel(kernel); + } else { + await widgetManager.fetchAll(); } })(); } diff --git a/tests/integration/reconnect_test.py b/tests/integration/reconnect_test.py new file mode 100644 index 000000000..27864c931 --- /dev/null +++ b/tests/integration/reconnect_test.py @@ -0,0 +1,101 @@ +from pathlib import Path +from typing import Optional + +import playwright.sync_api + +import solara +import solara.server.kernel_context + +HERE = Path(__file__).parent + + +set_value = None +context: Optional["solara.server.kernel_context.VirtualKernelContext"] = None + + +@solara.component +def Page(): + global set_value, app_context + value, set_value = solara.use_state(0) + assert set_value is not None + context = solara.server.kernel_context.get_current_context() + assert context is not None + solara.Text(f"Value {value}") + + def disconnect(): + assert len(context.kernel.session.websockets) == 1 + list(context.kernel.session.websockets)[0].close() + + solara.Button("Disconnect", on_click=disconnect) + solara.Button("Increment", on_click=lambda: set_value(value + 1)) + + def disconnect_and_change(): + assert len(context.kernel.session.websockets) == 1 + list(context.kernel.session.websockets)[0].close() + set_value(100) + + solara.Button("Disconnect and change", on_click=disconnect_and_change) + + +def test_reconnect_simple(browser: playwright.sync_api.Browser, page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path): + with extra_include_path(HERE), solara_app("reconnect_test:Page"): + page_session.goto(solara_server.base_url) + page_session.locator("text=Value 0").wait_for() + page_session.locator("text=Increment").click() + page_session.locator("text=Value 1").wait_for() + assert len(solara.server.kernel_context.contexts) == 1 + context = list(solara.server.kernel_context.contexts.values())[0] + assert len(context.kernel.session.websockets) == 1 + ws = list(context.kernel.session.websockets)[0] + page_session.locator("text=Disconnect").nth(0).click() + n = 0 + # we wait till the current websocket is not connected anymore, and a different one is connected + while not (ws not in context.kernel.session.websockets and len(context.kernel.session.websockets) == 1): + page_session.wait_for_timeout(100) + n += 1 + if n > 50: + raise RuntimeError("Timeout waiting for reconnected websocket") + page_session.locator("text=Value 1").wait_for() + page_session.locator("text=Increment").click() + page_session.locator("text=Value 2").wait_for() + # we should not have created a new context + assert len(solara.server.kernel_context.contexts) == 1 + + +def test_reconnect_fail(browser: playwright.sync_api.Browser, page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path): + with extra_include_path(HERE), solara_app("reconnect_test:Page"): + # import reconnect_test as module + + page_session.goto(solara_server.base_url) + page_session.locator("text=Value 0").wait_for() + page_session.locator("text=Increment").click() + page_session.locator("text=Value 1").wait_for() + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + try: + solara.server.settings.kernel.cull_timeout = "0s" + assert len(solara.server.kernel_context.contexts) == 1 + context = list(solara.server.kernel_context.contexts.values())[0] + assert len(context.kernel.session.websockets) == 1 + page_session.locator("text=Disconnect").nth(0).click() + page_session.locator("text=Could not restore session").wait_for() + n = 0 + # we wait till the all contexts are closed + while len(solara.server.kernel_context.contexts): + page_session.wait_for_timeout(100) + n += 1 + if n > 50: + raise RuntimeError("Timeout waiting for kernel shutdown") + + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous + + +def test_reconnect_and_update(browser: playwright.sync_api.Browser, page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path): + with extra_include_path(HERE), solara_app("reconnect_test:Page"): + page_session.goto(solara_server.base_url) + page_session.locator("text=Value 0").wait_for() + page_session.locator("text=Increment").click() + page_session.locator("text=Value 1").wait_for() + # this will disconnect, and aftwards change something so the websocket queue feature is used + page_session.locator("text=Disconnect and change").click() + page_session.locator("text=Value 100").wait_for()