From 83581057a88a43cce2a907be9df566462ce25340 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Tue, 3 Oct 2023 17:01:12 +0200 Subject: [PATCH] refactor: use kernel_id instead of session_id We have been using the jupyter session_id as a unique key for the ill defined concept of a "AppContext". For consistency, we now use the kernel_id as the unique key and renamed the AppContext to VirtualKernelContext. A VirtualKernelContext is a context manager that needs to wrap any code that uses a kernel, either explicitly or implicitly, such as widgets. --- packages/solara-widget-manager/src/kernel.ts | 2 +- solara/scope/__init__.py | 4 +- solara/server/app.py | 186 ++----------------- solara/server/flask.py | 20 +- solara/server/kernel.py | 2 +- solara/server/kernel_context.py | 170 +++++++++++++++++ solara/server/patch.py | 32 ++-- solara/server/server.py | 11 +- solara/server/starlette.py | 33 ++-- solara/server/telemetry.py | 4 +- solara/test/pytest_plugin.py | 4 +- solara/toestand.py | 6 +- tests/integration/ssg_test.py | 4 +- tests/unit/app_test.py | 30 +-- tests/unit/conftest.py | 15 +- tests/unit/no_solara_test.py | 6 +- tests/unit/output_widget_test.py | 8 +- tests/unit/patch_test.py | 12 +- tests/unit/shell_test.py | 8 +- tests/unit/toestand_test.py | 44 ++--- 20 files changed, 312 insertions(+), 289 deletions(-) create mode 100644 solara/server/kernel_context.py diff --git a/packages/solara-widget-manager/src/kernel.ts b/packages/solara-widget-manager/src/kernel.ts index 966b8b061..49f5977a4 100644 --- a/packages/solara-widget-manager/src/kernel.ts +++ b/packages/solara-widget-manager/src/kernel.ts @@ -27,7 +27,7 @@ export async function connectKernel( // if (!model) { // return; // } - const model = { 'id': 'solara-id', 'name': 'solara-name' } + const model = { 'id': kernelId, 'name': 'solara-name' } const kernel = new KernelConnection({ model, serverSettings }); return kernel; } diff --git a/solara/scope/__init__.py b/solara/scope/__init__.py index 60af120c3..efcead322 100644 --- a/solara/scope/__init__.py +++ b/solara/scope/__init__.py @@ -18,9 +18,9 @@ def __init__(self, name="connection"): def _get_dict(self) -> MutableMapping: if _in_solara_server(): - import solara.server.app + import solara.server.kernel_context - context = solara.server.app.get_current_context() + context = solara.server.kernel_context.get_current_context() if self.name not in context.user_dicts: with self.lock: if self.name not in context.user_dicts: diff --git a/solara/server/app.py b/solara/server/app.py index 43c1fdfc1..4a8e02177 100644 --- a/solara/server/app.py +++ b/solara/server/app.py @@ -1,4 +1,3 @@ -import dataclasses import importlib.util import logging import os @@ -8,17 +7,16 @@ import traceback from enum import Enum from pathlib import Path -from typing import Any, Callable, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, cast import ipywidgets as widgets import reacton -from ipywidgets import DOMWidget, Widget from reacton.core import Element, render import solara -from . import kernel, reload, settings, websocket -from .kernel import Kernel, WebsocketStreamWrapper +from . import kernel_context, reload, settings +from .kernel import Kernel from .utils import pdb_guard WebSocket = Any @@ -30,13 +28,6 @@ reload.reloader.start() -class Local(threading.local): - app_context_stack: Optional[List[Optional["AppContext"]]] = None - - -local = Local() - - class AppType(str, Enum): SCRIPT = "script" NOTEBOOK = "notebook" @@ -68,13 +59,13 @@ def __init__(self, name, default_app_name="Page"): self.name = name self.path: Path = Path(self.name).resolve() try: - context = get_current_context() + context = kernel_context.get_current_context() except RuntimeError: context = None if context is not None: raise RuntimeError(f"We should not have an existing Solara app context when running an app for the first time: {context}") - app_context = create_dummy_context() - with app_context: + dummy_kernel_context = kernel_context.create_dummy_context() + with dummy_kernel_context: app = self._execute() self._first_execute_app = app @@ -85,7 +76,7 @@ def __init__(self, name, default_app_name="Page"): if mod.__file__ is not None: package_root_path = Path(mod.__file__).parent reload.reloader.root_path = package_root_path - app_context.close() + dummy_kernel_context.close() def _execute(self): logger.info("Executing %s", self.name) @@ -187,8 +178,8 @@ def add_path(): def close(self): reload.reloader.on_change = None - context_values = list(contexts.values()) - contexts.clear() + context_values = list(kernel_context.contexts.values()) + kernel_context.contexts.clear() for context in context_values: context.close() @@ -204,7 +195,7 @@ def on_file_change(self, name): if path.suffix == ".vue": logger.info("Vue file changed: %s", name) template_content = path.read_text() - for context in list(contexts.values()): + for context in list(kernel_context.contexts.values()): with context: for filepath, widget in context.templates.items(): if filepath == str(path): @@ -224,7 +215,7 @@ def reload(self): solara.lab.toestand.ConnectionStore._type_counter.clear() - context_values = list(contexts.values()) + context_values = list(kernel_context.contexts.values()) # save states into the context so the hot reload will # keep the same state for context in context_values: @@ -256,136 +247,6 @@ def reload(self): context.reload() -@dataclasses.dataclass -class AppContext: - id: str - kernel: kernel.Kernel - control_sockets: List[WebSocket] = dataclasses.field(default_factory=list) - # this is the 'private' version of the normally global ipywidgets.Widgets.widget dict - # see patch.py - widgets: Dict[str, Widget] = dataclasses.field(default_factory=dict) - # same, for ipyvue templates - # see patch.py - templates: Dict[str, Widget] = dataclasses.field(default_factory=dict) - user_dicts: Dict[str, Dict] = dataclasses.field(default_factory=dict) - # anything we need to attach to the context - # e.g. for a react app the render context, so that we can store/restore the state - app_object: Optional[Any] = None - reload: Callable = lambda: None - state: Any = None - container: Optional[DOMWidget] = None - - def display(self, *args): - print(args) # noqa - - def __enter__(self): - if local.app_context_stack is None: - local.app_context_stack = [] - key = get_current_thread_key() - local.app_context_stack.append(current_context.get(key, None)) - current_context[key] = self - - def __exit__(self, *args): - key = get_current_thread_key() - assert local.app_context_stack is not None - current_context[key] = local.app_context_stack.pop() - - def close(self): - with self: - if self.app_object is not None: - if isinstance(self.app_object, reacton.core._RenderContext): - try: - self.app_object.close() - except Exception as e: - logger.exception("Could not close render context: %s", e) - # we want to continue, so we at least close all widgets - widgets.Widget.close_all() - # what if we reference each other - # import gc - # gc.collect() - if self.id in contexts: - del contexts[self.id] - - def _state_reset(self): - state_directory = Path(".") / "states" - state_directory.mkdir(exist_ok=True) - path = state_directory / f"{self.id}.pickle" - path = path.absolute() - try: - path.unlink() - except: # noqa - pass - del contexts[self.id] - key = get_current_thread_key() - del current_context[key] - - def state_save(self, state_directory: os.PathLike): - path = Path(state_directory) / f"{self.id}.pickle" - render_context = self.app_object - if render_context is not None: - render_context = cast(reacton.core._RenderContext, render_context) - state = render_context.state_get() - with path.open("wb") as f: - logger.debug("State: %r", state) - pickle.dump(state, f) - - -contexts: Dict[str, AppContext] = {} -# maps from thread key to AppContext, if AppContext is None, it exists, but is not set as current -current_context: Dict[str, Optional[AppContext]] = {} - - -def create_dummy_context(): - from . import kernel - - app_context = AppContext( - id="dummy", - kernel=kernel.Kernel(), - ) - return app_context - - -def get_current_thread_key() -> str: - thread = threading.current_thread() - return get_thread_key(thread) - - -def get_thread_key(thread: threading.Thread) -> str: - thread_key = thread._name + str(thread._ident) # type: ignore - return thread_key - - -def set_context_for_thread(context: AppContext, thread: threading.Thread): - key = get_thread_key(thread) - current_context[key] = context - - -def has_current_context() -> bool: - thread_key = get_current_thread_key() - return (thread_key in current_context) and (current_context[thread_key] is not None) - - -def get_current_context() -> AppContext: - thread_key = get_current_thread_key() - if thread_key not in current_context: - raise RuntimeError( - f"Tried to get the current context for thread {thread_key}, but no known context found. This might be a bug in Solara. " - f"(known contexts: {list(current_context.keys())}" - ) - context = current_context[thread_key] - if context is None: - raise RuntimeError( - f"Tried to get the current context for thread {thread_key!r}, although the context is know, it was not set for this thread. " - + "This might be a bug in Solara." - ) - return context - - -def set_current_context(context: Optional[AppContext]): - thread_key = get_current_thread_key() - current_context[thread_key] = context - - def _run_app( app_state, app_script: AppScript, @@ -398,7 +259,7 @@ def _run_app( if app_state: logger.info("Restoring state: %r", app_state) - context = get_current_context() + context = kernel_context.get_current_context() container = context.container if isinstance(main_object, widgets.Widget): return main_object, render_context @@ -438,7 +299,7 @@ def _run_app( def load_app_widget(app_state, app_script: AppScript, pathname: str): # load the app, and set it at the child of the context's container app_state_initial = app_state - context = get_current_context() + context = kernel_context.get_current_context() container = context.container assert container is not None try: @@ -482,7 +343,7 @@ def on_msg(msg): path = data.get("path", "") app_name = data.get("appName") or "__default__" app = apps[app_name] - context = get_current_context() + context = kernel_context.get_current_context() import ipyvuetify container = ipyvuetify.Html(tag="div") @@ -490,10 +351,10 @@ def on_msg(msg): load_app_widget(None, app, path) comm.send({"method": "finished", "widget_id": context.container._model_id}) elif method == "check": - context = get_current_context() + context = kernel_context.get_current_context() elif method == "reload": assert app is not None - context = get_current_context() + context = kernel_context.get_current_context() path = data.get("path", "") with context: load_app_widget(context.state, app, path) @@ -509,7 +370,7 @@ def reload(): logger.debug(f"Send reload to client: {context.id}") comm.send({"method": "reload"}) - context = get_current_context() + context = kernel_context.get_current_context() context.reload = reload @@ -517,19 +378,6 @@ def register_solara_comm_target(kernel: Kernel): kernel.comm_manager.register_target("solara.control", solara_comm_target) -def initialize_virtual_kernel(context_id: str, websocket: websocket.WebsocketWrapper): - kernel = Kernel() - logger.info("new virtual kernel: %s", context_id) - context = contexts[context_id] = AppContext(id=context_id, kernel=kernel, control_sockets=[], widgets={}, templates={}) - with context: - widgets.register_comm_target(kernel) - register_solara_comm_target(kernel) - assert kernel is Kernel.instance() - kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell") - kernel.control_stream = WebsocketStreamWrapper(websocket, "control") - kernel.session.websockets.add(websocket) - - from . import patch # noqa patch.patch() diff --git a/solara/server/flask.py b/solara/server/flask.py index 091363a3e..7e7897413 100644 --- a/solara/server/flask.py +++ b/solara/server/flask.py @@ -42,7 +42,7 @@ def allowed(): from solara.server.threaded import ServerBase from . import app as appmod -from . import cdn_helper, server, settings, websocket +from . import cdn_helper, kernel_context, server, settings, websocket os.environ["SERVER_SOFTWARE"] = "solara/" + str(solara.__version__) @@ -111,8 +111,8 @@ def kernels(id): return {"name": "lala", "id": "dsa"} -@websocket_extension.route("/jupyter/api/kernels//") -def kernels_connection(ws: simple_websocket.Server, id: str, name: str): +@websocket_extension.route("/jupyter/api/kernels//") +def kernels_connection(ws: simple_websocket.Server, kernel_id: str, name: str): if not settings.main.base_url: settings.main.base_url = url_for("blueprint-solara.read_root", _external=True) if settings.oauth.private and not has_solara_enterprise: @@ -127,24 +127,24 @@ def kernels_connection(ws: simple_websocket.Server, id: str, name: str): user = None try: - connection_id = request.args["session_id"] + page_id = request.args["session_id"] session_id = request.cookies.get(server.COOKIE_KEY_SESSION_ID) - logger.info("Solara kernel requested for session_id=%s connection_id=%s", session_id, connection_id) + logger.info("Solara kernel requested for session_id=%s kernel_id=%s", session_id, kernel_id) if session_id is None: logger.error("no session cookie") ws.close() return ws_wrapper = WebsocketWrapper(ws) - asyncio.run(server.app_loop(ws_wrapper, session_id=session_id, connection_id=connection_id, user=user)) + asyncio.run(server.app_loop(ws_wrapper, session_id=session_id, kernel_id=kernel_id, page_id=page_id, user=user)) except: # noqa logger.exception("Error in kernel handler") raise -@blueprint.route("/_solara/api/close/", methods=["GET", "POST"]) -def close(connection_id: str): - if connection_id in appmod.contexts: - context = appmod.contexts[connection_id] +@blueprint.route("/_solara/api/close/", methods=["GET", "POST"]) +def close(kernel_id: str): + if kernel_id in kernel_context.contexts: + context = kernel_context.contexts[kernel_id] context.close() return "" diff --git a/solara/server/kernel.py b/solara/server/kernel.py index 9a0c10cfb..dbdaf302c 100644 --- a/solara/server/kernel.py +++ b/solara/server/kernel.py @@ -97,7 +97,7 @@ def publish_msg(self, msg_type, data=None, metadata=None, buffers=None, **keys): comm.create_comm = Comm def get_comm_manager(): - from .app import get_current_context, has_current_context + from .kernel_context import get_current_context, has_current_context if has_current_context(): return get_current_context().kernel.comm_manager diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py new file mode 100644 index 000000000..e7772d145 --- /dev/null +++ b/solara/server/kernel_context.py @@ -0,0 +1,170 @@ +import dataclasses +import logging +import os +import pickle +import threading +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, cast + +import ipywidgets as widgets +import reacton +from ipywidgets import DOMWidget, Widget + +from . import kernel, kernel_context, websocket +from .kernel import Kernel, WebsocketStreamWrapper + +WebSocket = Any +logger = logging.getLogger("solara.server.app") + + +class Local(threading.local): + kernel_context_stack: Optional[List[Optional["kernel_context.VirtualKernelContext"]]] = None + + +local = Local() + + +@dataclasses.dataclass +class VirtualKernelContext: + id: str + kernel: kernel.Kernel + control_sockets: List[WebSocket] = dataclasses.field(default_factory=list) + # this is the 'private' version of the normally global ipywidgets.Widgets.widget dict + # see patch.py + widgets: Dict[str, Widget] = dataclasses.field(default_factory=dict) + # same, for ipyvue templates + # see patch.py + templates: Dict[str, Widget] = dataclasses.field(default_factory=dict) + user_dicts: Dict[str, Dict] = dataclasses.field(default_factory=dict) + # anything we need to attach to the context + # e.g. for a react app the render context, so that we can store/restore the state + app_object: Optional[Any] = None + reload: Callable = lambda: None + state: Any = None + container: Optional[DOMWidget] = None + + def display(self, *args): + print(args) # noqa + + def __enter__(self): + if local.kernel_context_stack is None: + local.kernel_context_stack = [] + key = get_current_thread_key() + local.kernel_context_stack.append(current_context.get(key, None)) + current_context[key] = self + + def __exit__(self, *args): + key = get_current_thread_key() + assert local.kernel_context_stack is not None + current_context[key] = local.kernel_context_stack.pop() + + def close(self): + with self: + if self.app_object is not None: + if isinstance(self.app_object, reacton.core._RenderContext): + try: + self.app_object.close() + except Exception as e: + logger.exception("Could not close render context: %s", e) + # we want to continue, so we at least close all widgets + widgets.Widget.close_all() + # what if we reference each other + # import gc + # gc.collect() + if self.id in contexts: + del contexts[self.id] + + def _state_reset(self): + state_directory = Path(".") / "states" + state_directory.mkdir(exist_ok=True) + path = state_directory / f"{self.id}.pickle" + path = path.absolute() + try: + path.unlink() + except: # noqa + pass + del contexts[self.id] + key = get_current_thread_key() + del current_context[key] + + def state_save(self, state_directory: os.PathLike): + path = Path(state_directory) / f"{self.id}.pickle" + render_context = self.app_object + if render_context is not None: + render_context = cast(reacton.core._RenderContext, render_context) + state = render_context.state_get() + with path.open("wb") as f: + logger.debug("State: %r", state) + pickle.dump(state, f) + + +contexts: Dict[str, VirtualKernelContext] = {} +# maps from thread key to VirtualKernelContext, if VirtualKernelContext is None, it exists, but is not set as current +current_context: Dict[str, Optional[VirtualKernelContext]] = {} + + +def create_dummy_context(): + from . import kernel + + kernel_context = VirtualKernelContext( + id="dummy", + kernel=kernel.Kernel(), + ) + return kernel_context + + +def get_current_thread_key() -> str: + thread = threading.current_thread() + return get_thread_key(thread) + + +def get_thread_key(thread: threading.Thread) -> str: + thread_key = thread._name + str(thread._ident) # type: ignore + return thread_key + + +def set_context_for_thread(context: VirtualKernelContext, thread: threading.Thread): + key = get_thread_key(thread) + current_context[key] = context + + +def has_current_context() -> bool: + thread_key = get_current_thread_key() + return (thread_key in current_context) and (current_context[thread_key] is not None) + + +def get_current_context() -> VirtualKernelContext: + thread_key = get_current_thread_key() + if thread_key not in current_context: + raise RuntimeError( + f"Tried to get the current context for thread {thread_key}, but no known context found. This might be a bug in Solara. " + f"(known contexts: {list(current_context.keys())}" + ) + context = current_context[thread_key] + if context is None: + raise RuntimeError( + f"Tried to get the current context for thread {thread_key!r}, although the context is know, it was not set for this thread. " + + "This might be a bug in Solara." + ) + return context + + +def set_current_context(context: Optional[VirtualKernelContext]): + thread_key = get_current_thread_key() + current_context[thread_key] = context + + +def initialize_virtual_kernel(kernel_id: str, websocket: websocket.WebsocketWrapper): + import solara.server.app + + kernel = Kernel() + logger.info("new virtual kernel: %s", kernel_id) + context = contexts[kernel_id] = VirtualKernelContext(id=kernel_id, kernel=kernel, control_sockets=[], widgets={}, templates={}) + with context: + widgets.register_comm_target(kernel) + solara.server.app.register_solara_comm_target(kernel) + assert kernel is Kernel.instance() + kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell") + kernel.control_stream = WebsocketStreamWrapper(websocket, "control") + kernel.session.websockets.add(websocket) + return context diff --git a/solara/server/patch.py b/solara/server/patch.py index cad646865..4a6c6f07b 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -13,7 +13,7 @@ import ipywidgets.widgets.widget_output from IPython.core.interactiveshell import InteractiveShell -from . import app, reload, settings +from . import app, kernel_context, reload, settings from .utils import pdb_guard logger = logging.getLogger("solara.server.patch") @@ -28,7 +28,7 @@ class FakeIPython: - def __init__(self, context: app.AppContext): + def __init__(self, context: kernel_context.VirtualKernelContext): self.context = context self.kernel = context.kernel self.display_pub = self.kernel.shell.display_pub @@ -78,8 +78,8 @@ def set_custom_exc(self, exc_tuple, handler): def kernel_instance_dispatch(cls, *args, **kwargs): - if app.has_current_context(): - context = app.get_current_context() + if kernel_context.has_current_context(): + context = kernel_context.get_current_context() return context.kernel else: return Kernel_instance_original(cls, *args, **kwargs) @@ -91,7 +91,7 @@ def kernel_instance_dispatch(cls, *args, **kwargs): def kernel_initialized_dispatch(cls): if app is None: # python is shutting down, and the comm dtor wants to send a close message return False - if app.has_current_context(): + if kernel_context.has_current_context(): return True else: return Kernel_initialized_initial(cls) @@ -101,8 +101,8 @@ def kernel_initialized_dispatch(cls): def interactive_shell_instance_dispatch(cls, *args, **kwargs): - if app.has_current_context(): - context = app.get_current_context() + if kernel_context.has_current_context(): + context = kernel_context.get_current_context() return context.kernel.shell else: return InteractiveShell_instance_initial(cls, *args, **kwargs) @@ -169,8 +169,8 @@ def display_solara( def get_ipython(): - if app.has_current_context(): - context = app.get_current_context() + if kernel_context.has_current_context(): + context = kernel_context.get_current_context() our_fake_ipython = FakeIPython(context) return our_fake_ipython else: @@ -199,8 +199,8 @@ def __setitem__(self, key, value): class context_dict_widgets(context_dict): def _get_context_dict(self) -> dict: - if app.has_current_context(): - context = app.get_current_context() + if kernel_context.has_current_context(): + context = kernel_context.get_current_context() return context.widgets else: return global_widgets_dict @@ -208,8 +208,8 @@ def _get_context_dict(self) -> dict: class context_dict_templates(context_dict): def _get_context_dict(self) -> dict: - if app.has_current_context(): - context = app.get_current_context() + if kernel_context.has_current_context(): + context = kernel_context.get_current_context() return context.templates else: return global_templates_dict @@ -220,7 +220,7 @@ def __init__(self, name): self.name = name def _get_context_dict(self) -> dict: - context = app.get_current_context() + context = kernel_context.get_current_context() if self.name not in context.user_dicts: context.user_dicts[self.name] = {} return context.user_dicts[self.name] @@ -245,14 +245,14 @@ def WidgetContextAwareThread__init__(self, *args, **kwargs): Thread__init__(self, *args, **kwargs) self.current_context = None try: - self.current_context = app.get_current_context() + self.current_context = kernel_context.get_current_context() except RuntimeError: logger.debug(f"No context for thread {self}") def Thread_debug_run(self): if self.current_context: - app.set_context_for_thread(self.current_context, self) + kernel_context.set_context_for_thread(self.current_context, self) with pdb_guard(): Thread__run(self) diff --git a/solara/server/server.py b/solara/server/server.py index ccf834411..6f1f11c69 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -16,8 +16,8 @@ import solara.routing from . import app, jupytertools, settings, websocket -from .app import initialize_virtual_kernel from .kernel import Kernel, deserialize_binary_message +from .kernel_context import initialize_virtual_kernel COOKIE_KEY_SESSION_ID = "solara-session-id" @@ -106,11 +106,10 @@ def is_ready(url) -> bool: return False -async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, connection_id: str, user: dict = None): - initialize_virtual_kernel(connection_id, ws) - context = app.contexts.get(connection_id) +async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: str, page_id: str, user: dict = None): + context = initialize_virtual_kernel(kernel_id, ws) if context is None: - logging.warning("invalid context id: %r", connection_id) + logging.warning("invalid kernel id: %r", kernel_id) # to avoid very fast reconnects (we are in a thread anyway) time.sleep(0.5) return @@ -118,7 +117,7 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, connection_i if settings.main.tracer: import viztracer - output_file = f"viztracer-{connection_id}.html" + output_file = f"viztracer-{page_id}.html" run_context = viztracer.VizTracer(output_file=output_file, max_stack_depth=10) logger.warning(f"Running with tracer: {output_file}") else: diff --git a/solara/server/starlette.py b/solara/server/starlette.py index ac7eb0760..b9d57c5db 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -46,7 +46,7 @@ from solara.server.threaded import ServerBase from . import app as appmod -from . import server, settings, telemetry, websocket +from . import kernel_context, server, settings, telemetry, websocket from .cdn_helper import cdn_url_path, get_path os.environ["SERVER_SOFTWARE"] = "solara/" + str(solara.__version__) @@ -166,12 +166,17 @@ async def kernel_connection(ws: starlette.websockets.WebSocket): logger.error("no session cookie") await ws.close() return - connection_id = ws.query_params["session_id"] - if not connection_id: - logger.error("no session_id/connection_id") + # we use the jupyter session_id query parameter as the key/id + # for a page scope. + page_id = ws.query_params["session_id"] + if not page_id: + logger.error("no page_id") + kernel_id = ws.path_params["kernel_id"] + if not kernel_id: + logger.error("no kernel_id") await ws.close() return - logger.info("Solara kernel requested for session_id=%s connection_id=%s", session_id, connection_id) + logger.info("Solara kernel requested for session_id=%s kernel_id=%s", session_id, kernel_id) await ws.accept() def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal): @@ -180,14 +185,14 @@ def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.fr async def run(): try: assert session_id is not None - assert connection_id is not None - telemetry.connection_open(session_id, connection_id) - await server.app_loop(ws_wrapper, session_id, connection_id, user) + assert kernel_id is not None + telemetry.connection_open(session_id) + await server.app_loop(ws_wrapper, session_id, kernel_id, page_id, user) except: # noqa await portal.stop(cancel_remaining=True) raise finally: - telemetry.connection_close(session_id, connection_id) + telemetry.connection_close(session_id) # sometimes throws: RuntimeError: Already running asyncio in this thread anyio.run(run) @@ -206,9 +211,9 @@ async def run(): def close(request: Request): - connection_id = request.path_params["connection_id"] - if connection_id in appmod.contexts: - context = appmod.contexts[connection_id] + kernel_id = request.path_params["kernel_id"] + if kernel_id in kernel_context.contexts: + context = kernel_context.contexts[kernel_id] context.close() response = HTMLResponse(content="", status_code=200) return response @@ -381,10 +386,10 @@ def readyz(request: Request): Route("/readyz", endpoint=readyz), *routes_auth, Route("/jupyter/api/kernels/{id}", endpoint=kernels), - WebSocketRoute("/jupyter/api/kernels/{id}/{name}", endpoint=kernel_connection), + WebSocketRoute("/jupyter/api/kernels/{kernel_id}/{name}", endpoint=kernel_connection), Route("/", endpoint=root), Route("/{fullpath}", endpoint=root), - Route("/_solara/api/close/{connection_id}", endpoint=close, methods=["POST"]), + Route("/_solara/api/close/{kernel_id}", endpoint=close, methods=["POST"]), # only enable when the proxy is turned on, otherwise if the directory does not exists we will get an exception *([Mount(f"/{cdn_url_path}", app=StaticCdn(directory=settings.assets.proxy_cache_dir))] if settings.assets.proxy else []), Mount(f"{prefix}/static/public", app=StaticPublic()), diff --git a/solara/server/telemetry.py b/solara/server/telemetry.py index 9b2dcdafd..52b1a54c1 100644 --- a/solara/server/telemetry.py +++ b/solara/server/telemetry.py @@ -146,12 +146,12 @@ def server_stop(): track("Solara server stop", {"duration_seconds": duration, **_usage_stats()}) -def connection_open(session_id, connection_id): +def connection_open(session_id): _connections_per_session_daily[session_id] += 1 _connections_per_session_cumulative[session_id] += 1 -def connection_close(session_id, connection_id): +def connection_close(session_id): pass diff --git a/solara/test/pytest_plugin.py b/solara/test/pytest_plugin.py index e6d195812..5138e0946 100644 --- a/solara/test/pytest_plugin.py +++ b/solara/test/pytest_plugin.py @@ -173,9 +173,9 @@ def solara_test(solara_server, solara_app, page_session: "playwright.sync_api.Pa run_event.wait() try: assert run_calls == 1 - keys = list(solara.server.app.contexts) + keys = list(solara.server.kernel_context.contexts) assert len(keys) == 1, "expected only one context, got %s" % keys - context = solara.server.app.contexts[keys[0]] + context = solara.server.kernel_context.contexts[keys[0]] with context: test_output_warmup = widgets.Output() test_output = widgets.Output() diff --git a/solara/toestand.py b/solara/toestand.py index a64c903a5..65c75f0aa 100644 --- a/solara/toestand.py +++ b/solara/toestand.py @@ -217,11 +217,11 @@ def _get_dict(self): scope_dict = self._global_dict scope_id = "global" if _using_solara_server(): - import solara.server.app + import solara.server.kernel_context try: - context = solara.server.app.get_current_context() - except: # noqa + context = solara.server.kernel_context.get_current_context() + except RuntimeError: # noqa pass # do we need to be more strict? else: scope_dict = cast(Dict[str, S], context.user_dicts) diff --git a/tests/integration/ssg_test.py b/tests/integration/ssg_test.py index 7f4f8b14b..d99211fa0 100644 --- a/tests/integration/ssg_test.py +++ b/tests/integration/ssg_test.py @@ -7,14 +7,14 @@ import solara from solara.server import settings -from solara.server.app import AppContext, get_current_context +from solara.server.kernel_context import VirtualKernelContext, get_current_context HERE = Path(__file__).parent text_ssg = "# SSG Test" text_live = "# Live render" -context: Optional[AppContext] = None +context: Optional[VirtualKernelContext] = None def set_value(x: str): diff --git a/tests/unit/app_test.py b/tests/unit/app_test.py index ae52b2789..2cb303d5f 100644 --- a/tests/unit/app_test.py +++ b/tests/unit/app_test.py @@ -20,11 +20,11 @@ reload.reloader.start() -def test_notebook_element(app_context, no_app_context): +def test_notebook_element(kernel_context, no_kernel_context): name = str(HERE / "solara_test_apps" / "notebookapp_element.ipynb") app = AppScript(name) try: - with app_context: + with kernel_context: el = app.run() assert isinstance(el, reacton.core.Element) el2 = app.run() @@ -33,11 +33,11 @@ def test_notebook_element(app_context, no_app_context): app.close() -def test_notebook_component(app_context, no_app_context): +def test_notebook_component(kernel_context, no_kernel_context): name = str(HERE / "solara_test_apps" / "notebookapp_component.ipynb") app = AppScript(name) try: - with app_context: + with kernel_context: el = app.run() assert isinstance(el, reacton.core.Element) el2 = app.run() @@ -46,11 +46,11 @@ def test_notebook_component(app_context, no_app_context): app.close() -def test_notebook_widget(app_context, no_app_context): +def test_notebook_widget(kernel_context, no_kernel_context): name = str(HERE / "solara_test_apps" / "notebookapp_widget.ipynb") app = AppScript(name) try: - with app_context: + with kernel_context: el = app.run() root = solara.RoutingProvider(children=[el], routes=app.routes, pathname="/") _box, rc = solara.render(root, handle_error=False) @@ -63,11 +63,11 @@ def test_notebook_widget(app_context, no_app_context): app.close() -def test_sidebar_single_file_multiple_routes(app_context, no_app_context): +def test_sidebar_single_file_multiple_routes(kernel_context, no_kernel_context): name = str(HERE / "solara_test_apps" / "single_file_multiple_routes.py") app = AppScript(name) try: - with app_context: + with kernel_context: c = app.run() root = solara.RoutingProvider(children=[c], routes=app.routes, pathname="/") box, rc = solara.render(root, handle_error=False) @@ -76,11 +76,11 @@ def test_sidebar_single_file_multiple_routes(app_context, no_app_context): app.close() -def test_sidebar_single_file(app_context, no_app_context): +def test_sidebar_single_file(kernel_context, no_kernel_context): name = str(HERE / "solara_test_apps" / "single_file.py") app = AppScript(name) try: - with app_context: + with kernel_context: c = app.run() root = solara.RoutingProvider(children=[c], routes=app.routes, pathname="/") box, rc = solara.render(root, handle_error=False) @@ -89,11 +89,11 @@ def test_sidebar_single_file(app_context, no_app_context): app.close() -def test_sidebar_single_file_missing(app_context, no_app_context): +def test_sidebar_single_file_missing(kernel_context, no_kernel_context): name = str(HERE / "solara_test_apps" / "single_file.py:doesnotexist") app = AppScript(name) try: - with app_context: + with kernel_context: c = app.run() root = solara.RoutingProvider(children=[c], routes=app.routes, pathname="/") box, rc = solara.render(root, handle_error=False) @@ -103,7 +103,7 @@ def test_sidebar_single_file_missing(app_context, no_app_context): # these make other test fail on CI (vaex is used, which causes a blake3 reload, which fails) -def test_watch_module_reload(tmpdir, app_context, extra_include_path, no_app_context): +def test_watch_module_reload(tmpdir, kernel_context, extra_include_path, no_kernel_context): import ipyvuetify as v with extra_include_path(str(tmpdir)): @@ -150,7 +150,7 @@ def test_watch_module_reload(tmpdir, app_context, extra_include_path, no_app_con reload.reloader.watched_modules.remove("somemod") -# def test_script_reload_component(tmpdir, app_context, extra_include_path, no_app_context): +# def test_script_reload_component(tmpdir, kernel_context, extra_include_path, no_kernel_context): # import ipyvuetify as v # with extra_include_path(str(tmpdir)): @@ -176,7 +176,7 @@ def test_watch_module_reload(tmpdir, app_context, extra_include_path, no_app_con # app.close() -# def test_watch_module_import_error(tmpdir, app_context, extra_include_path, no_app_context): +# def test_watch_module_import_error(tmpdir, kernel_context, extra_include_path, no_kernel_context): # import ipyvuetify as v # with extra_include_path(str(tmpdir)): diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 670deeb69..685e7c1e4 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,14 +1,15 @@ import pytest import solara.server.app +import solara.server.kernel_context from solara.server import kernel -from solara.server.app import AppContext +from solara.server.kernel_context import VirtualKernelContext @pytest.fixture(autouse=True) -def app_context(): +def kernel_context(): kernel_shared = kernel.Kernel() - context = AppContext(id="1", kernel=kernel_shared) + context = VirtualKernelContext(id="1", kernel=kernel_shared) try: with context: yield context @@ -18,10 +19,10 @@ def app_context(): @pytest.fixture() -def no_app_context(app_context): - context = solara.server.app.get_current_context() - solara.server.app.set_current_context(None) +def no_kernel_context(kernel_context): + context = solara.server.kernel_context.get_current_context() + solara.server.kernel_context.set_current_context(None) try: yield finally: - solara.server.app.set_current_context(context) + solara.server.kernel_context.set_current_context(context) diff --git a/tests/unit/no_solara_test.py b/tests/unit/no_solara_test.py index 40b540e89..2b2e245c1 100644 --- a/tests/unit/no_solara_test.py +++ b/tests/unit/no_solara_test.py @@ -6,16 +6,16 @@ # test if normal widget code works with no app context -def test_create_widget(no_app_context): +def test_create_widget(no_kernel_context): button = widgets.Button(description="Click me") button.layout.close() button.close() -def test_vue_template(no_app_context): +def test_vue_template(no_kernel_context): widget = solara.components.file_download.FileDownloadWidget() widget.close() -def test_display(no_app_context): +def test_display(no_kernel_context): display("test") diff --git a/tests/unit/output_widget_test.py b/tests/unit/output_widget_test.py index a07273384..6f061dbe3 100644 --- a/tests/unit/output_widget_test.py +++ b/tests/unit/output_widget_test.py @@ -3,18 +3,18 @@ import IPython.display import ipywidgets as widgets -from solara.server import app, kernel +from solara.server import kernel, kernel_context -def test_interactive_shell(no_app_context): +def test_interactive_shell(no_kernel_context): ws1 = Mock() ws2 = Mock() kernel1 = kernel.Kernel() kernel2 = kernel.Kernel() kernel1.session.websockets.add(ws1) kernel2.session.websockets.add(ws2) - context1 = app.AppContext(id="1", kernel=kernel1) - context2 = app.AppContext(id="2", kernel=kernel2) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1) + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2) with context1: output1 = widgets.Output() diff --git a/tests/unit/patch_test.py b/tests/unit/patch_test.py index 910ce0ecc..0970f61eb 100644 --- a/tests/unit/patch_test.py +++ b/tests/unit/patch_test.py @@ -3,26 +3,26 @@ import ipywidgets as widgets import pytest -from solara.server import app, kernel +from solara.server import kernel, kernel_context # with python 3.6 we don't use the comm package @pytest.mark.skipif(sys.version_info < (3, 7, 0), reason="ipykernel version too low") -def test_widget_error_message_outside_context(no_app_context): +def test_widget_error_message_outside_context(no_kernel_context): from ipyvuetify.Themes import theme theme.get_state() kernel_shared = kernel.Kernel() - context1 = app.AppContext(id="1", kernel=kernel_shared) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared) with pytest.raises(RuntimeError): with context1: assert theme.model_id -def test_widget_dict(no_app_context): +def test_widget_dict(no_kernel_context): kernel_shared = kernel.Kernel() - context1 = app.AppContext(id="1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) - context2 = app.AppContext(id="2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) with context1: btn1 = widgets.Button(description="context1") diff --git a/tests/unit/shell_test.py b/tests/unit/shell_test.py index c1de3e4ef..8839b76a2 100644 --- a/tests/unit/shell_test.py +++ b/tests/unit/shell_test.py @@ -2,18 +2,18 @@ import IPython.display -from solara.server import app, kernel +from solara.server import kernel, kernel_context -def test_shell(no_app_context): +def test_shell(no_kernel_context): ws1 = Mock() ws2 = Mock() kernel1 = kernel.Kernel() kernel2 = kernel.Kernel() kernel1.session.websockets.add(ws1) kernel2.session.websockets.add(ws2) - context1 = app.AppContext(id="1", kernel=kernel1) - context2 = app.AppContext(id="2", kernel=kernel2) + context1 = kernel_context.VirtualKernelContext(id="1", kernel=kernel1) + context2 = kernel_context.VirtualKernelContext(id="2", kernel=kernel2) with context1: IPython.display.display("test1") diff --git a/tests/unit/toestand_test.py b/tests/unit/toestand_test.py index 36d3cf68a..7a97f1fed 100644 --- a/tests/unit/toestand_test.py +++ b/tests/unit/toestand_test.py @@ -10,7 +10,7 @@ import solara import solara as sol import solara.lab -from solara.server import app, kernel +from solara.server import kernel, kernel_context from solara.toestand import Reactive, Ref, State, use_sync_external_store from .common import click @@ -112,17 +112,17 @@ def test_subscribe(): u() -def test_scopes(no_app_context): +def test_scopes(no_kernel_context): bear_store = BearReactive(bears) mock_global = unittest.mock.Mock() unsub = [] unsub += [bear_store.subscribe(mock_global)] kernel_shared = kernel.Kernel() - assert app.current_context[app.get_current_thread_key()] is None + assert kernel_context.current_context[kernel_context.get_current_thread_key()] is None - context1 = app.AppContext(id="toestand-1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) - context2 = app.AppContext(id="toestand-2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context1 = kernel_context.VirtualKernelContext(id="toestand-1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context2 = kernel_context.VirtualKernelContext(id="toestand-2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) with context1: mock1 = unittest.mock.Mock() @@ -378,8 +378,8 @@ def App(): # the storage should live in the app context to support multiple users/connections kernel_shared = kernel.Kernel() - context1 = app.AppContext(id="bear-1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) - context2 = app.AppContext(id="bear-2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context1 = kernel_context.VirtualKernelContext(id="bear-1", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) + context2 = kernel_context.VirtualKernelContext(id="bear-2", kernel=kernel_shared, control_sockets=[], widgets={}, templates={}) rcs = [] for context in [context1, context2]: with context: @@ -781,7 +781,7 @@ def test3(): t.join() -def test_reactive_auto_subscribe(app_context): +def test_reactive_auto_subscribe(kernel_context): x = Reactive(1) y = Reactive("hi") extra = Reactive("extra") @@ -815,10 +815,10 @@ def Main(): count.value = 2 assert len(rc.find(v.Slider)) == 2 - assert len(x._storage.listeners2[app_context.id]) == 2 + assert len(x._storage.listeners2[kernel_context.id]) == 2 x.value = 3 assert rc.find(v.Slider)[0].widget.v_model == 3 - assert len(x._storage.listeners2[app_context.id]) == 2 + assert len(x._storage.listeners2[kernel_context.id]) == 2 count.value = 1 assert len(rc.find(v.Slider)) == 1 @@ -826,8 +826,8 @@ def Main(): assert len(rc.find(v.Slider)) == 0 rc.close() - assert not x._storage.listeners[app_context.id] - assert not x._storage.listeners2[app_context.id] + assert not x._storage.listeners[kernel_context.id] + assert not x._storage.listeners2[kernel_context.id] def test_reactive_auto_subscribe_sub(): @@ -851,7 +851,7 @@ def Test(): assert renders == renders_before -def test_reactive_auto_subscribe_cleanup(app_context): +def test_reactive_auto_subscribe_cleanup(kernel_context): x = Reactive(1) y = Reactive("hi") renders = 0 @@ -874,22 +874,22 @@ def Test(): assert len(y._storage.listeners2) == 0 x.value = 42 assert renders == 2 - assert len(x._storage.listeners2[app_context.id]) == 1 - assert len(y._storage.listeners2[app_context.id]) == 1 + assert len(x._storage.listeners2[kernel_context.id]) == 1 + assert len(y._storage.listeners2[kernel_context.id]) == 1 # this triggers two renders, where during the first one we use y, but the seconds we don't x.value = 0 assert rc.find(v.Slider).widget.v_model == 100 - assert len(x._storage.listeners2[app_context.id]) == 1 + assert len(x._storage.listeners2[kernel_context.id]) == 1 # which means we shouldn't have a listener on y - assert len(y._storage.listeners2[app_context.id]) == 0 + assert len(y._storage.listeners2[kernel_context.id]) == 0 rc.close() - assert not x._storage.listeners[app_context.id] - assert not y._storage.listeners2[app_context.id] + assert not x._storage.listeners[kernel_context.id] + assert not y._storage.listeners2[kernel_context.id] -def test_reactive_auto_subscribe_subfield_limit(app_context): +def test_reactive_auto_subscribe_subfield_limit(kernel_context): bears = Reactive(Bears(type="brown", count=1)) renders = 0 @@ -906,8 +906,8 @@ def Test(): Ref(bears.fields.count).value = 2 assert renders == 2 rc.close() - assert not bears._storage.listeners[app_context.id] - assert not bears._storage.listeners2[app_context.id] + assert not bears._storage.listeners[kernel_context.id] + assert not bears._storage.listeners2[kernel_context.id] def test_reactive_batch_update():