From 2a16fb450acd0b3721dab78077d3950dd50593b2 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Fri, 22 Mar 2024 12:54:52 +0100 Subject: [PATCH] feat: make the server (starlette) work without threads for pyodide in pyodide (pycafe) we cannot use threads. We currently have workarounds in pycafe, but it would be easier to just not use threads in the server if they are not available. --- solara/server/kernel_context.py | 35 +++++++++++- solara/server/patch.py | 3 + solara/server/settings.py | 1 + solara/server/starlette.py | 94 +++++++++++++++++++++----------- solara/tasks.py | 7 +-- solara/util.py | 6 ++ tests/integration/server_test.py | 37 +++++++++++++ 7 files changed, 143 insertions(+), 40 deletions(-) diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index d4227a508..9b39464a7 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,4 +1,10 @@ import asyncio + +try: + import contextvars +except ModuleNotFoundError: + contextvars = None # type: ignore + import dataclasses import enum import logging @@ -6,6 +12,7 @@ import pickle import threading import time +import typing from pathlib import Path from typing import Any, Callable, Dict, List, Optional, cast @@ -255,12 +262,35 @@ def create_dummy_context(): return kernel_context +if contextvars is not None: + if typing.TYPE_CHECKING: + async_context_id = contextvars.ContextVar[str]("async_context_id") + else: + async_context_id = contextvars.ContextVar("async_context_id") + async_context_id.set("default") +else: + async_context_id = None + + def get_current_thread_key() -> str: - thread = threading.current_thread() - return get_thread_key(thread) + if not solara.server.settings.kernel.threaded: + if async_context_id is not None: + try: + key = async_context_id.get() + except LookupError: + raise RuntimeError("no kernel context set") + else: + raise RuntimeError("No threading support, and no contextvars support (Python 3.6 is not supported for this)") + else: + thread = threading.current_thread() + key = get_thread_key(thread) + return key def get_thread_key(thread: threading.Thread) -> str: + if not solara.server.settings.kernel.threaded: + if async_context_id is not None: + return async_context_id.get() thread_key = thread._name + str(thread._ident) # type: ignore return thread_key @@ -318,6 +348,7 @@ def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websoc widgets.register_comm_target(kernel) appmodule.register_solara_comm_target(kernel) with context: + assert has_current_context() assert kernel is Kernel.instance() kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell") kernel.control_stream = WebsocketStreamWrapper(websocket, "control") diff --git a/solara/server/patch.py b/solara/server/patch.py index d81211c3e..f54e616d7 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -295,7 +295,10 @@ def _WidgetContextAwareThread__bootstrap(self): # we need to call this manually, because set_context_for_thread # uses this, and the original _bootstrap calls it too late for us self._set_ident() + if kernel_context.async_context_id is not None: + kernel_context.async_context_id.set(self.current_context.id) kernel_context.set_context_for_thread(self.current_context, self) + shell = self.current_context.kernel.shell shell.display_pub.register_hook(shell.display_in_reacton_hook) try: diff --git a/solara/server/settings.py b/solara/server/settings.py index f39b822b2..9d4007603 100644 --- a/solara/server/settings.py +++ b/solara/server/settings.py @@ -86,6 +86,7 @@ class Config: class Kernel(BaseSettings): cull_timeout: str = "24h" max_count: Optional[int] = None + threaded: bool = solara.util.has_threads class Config: env_prefix = "solara_kernel_" diff --git a/solara/server/starlette.py b/solara/server/starlette.py index a8e7a026d..6fa77b81a 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -6,7 +6,7 @@ import sys import threading import typing -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Set, Union, cast from uuid import uuid4 import anyio @@ -96,10 +96,14 @@ class WebsocketDebugInfo: class WebsocketWrapper(websocket.WebsocketWrapper): ws: starlette.websockets.WebSocket - def __init__(self, ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal) -> None: + def __init__(self, ws: starlette.websockets.WebSocket, portal: Optional[anyio.from_thread.BlockingPortal]) -> None: self.ws = ws self.portal = portal self.to_send: List[Union[str, bytes]] = [] + # following https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + # we store a strong reference + self.tasks: Set[asyncio.Task] = set() + self.event_loop = asyncio.get_event_loop() if settings.main.experimental_performance: self.task = asyncio.ensure_future(self.process_messages_task()) @@ -114,28 +118,44 @@ async def process_messages_task(self): await self.ws.send_text(first) def close(self): - self.portal.call(self.ws.close) # type: ignore + if self.portal is None: + asyncio.ensure_future(self.ws.close()) + else: + self.portal.call(self.ws.close) # type: ignore def send_text(self, data: str) -> None: - if settings.main.experimental_performance: - self.to_send.append(data) + if self.portal is None: + task = self.event_loop.create_task(self.ws.send_text(data)) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) else: - self.portal.call(self.ws.send_bytes, data) # type: ignore + if settings.main.experimental_performance: + self.to_send.append(data) + else: + self.portal.call(self.ws.send_bytes, data) # type: ignore def send_bytes(self, data: bytes) -> None: - if settings.main.experimental_performance: - self.to_send.append(data) + if self.portal is None: + task = self.event_loop.create_task(self.ws.send_bytes(data)) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) else: - self.portal.call(self.ws.send_bytes, data) # type: ignore + if settings.main.experimental_performance: + self.to_send.append(data) + else: + self.portal.call(self.ws.send_bytes, data) # type: ignore async def receive(self): - if hasattr(self.portal, "start_task_soon"): - # version 3+ - fut = self.portal.start_task_soon(self.ws.receive) # type: ignore + if self.portal is None: + message = await asyncio.ensure_future(self.ws.receive()) else: - fut = self.portal.spawn_task(self.ws.receive) # type: ignore + if hasattr(self.portal, "start_task_soon"): + # version 3+ + fut = self.portal.start_task_soon(self.ws.receive) # type: ignore + else: + fut = self.portal.spawn_task(self.ws.receive) # type: ignore - message = await asyncio.wrap_future(fut) + message = await asyncio.wrap_future(fut) if "text" in message: return message["text"] elif "bytes" in message: @@ -237,35 +257,45 @@ async def _kernel_connection(ws: starlette.websockets.WebSocket): WebsocketDebugInfo.connecting -= 1 WebsocketDebugInfo.open += 1 - def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal): - async def run(): + async def run(ws_wrapper: WebsocketWrapper): + if kernel_context.async_context_id is not None: + kernel_context.async_context_id.set(uuid4().hex) + assert session_id is not None + assert kernel_id is not None + telemetry.connection_open(session_id) + headers_dict: Dict[str, List[str]] = {} + for k, v in ws.headers.items(): + if k not in headers_dict.keys(): + headers_dict[k] = [v] + else: + headers_dict[k].append(v) + await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user) + + def websocket_thread_runner(ws_wrapper: WebsocketWrapper, portal: anyio.from_thread.BlockingPortal): + async def run_wrapper(): try: - assert session_id is not None - assert kernel_id is not None - telemetry.connection_open(session_id) - headers_dict: Dict[str, List[str]] = {} - for k, v in ws.headers.items(): - if k not in headers_dict.keys(): - headers_dict[k] = [v] - else: - headers_dict[k].append(v) - await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user) + await run(ws_wrapper) except: # noqa - await portal.stop(cancel_remaining=True) + if portal is not None: + await portal.stop(cancel_remaining=True) raise finally: telemetry.connection_close(session_id) # sometimes throws: RuntimeError: Already running asyncio in this thread - anyio.run(run) # type: ignore + anyio.run(run_wrapper) # type: ignore # this portal allows us to sync call the websocket calls from this current event loop we are in # each websocket however, is handled from a separate thread try: - async with anyio.from_thread.BlockingPortal() as portal: - ws_wrapper = WebsocketWrapper(ws, portal) - thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws, portal, limiter=limiter) # type: ignore - await thread_return + if settings.kernel.threaded: + async with anyio.from_thread.BlockingPortal() as portal: + ws_wrapper = WebsocketWrapper(ws, portal) + thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws_wrapper, portal, limiter=limiter) # type: ignore + await thread_return + else: + ws_wrapper = WebsocketWrapper(ws, None) + await run(ws_wrapper) finally: if settings.main.experimental_performance: try: diff --git a/solara/tasks.py b/solara/tasks.py index 45949d8da..8aedcddb5 100644 --- a/solara/tasks.py +++ b/solara/tasks.py @@ -32,12 +32,7 @@ logger = logging.getLogger("solara.task") -try: - threading.Thread(target=lambda: None).start() - has_threads = True -except RuntimeError: - has_threads = False -has_threads +has_threads = solara.util.has_threads class TaskState(Enum): diff --git a/solara/util.py b/solara/util.py index 84f32c826..51928f79b 100644 --- a/solara/util.py +++ b/solara/util.py @@ -21,6 +21,12 @@ ipyvuetify_major_version = int(ipyvuetify.__version__.split(".")[0]) ipywidgets_major = int(ipywidgets.__version__.split(".")[0]) +try: + threading.Thread(target=lambda: None).start() + has_threads = True +except RuntimeError: + has_threads = False + def github_url(file): rel_path = os.path.relpath(file, Path(solara.__file__).parent.parent) diff --git a/tests/integration/server_test.py b/tests/integration/server_test.py index 8816287f1..b3a5b7e69 100644 --- a/tests/integration/server_test.py +++ b/tests/integration/server_test.py @@ -155,3 +155,40 @@ def test_run_in_iframe(page_session: playwright.sync_api.Page, solara_server, so iframe = page_session.frame("main") el = iframe.locator(".jupyter-widgets") assert el.text_content() == "Hello world" + + +@solara.component +def ClickTaskButton(): + count = solara.use_reactive(0) + + @solara.lab.use_task(dependencies=None) + def on_click(): + count.value += 1 + + return solara.Button(f"Clicked: {count}", on_click=on_click) + + +def test_kernel_asyncio(browser: playwright.sync_api.Browser, solara_server, solara_app, extra_include_path): + # ClickTaskButton also tests the use of tasks + try: + threaded = solara.server.settings.kernel.threaded + solara.server.settings.kernel.threaded = False + with extra_include_path(HERE), solara_app("server_test:ClickTaskButton"): + context1 = browser.new_context() + page1 = context1.new_page() + page1.goto(solara_server.base_url) + page1.locator("text=Clicked: 0").click() + page1.locator("text=Clicked: 1").click() + context2 = browser.new_context() + page2 = context2.new_page() + page2.goto(solara_server.base_url) + page2.locator("text=Clicked: 0").click() + page2.locator("text=Clicked: 1").click() + page1.locator("text=Clicked: 2").wait_for() + page2.locator("text=Clicked: 2").wait_for() + finally: + page1.close() + page2.close() + context1.close() + context2.close() + solara.server.settings.kernel.threaded = threaded