From d602541e787122b6b5021710c6782ebe7be2762a Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Tue, 3 Oct 2023 19:40:11 +0200 Subject: [PATCH] feat: allow reconnecting to existing kernel and display widget by id This allows an ipypopout or a similar library to open a new browser window and show a widget that is already running in the main window. Note that this is limited to the same browser, because the session_id is required to be the same. This is a security feature. --- packages/solara-widget-manager/src/manager.ts | 5 + solara/server/flask.py | 3 +- solara/server/kernel.py | 7 + solara/server/kernel_context.py | 141 +++++++++++++++++- solara/server/server.py | 58 +++---- solara/server/settings.py | 13 ++ solara/server/starlette.py | 3 +- solara/server/static/main-vuetify.js | 25 +++- solara/util.py | 27 ++++ tests/integration/lifecycle_test.py | 90 +++++++++++ tests/integration/popout_test.py | 52 +++++++ tests/unit/conftest.py | 2 +- tests/unit/lifecycle_test.py | 108 ++++++++++++++ tests/unit/output_widget_test.py | 4 +- tests/unit/patch_test.py | 6 +- tests/unit/shell_test.py | 4 +- tests/unit/toestand_test.py | 8 +- 17 files changed, 501 insertions(+), 55 deletions(-) create mode 100644 tests/integration/lifecycle_test.py create mode 100644 tests/integration/popout_test.py create mode 100644 tests/unit/lifecycle_test.py diff --git a/packages/solara-widget-manager/src/manager.ts b/packages/solara-widget-manager/src/manager.ts index e1419c460..3a9a2c65f 100644 --- a/packages/solara-widget-manager/src/manager.ts +++ b/packages/solara-widget-manager/src/manager.ts @@ -135,6 +135,11 @@ export class WidgetManager extends JupyterLabManager { } } + async fetchAll() { + // fetch all widgets + await this._loadFromKernel(); + } + async run(appName: string, path: string) { // used for routing // should be similar to what we do in navigator.vue diff --git a/solara/server/flask.py b/solara/server/flask.py index 7e7897413..594eac8f2 100644 --- a/solara/server/flask.py +++ b/solara/server/flask.py @@ -143,9 +143,10 @@ def kernels_connection(ws: simple_websocket.Server, kernel_id: str, name: str): @blueprint.route("/_solara/api/close/", methods=["GET", "POST"]) def close(kernel_id: str): + page_id = request.args["session_id"] if kernel_id in kernel_context.contexts: context = kernel_context.contexts[kernel_id] - context.close() + context.page_close(page_id) return "" diff --git a/solara/server/kernel.py b/solara/server/kernel.py index dbdaf302c..cd7af0346 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -218,6 +218,13 @@ def __init__(self, *args, **kwargs): super(SessionWebsocket, self).__init__(*args, **kwargs) self.websockets: Set[websocket.WebsocketWrapper] = set() # map from .. msg id to websocket? + def close(self): + for ws in list(self.websockets): + try: + ws.close() + except: # noqa + pass + def send(self, stream, msg_or_type, content=None, parent=None, ident=None, buffers=None, track=False, header=None, metadata=None): try: if isinstance(msg_or_type, dict): diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index e7772d145..562503a6d 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,8 +1,11 @@ +import asyncio import dataclasses +import enum import logging import os import pickle import threading +import time from pathlib import Path from typing import Any, Callable, Dict, List, Optional, cast @@ -10,6 +13,9 @@ import reacton from ipywidgets import DOMWidget, Widget +import solara.server.settings +import solara.util + from . import kernel, kernel_context, websocket from .kernel import Kernel, WebsocketStreamWrapper @@ -24,10 +30,20 @@ class Local(threading.local): local = Local() +class PageStatus(enum.Enum): + CONNECTED = "connected" + DISCONNECTED = "disconnected" + CLOSED = "closed" + + @dataclasses.dataclass class VirtualKernelContext: id: str kernel: kernel.Kernel + # we keep track of the session id to prevent kernel hijacking + # to 'steal' a kernel, one would need to know the session id + # *and* the kernel id + session_id: str control_sockets: List[WebSocket] = dataclasses.field(default_factory=list) # this is the 'private' version of the normally global ipywidgets.Widgets.widget dict # see patch.py @@ -42,6 +58,11 @@ class VirtualKernelContext: reload: Callable = lambda: None state: Any = None container: Optional[DOMWidget] = None + # we track which pages are connected to implement kernel culling + page_status: Dict[str, PageStatus] = dataclasses.field(default_factory=dict) + # only used for testing + _last_kernel_cull_task: "Optional[asyncio.Future[None]]" = None + closed: bool = False def display(self, *args): print(args) # noqa @@ -59,6 +80,7 @@ def __exit__(self, *args): current_context[key] = local.kernel_context_stack.pop() def close(self): + logger.info("Shut down virtual kernel: %s", self.id) with self: if self.app_object is not None: if isinstance(self.app_object, reacton.core._RenderContext): @@ -71,8 +93,10 @@ def close(self): # what if we reference each other # import gc # gc.collect() + self.kernel.session.close() if self.id in contexts: del contexts[self.id] + self.closed = True def _state_reset(self): state_directory = Path(".") / "states" @@ -97,6 +121,94 @@ def state_save(self, state_directory: os.PathLike): logger.debug("State: %r", state) pickle.dump(state, f) + def page_connect(self, page_id: str): + logger.info("Connect page %s for kernel %s", page_id, self.id) + assert self.page_status.get(page_id) != PageStatus.CLOSED, "cannot connect with the same page_id after a close" + self.page_status[page_id] = PageStatus.CONNECTED + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + + def page_disconnect(self, page_id: str) -> "asyncio.Future[None]": + """Signal that a page has disconnected, and schedule a kernel cull if needed. + + During the kernel reconnect window, we will keep the kernel alive, even if all pages have disconnected. + + Returns a future that is set when the kernel cull is done. + The scheduled kernel cull can be cancelled when a new page connects, a new disconnect is scheduled, + or a page if explicitly closed. + """ + logger.info("Disconnect page %s for kernel %s", page_id, self.id) + future: "asyncio.Future[None]" = asyncio.Future() + self.page_status[page_id] = PageStatus.DISCONNECTED + current_event_loop = asyncio.get_event_loop() + + async def kernel_cull(): + try: + cull_timeout_sleep_seconds = solara.util.parse_timedelta(solara.server.settings.kernel.cull_timeout) + logger.info("Scheduling kernel cull, will wait for max %s before shutting down the virtual kernel %s", cull_timeout_sleep_seconds, self.id) + await asyncio.sleep(cull_timeout_sleep_seconds) + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + if has_connected_pages: + logger.info("We have (re)connected pages, keeping the virtual kernel %s alive", self.id) + else: + logger.info("No connected pages, and timeout reached, shutting down virtual kernel %s", self.id) + self.close() + current_event_loop.call_soon_threadsafe(future.set_result, None) + except asyncio.CancelledError: + current_event_loop.call_soon_threadsafe(future.cancel, "cancelled because a new cull task was scheduled") + raise + + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + if not has_connected_pages: + # when we have no connected pages, we will schedule a kernel cull + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + + async def create_task(): + task = asyncio.create_task(kernel_cull()) + # create a reference to the task so we can cancel it later + self._last_kernel_cull_task = task + await task + + asyncio.run_coroutine_threadsafe(create_task(), keep_alive_event_loop) + else: + future.set_result(None) + return future + + def page_close(self, page_id: str): + """Signal that a page has closed, and close the context if needed. + + Closing the browser tab or a page navigation means an explicit close, which is + different from a websocket/page disconnect, which we might want to recover from. + + """ + self.page_status[page_id] = PageStatus.CLOSED + logger.info("Close page %s for kernel %s", page_id, self.id) + has_connected_pages = PageStatus.CONNECTED in self.page_status.values() + has_disconnected_pages = PageStatus.DISCONNECTED in self.page_status.values() + if not (has_connected_pages or has_disconnected_pages): + logger.info("No connected or disconnected pages, shutting down virtual kernel %s", self.id) + if self._last_kernel_cull_task: + self._last_kernel_cull_task.cancel() + self.close() + + +try: + # Normal Python + keep_alive_event_loop = asyncio.new_event_loop() + + def _run(): + asyncio.set_event_loop(keep_alive_event_loop) + try: + keep_alive_event_loop.run_forever() + except Exception: + logger.exception("Error in keep alive event loop") + raise + + threading.Thread(target=_run, daemon=True).start() +except RuntimeError: + # Emscripten/pyodide/lite + keep_alive_event_loop = asyncio.get_event_loop() contexts: Dict[str, VirtualKernelContext] = {} # maps from thread key to VirtualKernelContext, if VirtualKernelContext is None, it exists, but is not set as current @@ -108,6 +220,7 @@ def create_dummy_context(): kernel_context = VirtualKernelContext( id="dummy", + session_id="dummy", kernel=kernel.Kernel(), ) return kernel_context @@ -154,15 +267,27 @@ def set_current_context(context: Optional[VirtualKernelContext]): current_context[thread_key] = context -def initialize_virtual_kernel(kernel_id: str, websocket: websocket.WebsocketWrapper): - import solara.server.app - - kernel = Kernel() - logger.info("new virtual kernel: %s", kernel_id) - context = contexts[kernel_id] = VirtualKernelContext(id=kernel_id, kernel=kernel, control_sockets=[], widgets={}, templates={}) +def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websocket.WebsocketWrapper): + from solara.server import app as appmodule + + if kernel_id in contexts: + logger.info("reusing virtual kernel: %s", kernel_id) + context = contexts[kernel_id] + if context.session_id != session_id: + logger.critical("Session id mismatch when reusing kernel (hack attempt?): %s != %s", context.session_id, session_id) + websocket.send_text("Session id mismatch when reusing kernel (hack attempt?)") + # to avoid very fast reconnects (we are in a thread anyway) + time.sleep(0.5) + raise ValueError("Session id mismatch") + kernel = context.kernel + else: + kernel = Kernel() + logger.info("new virtual kernel: %s", kernel_id) + context = contexts[kernel_id] = VirtualKernelContext(id=kernel_id, session_id=session_id, kernel=kernel, control_sockets=[], widgets={}, templates={}) + with context: + widgets.register_comm_target(kernel) + appmodule.register_solara_comm_target(kernel) with context: - widgets.register_comm_target(kernel) - solara.server.app.register_solara_comm_target(kernel) assert kernel is Kernel.instance() kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell") kernel.control_stream = WebsocketStreamWrapper(websocket, "control") diff --git a/solara/server/server.py b/solara/server/server.py index 6f1f11c69..d03748352 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -107,7 +107,7 @@ def is_ready(url) -> bool: async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: str, page_id: str, user: dict = None): - context = initialize_virtual_kernel(kernel_id, ws) + context = initialize_virtual_kernel(session_id, kernel_id, ws) if context is None: logging.warning("invalid kernel id: %r", kernel_id) # to avoid very fast reconnects (we are in a thread anyway) @@ -124,35 +124,39 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s run_context = solara.util.nullcontext() kernel = context.kernel - with run_context, context: - if user: - from solara_enterprise.auth import user as solara_user + try: + context.page_connect(page_id) + with run_context, context: + if user: + from solara_enterprise.auth import user as solara_user - solara_user.set(user) + solara_user.set(user) - while True: - try: - message = await ws.receive() - except websocket.WebSocketDisconnect: + while True: try: - context.kernel.session.websockets.remove(ws) - except KeyError: - pass - logger.debug("Disconnected") - return - t0 = time.time() - if isinstance(message, str): - msg = json.loads(message) - else: - msg = deserialize_binary_message(message) - t1 = time.time() - if not process_kernel_messages(kernel, msg): - # if we shut down the kernel, we do not keep the page session alive - context.close() - return - t2 = time.time() - if settings.main.timing: - print(f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s") # noqa: T201 + message = await ws.receive() + except websocket.WebSocketDisconnect: + try: + context.kernel.session.websockets.remove(ws) + except KeyError: + pass + logger.debug("Disconnected") + break + t0 = time.time() + if isinstance(message, str): + msg = json.loads(message) + else: + msg = deserialize_binary_message(message) + t1 = time.time() + if not process_kernel_messages(kernel, msg): + # if we shut down the kernel, we do not keep the page session alive + context.close() + return + t2 = time.time() + if settings.main.timing: + print(f"timing: total={t2-t0:.3f}s, deserialize={t1-t0:.3f}s, kernel={t2-t1:.3f}s") # noqa: T201 + finally: + context.page_disconnect(page_id) def process_kernel_messages(kernel: Kernel, msg: Dict) -> bool: diff --git a/solara/server/settings.py b/solara/server/settings.py index a3a7ea925..58f8607a0 100644 --- a/solara/server/settings.py +++ b/solara/server/settings.py @@ -10,6 +10,7 @@ from filelock import FileLock +import solara.util from solara.minisettings import BaseSettings from .. import ( # noqa # sidefx is that this module creates the ~/.solara directory @@ -85,6 +86,15 @@ class Config: env_file = ".env" +class Kernel(BaseSettings): + cull_timeout: str = "24h" + + class Config: + env_prefix = "solara_kernel_" + case_sensitive = False + env_file = ".env" + + AUTH0_TEST_CLIENT_ID = "cW7owP5Q52YHMZAnJwT8FPlH2ZKvvL3U" AUTH0_TEST_CLIENT_SECRET = "zxITXxoz54OjuSmdn-PluQgAwbeYyoB7ALlnLoodftvAn81usDXW0quchvoNvUYD" AUTH0_TEST_API_BASE_URL = "dev-y02f2bpr8skxu785.us.auth0.com" @@ -156,6 +166,9 @@ class Config: assets = Assets() oauth = OAuth() session = Session() +kernel = Kernel() +# fail early +solara.util.parse_timedelta(kernel.cull_timeout) if assets.proxy: try: diff --git a/solara/server/starlette.py b/solara/server/starlette.py index 3b2cb0f4b..3d74b39b5 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -212,9 +212,10 @@ async def run(): def close(request: Request): kernel_id = request.path_params["kernel_id"] + page_id = request.query_params["session_id"] if kernel_id in kernel_context.contexts: context = kernel_context.contexts[kernel_id] - context.close() + context.page_close(page_id) response = HTMLResponse(content="", status_code=200) return response diff --git a/solara/server/static/main-vuetify.js b/solara/server/static/main-vuetify.js index 31950ccc5..cb7e5d25d 100644 --- a/solara/server/static/main-vuetify.js +++ b/solara/server/static/main-vuetify.js @@ -114,18 +114,22 @@ async function solaraInit(mountId, appName) { define("vue", [], () => Vue); define("vuetify", [], { framework: app.$vuetify }); cookies = getCookiesMap(document.cookie); - uuid = generateUuid() + const searchParams = new URLSearchParams(window.location.search); + let kernelId = searchParams.get('kernelid') || generateUuid() let unloading = false; window.addEventListener('beforeunload', function (e) { unloading = true; kernel.dispose() - window.navigator.sendBeacon(close_url); + // allow to opt-out to make testing easier + if (!searchParams.has('solara-no-close-beacon')) { + window.navigator.sendBeacon(close_url); + } }); - let kernel = await solara.connectKernel(solara.rootPath + '/jupyter', uuid) + let kernel = await solara.connectKernel(solara.rootPath + '/jupyter', kernelId) if (!kernel) { return; } - const close_url = solara.rootPath + '/_solara/api/close/' + kernel.clientId; + const close_url = solara.rootPath + '/_solara/api/close/' + kernelId + "?session_id=" + kernel.clientId; let skipReconnectedCheck = true; kernel.statusChanged.connect(() => { app.$data.kernelBusy = kernel.status == 'busy'; @@ -202,8 +206,17 @@ async function solaraInit(mountId, appName) { // it seems if we attach this to early, it will not be called app.$data.loading_text = 'Loading app'; const path = window.location.pathname.slice(solara.rootPath.length); - const widgetId = await widgetManager.run(appName, path); - await solaraMount(widgetManager, mountId || 'content', widgetId); + let widgetModelId = searchParams.get('modelid'); + // if kernelid and modelid are given as query parameters, we will use them + // instead of running the current solara app. This allows usage such as + // ipypopout, which reconnects to an existing kernel and shows a particular + // widget. + if (kernelId && widgetModelId) { + await widgetManager.fetchAll(); + } else { + widgetModelId = await widgetManager.run(appName, path); + } + await solaraMount(widgetManager, mountId || 'content', widgetModelId); skipReconnectedCheck = false; solara.renderMathJax(); } diff --git a/solara/util.py b/solara/util.py index 555a94d66..97730dcc3 100644 --- a/solara/util.py +++ b/solara/util.py @@ -218,3 +218,30 @@ def tracefunc(frame, event, arg): finally: if hasattr(sys, "settrace"): sys.settrace(prev) + + +def parse_timedelta(size: str) -> float: + """Turn a human readable time delta into seconds. + Supports days(d), hours (h), minutes (m) and seconds (s). + If not unit is specified, seconds is assumed. + >>> parse_timedelta("1d") + 86400 + >>> parse_timedelta("1h") + 3600 + >>> parse_timedelta("30m") + 1800 + >>> parse_timedelta("10s") + 10 + >>> parse_timedelta("10") + 10 + """ + if size.endswith("d"): + return float(size[:-1]) * 24 * 60 * 60 + elif size.endswith("h"): + return float(size[:-1]) * 60 * 60 + elif size.endswith("m"): + return float(size[:-1]) * 60 + elif size.endswith("s"): + return float(size[:-1]) + else: + return float(size) diff --git a/tests/integration/lifecycle_test.py b/tests/integration/lifecycle_test.py new file mode 100644 index 000000000..b35a99dc5 --- /dev/null +++ b/tests/integration/lifecycle_test.py @@ -0,0 +1,90 @@ +import threading +from pathlib import Path +from typing import cast + +import playwright.sync_api +import pytest +from reacton.core import _RenderContext + +import solara.server.kernel_context +import solara.server.server +import solara.server.settings +from solara.server import kernel_context + +HERE = Path(__file__).parent + + +@solara.component +def ClickButton(label="Clicks"): + clicks = solara.use_reactive(0) + solara.Button(label=f"{label}-{clicks.value}", on_click=lambda: clicks.set(clicks.value + 1)) + + +@pytest.fixture +def short_cull_timeout(): + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + solara.server.settings.kernel.cull_timeout = "1.0s" + try: + yield + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous + + +def test_kernel_lifecycle_close_single( + short_cull_timeout, + browser: playwright.sync_api.Browser, + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, +): + with extra_include_path(HERE), solara_app("lifecycle_test:ClickButton"): + page_session.goto(solara_server.base_url) + page_session.locator("text=Clicks-0").click() + contexts = list(kernel_context.contexts.values()) + assert len(contexts) == 1 + context = contexts[0] + assert not context.closed + page_session.goto("about:blank") + page_session.wait_for_timeout(100) + assert context.closed + + +def test_kernel_lifecycle_close_while_disconnected( + short_cull_timeout, + browser: playwright.sync_api.Browser, + page_session: playwright.sync_api.Page, + solara_server, + solara_app, + extra_include_path, +): + with extra_include_path(HERE), solara_app("lifecycle_test:ClickButton"): + page_session.goto(solara_server.base_url + "?solara-no-close-beacon") + page_session.locator("text=Clicks-0").click() + contexts = list(kernel_context.contexts.values()) + assert len(contexts) == 1 + context = contexts[0] + assert not context.closed + page_session.wait_for_timeout(100) + page_session.goto("about:blank") + + kernel_id = context.id + rc = cast(_RenderContext, context.app_object) + widget = rc.container + assert widget is not None + model_id = widget._model_id + + page_session.goto(solara_server.base_url + f"?kernelid={kernel_id}&modelid={model_id}") + # make sure the page is functional + page_session.locator("text=Clicks-1").click() + page_session.locator("text=Clicks-2").wait_for() + page_session.goto("about:blank") + # give a bit of time to make sure the cull task is started + page_session.wait_for_timeout(100) + cull_task_2 = context._last_kernel_cull_task + assert cull_task_2 is not None + # we can't mix do async, so we hook up an event to the Future + event = threading.Event() + cull_task_2.add_done_callback(lambda x: event.set()) + event.wait() + assert context.closed diff --git a/tests/integration/popout_test.py b/tests/integration/popout_test.py new file mode 100644 index 000000000..efb8e4258 --- /dev/null +++ b/tests/integration/popout_test.py @@ -0,0 +1,52 @@ +from pathlib import Path +from typing import cast + +import playwright +import playwright.sync_api +from reacton.core import _RenderContext + +import solara +from solara.server import kernel_context + +HERE = Path(__file__).parent + + +@solara.component +def TwoTexts(): + with solara.Div(classes=["solara-test-div"]): + solara.Text("AAA") + solara.Text("BBB") + + +def test_popout(page_session: playwright.sync_api.Page, solara_server, solara_app, extra_include_path): + with extra_include_path(HERE), solara_app("popout_test:TwoTexts"): + page_session.goto(solara_server.base_url) + el = page_session.locator("text=AAA") + assert el.text_content() == "AAA" + contexts = list(kernel_context.contexts.values()) + assert len(contexts) == 1 + context = contexts[0] + kernel_id = context.id + rc = cast(_RenderContext, context.app_object) + widget = rc.find(children=["BBB"]).widget + model_id = widget._model_id + + # we should not lose the context, it should be kept alive + page_session.goto("about:blank") + page_session.wait_for_timeout(100) + contexts = list(kernel_context.contexts.values()) + assert len(contexts) == 1 + + page_session.goto(solara_server.base_url + f"?kernelid={kernel_id}&modelid={model_id}") + page_session.locator("text=BBB").wait_for() + page_session.locator("text=AAA").wait_for(state="detached") + + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + try: + solara.server.settings.kernel.cull_timeout = "0s" + page_session.goto("about:blank") + page_session.wait_for_timeout(1000) + contexts = list(kernel_context.contexts.values()) + assert len(contexts) == 0 + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 685e7c1e4..df7ad3f3f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -9,7 +9,7 @@ @pytest.fixture(autouse=True) def kernel_context(): kernel_shared = kernel.Kernel() - context = VirtualKernelContext(id="1", kernel=kernel_shared) + context = VirtualKernelContext(id="1", kernel=kernel_shared, session_id="session-1") try: with context: yield context diff --git a/tests/unit/lifecycle_test.py b/tests/unit/lifecycle_test.py new file mode 100644 index 000000000..6b830f6eb --- /dev/null +++ b/tests/unit/lifecycle_test.py @@ -0,0 +1,108 @@ +import asyncio +import time +from unittest.mock import Mock + +import pytest + +import solara.server.server +import solara.server.settings +from solara.server import kernel_context + + +@pytest.fixture +def short_cull_timeout(): + cull_timeout_previous = solara.server.settings.kernel.cull_timeout + solara.server.settings.kernel.cull_timeout = "0.2s" + try: + yield + finally: + solara.server.settings.kernel.cull_timeout = cull_timeout_previous + + +async def test_kernel_lifecycle_reconnect_simple(short_cull_timeout): + # a reconnect should be possible within the reconnect window + websocket = Mock() + context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket) + context.page_connect("page-id-1") + cull_task1 = context.page_disconnect("page-id-1") + await asyncio.sleep(0.01) + context.page_connect("page-id-2") + # the new connect should cancel the first cull task + with pytest.raises(asyncio.CancelledError): + await cull_task1 + assert not context.closed + await context.page_disconnect("page-id-2") + assert context.closed + + +async def test_kernel_lifecycle_double_disconnect(short_cull_timeout): + # a reconnect should be possible within the reconnect window + websocket = Mock() + context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket) + context.page_connect("page-id-1") + cull_task1 = context.page_disconnect("page-id-1") + + # now after 0.1 we disconnect the 2nd time + await asyncio.sleep(0.1) + context.page_connect("page-id-2") + cull_task2 = context.page_disconnect("page-id-2") + t_disconnect_page_2 = time.time() + t0_disconnect_page_2 = time.time() + + # go over the reconnect window of cull_task1 (with a 0.05 extra to make sure it is really over) + # await asyncio.sleep(0.1 + 0.05) + # but the first disconnect should not have closed the kernel context yet + with pytest.raises(asyncio.CancelledError): + await cull_task1 + assert (time.time() - t_disconnect_page_2) < 0.001, "should be cancelled really quickly" + + assert not context.closed + await cull_task2 + assert context.closed + # the context should be closed AFTER 0.2 seconds, but it could take a bit longer + assert 0.3 >= (time.time() - t0_disconnect_page_2) >= 0.2 + + +@pytest.mark.parametrize("close_first", [True, False]) +async def test_kernel_lifecycle_close_single(close_first, short_cull_timeout): + # a reconnect should be possible within the reconnect window + websocket = Mock() + context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket) + context.page_connect("page-id-1") + if close_first: + context.page_close("page-id-1") + assert context.closed + context.page_disconnect("page-id-1") + else: + context.page_disconnect("page-id-1") + assert not context.closed + context.page_close("page-id-1") + assert context.closed + + +@pytest.mark.parametrize("close_first", [True, False]) +async def test_kernel_lifecycle_close_while_disconnected(close_first, short_cull_timeout): + # a reconnect should be possible within the reconnect window + websocket = Mock() + context = kernel_context.initialize_virtual_kernel("session-id-1", "kernel-id-1", websocket) + context.page_connect("page-id-1") + cull_task_1 = context.page_disconnect("page-id-1") + await asyncio.sleep(0.1) + # after 0.1 we connect again, but close it directly + context.page_connect("page-id-2") + if close_first: + context.page_close("page-id-2") + await asyncio.sleep(0.01) + cull_task_2 = context.page_disconnect("page-id-2") + else: + cull_task_2 = context.page_disconnect("page-id-2") + await asyncio.sleep(0.01) + context.page_close("page-id-2") + assert not context.closed + await asyncio.sleep(0.15) + # but even though we closed, the first page is still in the disconnected state + with pytest.raises(asyncio.CancelledError): + await cull_task_1 + assert not context.closed + await cull_task_2 + assert context.closed diff --git a/tests/unit/output_widget_test.py b/tests/unit/output_widget_test.py index 6f061dbe3..5b9dc86d0 100644 --- a/tests/unit/output_widget_test.py +++ b/tests/unit/output_widget_test.py @@ -13,8 +13,8 @@ def test_interactive_shell(no_kernel_context): kernel2 = kernel.Kernel() kernel1.session.websockets.add(ws1) kernel2.session.websockets.add(ws2) - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1) - context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2, session_id="session-2") with context1: output1 = widgets.Output() diff --git a/tests/unit/patch_test.py b/tests/unit/patch_test.py index 0970f61eb..15976bc30 100644 --- a/tests/unit/patch_test.py +++ b/tests/unit/patch_test.py @@ -13,7 +13,7 @@ def test_widget_error_message_outside_context(no_kernel_context): theme.get_state() kernel_shared = kernel.Kernel() - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, session_id="session-1") with pytest.raises(RuntimeError): with context1: assert theme.model_id @@ -21,8 +21,8 @@ def test_widget_error_message_outside_context(no_kernel_context): def test_widget_dict(no_kernel_context): kernel_shared = kernel.Kernel() - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) - context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel_shared, session_id="session-2") with context1: btn1 = widgets.Button(description="context1") diff --git a/tests/unit/shell_test.py b/tests/unit/shell_test.py index 8839b76a2..bd4285006 100644 --- a/tests/unit/shell_test.py +++ b/tests/unit/shell_test.py @@ -12,8 +12,8 @@ def test_shell(no_kernel_context): kernel2 = kernel.Kernel() kernel1.session.websockets.add(ws1) kernel2.session.websockets.add(ws2) - context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1) - context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2, session_id="session-2") with context1: IPython.display.display("test1") diff --git a/tests/unit/toestand_test.py b/tests/unit/toestand_test.py index 7a97f1fed..68e3f13f0 100644 --- a/tests/unit/toestand_test.py +++ b/tests/unit/toestand_test.py @@ -121,8 +121,8 @@ def test_scopes(no_kernel_context): kernel_shared = kernel.Kernel() assert kernel_context.current_context[kernel_context.get_current_thread_key()] is None - context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) - context2 = kernel_context.VirtualKernelContext(id="toestand-2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel_shared, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="toestand-2", kernel=kernel_shared, session_id="session-2") with context1: mock1 = unittest.mock.Mock() @@ -378,8 +378,8 @@ def App(): # the storage should live in the app context to support multiple users/connections kernel_shared = kernel.Kernel() - context1 = kernel_context.VirtualKernelContext(id="bear-1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) - context2 = kernel_context.VirtualKernelContext(id="bear-2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context1 = kernel_context.VirtualKernelContext(id="bear-1", kernel=kernel_shared, session_id="session-1") + context2 = kernel_context.VirtualKernelContext(id="bear-2", kernel=kernel_shared, session_id="session-2") rcs = [] for context in [context1, context2]: with context: