Skip to content

Commit

Permalink
feat: support reconnecting websocket
Browse files Browse the repository at this point in the history
Instead of asking a user to refresh the browser, we let the websocket
reconnect and try to restore the page session.

 * Fixes #254
 * Fixes #161
  • Loading branch information
maartenbreddels committed Oct 6, 2023
1 parent f6943c6 commit 4606c46
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 21 deletions.
51 changes: 32 additions & 19 deletions packages/solara-widget-manager/src/manager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ export class WidgetManager extends JupyterLabManager {
);
this._registerWidgets();
this._loader = requireLoader;
const commId = base.uuid();
const kernel = context.sessionContext?.session?.kernel;
this.connectControlComm();
if (!kernel) {
throw new Error('No current kernel');
}
}
async connectControlComm() {
const commId = base.uuid();
const kernel = this.context.sessionContext?.session?.kernel;
this.controlComm = kernel.createComm('solara.control', commId);
this.controlCommHandler = {
onMsg: (msg) => {
Expand All @@ -107,31 +111,40 @@ export class WidgetManager extends JupyterLabManager {
}
async check() {
// checks if app is still valid (e.g. server restarted and lost the widget state)
const okPromise = new Promise((resolve, reject) => {
this.controlCommHandler = {
onMsg: (msg) => {
// if we are connected to the same kernel, we'll get a reply instantly
// however, if we are connected to a new kernel, we rely on the timeout
// so every time we create a new comm.

const kernel = this.context.sessionContext?.session?.kernel;
const commId = base.uuid();
const controlComm = kernel.createComm('solara.control', commId);
controlComm.open({}, {}, [])
try {
return await new Promise((resolve, reject) => {
controlComm.onMsg = (msg) => {
const data = msg['content']['data'];
if (data.method === 'finished') {
resolve(data.ok);
if (data.method === 'check') {
if (data.ok === true) {
resolve({ ok: true, message: data.message });
} else {
resolve({ ok: false, message: data.message });
}
}
else {
reject(data.error);
reject({ ok: false, message: "unexpected message" });
}
},
onClose: () => {
}
controlComm.onClose = () => {
console.error("closed solara control comm")
reject()
reject({ ok: false, message: "closed solara control comm" });
}
};
setTimeout(() => {
reject('timeout');
}, CONTROL_COMM_TIMEOUT);
});
this.controlComm.send({ method: 'check' });
try {
return await okPromise;
setTimeout(() => {
reject('timeout');
}, CONTROL_COMM_TIMEOUT);
controlComm.send({ method: 'check' });
});
} catch (e) {
return false;
return { ok: false, message: e };
}
}

Expand Down
10 changes: 10 additions & 0 deletions solara/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,13 +352,23 @@ def on_msg(msg):
comm.send({"method": "finished", "widget_id": context.container._model_id})
elif method == "check":
context = kernel_context.get_current_context()
# if there is no container, we never ran the app
if context.container is not None:
logger.info("Reconnect check: %s is ok", context.id)
comm.send({"method": "check", "ok": True, "message": "All fine"})
else:
logger.info("Reconnect check: %s is not ok", context.id)
comm.send({"method": "check", "ok": False, "message": "Not reconnected"})

elif method == "reload":
assert app is not None
context = kernel_context.get_current_context()
path = data.get("path", "")
with context:
load_app_widget(context.state, app, path)
comm.send({"method": "finished"})
else:
logger.error("Unknown comm method called on solara.control comm: %s", method)

comm.on_msg(on_msg)

Expand Down
10 changes: 8 additions & 2 deletions solara/server/static/main-vuetify.js
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,16 @@ async function solaraInit(mountId, appName) {
}
if (s.connectionStatus == 'connected' && !skipReconnectedCheck) {
(async () => {
let ok = await widgetManager.check()
if (!ok) {
if (app.$data.needsRefresh) {
// give up
return;
}
const msg = await widgetManager.check()
if (!msg.ok) {
app.$data.needsRefresh = true;
await solara.shutdownKernel(kernel);
} else {
await widgetManager.fetchAll();
}
})();
}
Expand Down
109 changes: 109 additions & 0 deletions tests/integration/reconnect_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from pathlib import Path
from typing import Optional

import playwright.sync_api

import solara
import solara.server.kernel_context

HERE = Path(__file__).parent


set_value = None
context: Optional["solara.server.kernel_context.VirtualKernelContext"] = None


@solara.component
def Page():
global set_value, app_context
value, set_value = solara.use_state(0)
assert set_value is not None
context = solara.server.kernel_context.get_current_context()
assert context is not None
solara.Text(f"Value {value}")

def disconnect():
assert len(context.kernel.session.websockets) == 1
list(context.kernel.session.websockets)[0].close()

solara.Button("Disconnect", on_click=disconnect)
solara.Button("Increment", on_click=lambda: set_value(value + 1))

def disconnect_and_change():
assert len(context.kernel.session.websockets) == 1
list(context.kernel.session.websockets)[0].close()
set_value(100)

solara.Button("Disconnect and change", on_click=disconnect_and_change)


def test_reconnect_simple(browser: playwright.sync_api.Browser, page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path):
with extra_include_path(HERE), solara_app("reconnect_test:Page"):
page_session.goto(solara_server.base_url)
page_session.locator("text=Value 0").wait_for()
page_session.locator("text=Increment").click()
page_session.locator("text=Value 1").wait_for()
assert len(solara.server.kernel_context.contexts) == 1
context = list(solara.server.kernel_context.contexts.values())[0]
assert len(context.kernel.session.websockets) == 1
ws = list(context.kernel.session.websockets)[0]
page_session.locator("text=Disconnect").nth(0).click()
n = 0
# we wait till the current websocket is not connected anymore, and a different one is connected
while not (ws not in context.kernel.session.websockets and len(context.kernel.session.websockets) == 1):
page_session.wait_for_timeout(100)
n += 1
if n > 50:
raise RuntimeError("Timeout waiting for reconnected websocket")
page_session.locator("text=Value 1").wait_for()
page_session.locator("text=Increment").click()
page_session.locator("text=Value 2").wait_for()
# we should not have created a new context
assert len(solara.server.kernel_context.contexts) == 1
page_session.goto("about:blank")
assert context.closed_event.wait(10)


def test_reconnect_fail(browser: playwright.sync_api.Browser, page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path):
with extra_include_path(HERE), solara_app("reconnect_test:Page"):
# import reconnect_test as module

page_session.goto(solara_server.base_url)
page_session.locator("text=Value 0").wait_for()
page_session.locator("text=Increment").click()
page_session.locator("text=Value 1").wait_for()
cull_timeout_previous = solara.server.settings.kernel.cull_timeout
context = None
try:
solara.server.settings.kernel.cull_timeout = "0s"
assert len(solara.server.kernel_context.contexts) == 1
context = list(solara.server.kernel_context.contexts.values())[0]
assert len(context.kernel.session.websockets) == 1
page_session.locator("text=Disconnect").nth(0).click()
page_session.locator("text=Could not restore session").wait_for()
n = 0
# we wait till the all contexts are closed
while len(solara.server.kernel_context.contexts):
page_session.wait_for_timeout(100)
n += 1
if n > 50:
raise RuntimeError("Timeout waiting for kernel shutdown")

finally:
page_session.goto("about:blank")
if context is not None:
assert context.closed_event.wait(10)
solara.server.settings.kernel.cull_timeout = cull_timeout_previous


def test_reconnect_and_update(browser: playwright.sync_api.Browser, page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path):
with extra_include_path(HERE), solara_app("reconnect_test:Page"):
page_session.goto(solara_server.base_url)
page_session.locator("text=Value 0").wait_for()
page_session.locator("text=Increment").click()
page_session.locator("text=Value 1").wait_for()
# this will disconnect, and aftwards change something so the websocket queue feature is used
page_session.locator("text=Disconnect and change").click()
page_session.locator("text=Value 100").wait_for()
page_session.goto("about:blank")
assert context.closed_event.wait(10)

0 comments on commit 4606c46

Please sign in to comment.