diff --git a/solara/server/kernel_context.py b/solara/server/kernel_context.py index d4227a508..9b39464a7 100644 --- a/solara/server/kernel_context.py +++ b/solara/server/kernel_context.py @@ -1,4 +1,10 @@ import asyncio + +try: + import contextvars +except ModuleNotFoundError: + contextvars = None # type: ignore + import dataclasses import enum import logging @@ -6,6 +12,7 @@ import pickle import threading import time +import typing from pathlib import Path from typing import Any, Callable, Dict, List, Optional, cast @@ -255,12 +262,35 @@ def create_dummy_context(): return kernel_context +if contextvars is not None: + if typing.TYPE_CHECKING: + async_context_id = contextvars.ContextVar[str]("async_context_id") + else: + async_context_id = contextvars.ContextVar("async_context_id") + async_context_id.set("default") +else: + async_context_id = None + + def get_current_thread_key() -> str: - thread = threading.current_thread() - return get_thread_key(thread) + if not solara.server.settings.kernel.threaded: + if async_context_id is not None: + try: + key = async_context_id.get() + except LookupError: + raise RuntimeError("no kernel context set") + else: + raise RuntimeError("No threading support, and no contextvars support (Python 3.6 is not supported for this)") + else: + thread = threading.current_thread() + key = get_thread_key(thread) + return key def get_thread_key(thread: threading.Thread) -> str: + if not solara.server.settings.kernel.threaded: + if async_context_id is not None: + return async_context_id.get() thread_key = thread._name + str(thread._ident) # type: ignore return thread_key @@ -318,6 +348,7 @@ def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websoc widgets.register_comm_target(kernel) appmodule.register_solara_comm_target(kernel) with context: + assert has_current_context() assert kernel is Kernel.instance() kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell") kernel.control_stream = WebsocketStreamWrapper(websocket, "control") diff --git a/solara/server/patch.py b/solara/server/patch.py index d81211c3e..f54e616d7 100644 --- a/solara/server/patch.py +++ b/solara/server/patch.py @@ -295,7 +295,10 @@ def _WidgetContextAwareThread__bootstrap(self): # we need to call this manually, because set_context_for_thread # uses this, and the original _bootstrap calls it too late for us self._set_ident() + if kernel_context.async_context_id is not None: + kernel_context.async_context_id.set(self.current_context.id) kernel_context.set_context_for_thread(self.current_context, self) + shell = self.current_context.kernel.shell shell.display_pub.register_hook(shell.display_in_reacton_hook) try: diff --git a/solara/server/settings.py b/solara/server/settings.py index f39b822b2..9d4007603 100644 --- a/solara/server/settings.py +++ b/solara/server/settings.py @@ -86,6 +86,7 @@ class Config: class Kernel(BaseSettings): cull_timeout: str = "24h" max_count: Optional[int] = None + threaded: bool = solara.util.has_threads class Config: env_prefix = "solara_kernel_" diff --git a/solara/server/starlette.py b/solara/server/starlette.py index a8e7a026d..6fa77b81a 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -6,7 +6,7 @@ import sys import threading import typing -from typing import Any, Dict, List, Optional, Union, cast +from typing import Any, Dict, List, Optional, Set, Union, cast from uuid import uuid4 import anyio @@ -96,10 +96,14 @@ class WebsocketDebugInfo: class WebsocketWrapper(websocket.WebsocketWrapper): ws: starlette.websockets.WebSocket - def __init__(self, ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal) -> None: + def __init__(self, ws: starlette.websockets.WebSocket, portal: Optional[anyio.from_thread.BlockingPortal]) -> None: self.ws = ws self.portal = portal self.to_send: List[Union[str, bytes]] = [] + # following https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task + # we store a strong reference + self.tasks: Set[asyncio.Task] = set() + self.event_loop = asyncio.get_event_loop() if settings.main.experimental_performance: self.task = asyncio.ensure_future(self.process_messages_task()) @@ -114,28 +118,44 @@ async def process_messages_task(self): await self.ws.send_text(first) def close(self): - self.portal.call(self.ws.close) # type: ignore + if self.portal is None: + asyncio.ensure_future(self.ws.close()) + else: + self.portal.call(self.ws.close) # type: ignore def send_text(self, data: str) -> None: - if settings.main.experimental_performance: - self.to_send.append(data) + if self.portal is None: + task = self.event_loop.create_task(self.ws.send_text(data)) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) else: - self.portal.call(self.ws.send_bytes, data) # type: ignore + if settings.main.experimental_performance: + self.to_send.append(data) + else: + self.portal.call(self.ws.send_bytes, data) # type: ignore def send_bytes(self, data: bytes) -> None: - if settings.main.experimental_performance: - self.to_send.append(data) + if self.portal is None: + task = self.event_loop.create_task(self.ws.send_bytes(data)) + self.tasks.add(task) + task.add_done_callback(self.tasks.discard) else: - self.portal.call(self.ws.send_bytes, data) # type: ignore + if settings.main.experimental_performance: + self.to_send.append(data) + else: + self.portal.call(self.ws.send_bytes, data) # type: ignore async def receive(self): - if hasattr(self.portal, "start_task_soon"): - # version 3+ - fut = self.portal.start_task_soon(self.ws.receive) # type: ignore + if self.portal is None: + message = await asyncio.ensure_future(self.ws.receive()) else: - fut = self.portal.spawn_task(self.ws.receive) # type: ignore + if hasattr(self.portal, "start_task_soon"): + # version 3+ + fut = self.portal.start_task_soon(self.ws.receive) # type: ignore + else: + fut = self.portal.spawn_task(self.ws.receive) # type: ignore - message = await asyncio.wrap_future(fut) + message = await asyncio.wrap_future(fut) if "text" in message: return message["text"] elif "bytes" in message: @@ -237,35 +257,45 @@ async def _kernel_connection(ws: starlette.websockets.WebSocket): WebsocketDebugInfo.connecting -= 1 WebsocketDebugInfo.open += 1 - def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal): - async def run(): + async def run(ws_wrapper: WebsocketWrapper): + if kernel_context.async_context_id is not None: + kernel_context.async_context_id.set(uuid4().hex) + assert session_id is not None + assert kernel_id is not None + telemetry.connection_open(session_id) + headers_dict: Dict[str, List[str]] = {} + for k, v in ws.headers.items(): + if k not in headers_dict.keys(): + headers_dict[k] = [v] + else: + headers_dict[k].append(v) + await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user) + + def websocket_thread_runner(ws_wrapper: WebsocketWrapper, portal: anyio.from_thread.BlockingPortal): + async def run_wrapper(): try: - assert session_id is not None - assert kernel_id is not None - telemetry.connection_open(session_id) - headers_dict: Dict[str, List[str]] = {} - for k, v in ws.headers.items(): - if k not in headers_dict.keys(): - headers_dict[k] = [v] - else: - headers_dict[k].append(v) - await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user) + await run(ws_wrapper) except: # noqa - await portal.stop(cancel_remaining=True) + if portal is not None: + await portal.stop(cancel_remaining=True) raise finally: telemetry.connection_close(session_id) # sometimes throws: RuntimeError: Already running asyncio in this thread - anyio.run(run) # type: ignore + anyio.run(run_wrapper) # type: ignore # this portal allows us to sync call the websocket calls from this current event loop we are in # each websocket however, is handled from a separate thread try: - async with anyio.from_thread.BlockingPortal() as portal: - ws_wrapper = WebsocketWrapper(ws, portal) - thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws, portal, limiter=limiter) # type: ignore - await thread_return + if settings.kernel.threaded: + async with anyio.from_thread.BlockingPortal() as portal: + ws_wrapper = WebsocketWrapper(ws, portal) + thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws_wrapper, portal, limiter=limiter) # type: ignore + await thread_return + else: + ws_wrapper = WebsocketWrapper(ws, None) + await run(ws_wrapper) finally: if settings.main.experimental_performance: try: diff --git a/solara/tasks.py b/solara/tasks.py index 45949d8da..8aedcddb5 100644 --- a/solara/tasks.py +++ b/solara/tasks.py @@ -32,12 +32,7 @@ logger = logging.getLogger("solara.task") -try: - threading.Thread(target=lambda: None).start() - has_threads = True -except RuntimeError: - has_threads = False -has_threads +has_threads = solara.util.has_threads class TaskState(Enum): diff --git a/solara/util.py b/solara/util.py index 84f32c826..51928f79b 100644 --- a/solara/util.py +++ b/solara/util.py @@ -21,6 +21,12 @@ ipyvuetify_major_version = int(ipyvuetify.__version__.split(".")[0]) ipywidgets_major = int(ipywidgets.__version__.split(".")[0]) +try: + threading.Thread(target=lambda: None).start() + has_threads = True +except RuntimeError: + has_threads = False + def github_url(file): rel_path = os.path.relpath(file, Path(solara.__file__).parent.parent) diff --git a/tests/integration/server_test.py b/tests/integration/server_test.py index 8816287f1..b3a5b7e69 100644 --- a/tests/integration/server_test.py +++ b/tests/integration/server_test.py @@ -155,3 +155,40 @@ def test_run_in_iframe(page_session: playwright.sync_api.Page, solara_server, so iframe = page_session.frame("main") el = iframe.locator(".jupyter-widgets") assert el.text_content() == "Hello world" + + +@solara.component +def ClickTaskButton(): + count = solara.use_reactive(0) + + @solara.lab.use_task(dependencies=None) + def on_click(): + count.value += 1 + + return solara.Button(f"Clicked: {count}", on_click=on_click) + + +def test_kernel_asyncio(browser: playwright.sync_api.Browser, solara_server, solara_app, extra_include_path): + # ClickTaskButton also tests the use of tasks + try: + threaded = solara.server.settings.kernel.threaded + solara.server.settings.kernel.threaded = False + with extra_include_path(HERE), solara_app("server_test:ClickTaskButton"): + context1 = browser.new_context() + page1 = context1.new_page() + page1.goto(solara_server.base_url) + page1.locator("text=Clicked: 0").click() + page1.locator("text=Clicked: 1").click() + context2 = browser.new_context() + page2 = context2.new_page() + page2.goto(solara_server.base_url) + page2.locator("text=Clicked: 0").click() + page2.locator("text=Clicked: 1").click() + page1.locator("text=Clicked: 2").wait_for() + page2.locator("text=Clicked: 2").wait_for() + finally: + page1.close() + page2.close() + context1.close() + context2.close() + solara.server.settings.kernel.threaded = threaded