Skip to content

Commit

Permalink
feat: make the server (starlette) work without threads for pyodide (#569
Browse files Browse the repository at this point in the history
)

in pyodide (pycafe) we cannot use threads. We currently have workarounds
in pycafe, but it would be easier to just not use threads in the server
if they are not available.
  • Loading branch information
maartenbreddels authored Mar 25, 2024
1 parent 14cbd50 commit f2f06bc
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 40 deletions.
35 changes: 33 additions & 2 deletions solara/server/kernel_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import asyncio

try:
import contextvars
except ModuleNotFoundError:
contextvars = None # type: ignore

import dataclasses
import enum
import inspect
Expand All @@ -7,6 +13,7 @@
import pickle
import threading
import time
import typing
from pathlib import Path
from types import FrameType, ModuleType
from typing import Any, Callable, Dict, List, NamedTuple, Optional, cast
Expand Down Expand Up @@ -292,12 +299,35 @@ def create_dummy_context():
return kernel_context


if contextvars is not None:
if typing.TYPE_CHECKING:
async_context_id = contextvars.ContextVar[str]("async_context_id")
else:
async_context_id = contextvars.ContextVar("async_context_id")
async_context_id.set("default")
else:
async_context_id = None


def get_current_thread_key() -> str:
thread = threading.current_thread()
return get_thread_key(thread)
if not solara.server.settings.kernel.threaded:
if async_context_id is not None:
try:
key = async_context_id.get()
except LookupError:
raise RuntimeError("no kernel context set")
else:
raise RuntimeError("No threading support, and no contextvars support (Python 3.6 is not supported for this)")
else:
thread = threading.current_thread()
key = get_thread_key(thread)
return key


def get_thread_key(thread: threading.Thread) -> str:
if not solara.server.settings.kernel.threaded:
if async_context_id is not None:
return async_context_id.get()
thread_key = thread._name + str(thread._ident) # type: ignore
return thread_key

Expand Down Expand Up @@ -355,6 +385,7 @@ def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websoc
widgets.register_comm_target(kernel)
appmodule.register_solara_comm_target(kernel)
with context:
assert has_current_context()
assert kernel is Kernel.instance()
kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell")
kernel.control_stream = WebsocketStreamWrapper(websocket, "control")
Expand Down
3 changes: 3 additions & 0 deletions solara/server/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,10 @@ def _WidgetContextAwareThread__bootstrap(self):
# we need to call this manually, because set_context_for_thread
# uses this, and the original _bootstrap calls it too late for us
self._set_ident()
if kernel_context.async_context_id is not None:
kernel_context.async_context_id.set(self.current_context.id)
kernel_context.set_context_for_thread(self.current_context, self)

shell = self.current_context.kernel.shell
shell.display_pub.register_hook(shell.display_in_reacton_hook)
try:
Expand Down
1 change: 1 addition & 0 deletions solara/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class Config:
class Kernel(BaseSettings):
cull_timeout: str = "24h"
max_count: Optional[int] = None
threaded: bool = solara.util.has_threads

class Config:
env_prefix = "solara_kernel_"
Expand Down
94 changes: 62 additions & 32 deletions solara/server/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import sys
import threading
import typing
from typing import Any, Dict, List, Optional, Union, cast
from typing import Any, Dict, List, Optional, Set, Union, cast
from uuid import uuid4

import anyio
Expand Down Expand Up @@ -96,10 +96,14 @@ class WebsocketDebugInfo:
class WebsocketWrapper(websocket.WebsocketWrapper):
ws: starlette.websockets.WebSocket

def __init__(self, ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal) -> None:
def __init__(self, ws: starlette.websockets.WebSocket, portal: Optional[anyio.from_thread.BlockingPortal]) -> None:
self.ws = ws
self.portal = portal
self.to_send: List[Union[str, bytes]] = []
# following https://docs.python.org/3/library/asyncio-task.html#asyncio.create_task
# we store a strong reference
self.tasks: Set[asyncio.Task] = set()
self.event_loop = asyncio.get_event_loop()
if settings.main.experimental_performance:
self.task = asyncio.ensure_future(self.process_messages_task())

Expand All @@ -114,28 +118,44 @@ async def process_messages_task(self):
await self.ws.send_text(first)

def close(self):
self.portal.call(self.ws.close) # type: ignore
if self.portal is None:
asyncio.ensure_future(self.ws.close())
else:
self.portal.call(self.ws.close) # type: ignore

def send_text(self, data: str) -> None:
if settings.main.experimental_performance:
self.to_send.append(data)
if self.portal is None:
task = self.event_loop.create_task(self.ws.send_text(data))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
else:
self.portal.call(self.ws.send_bytes, data) # type: ignore
if settings.main.experimental_performance:
self.to_send.append(data)
else:
self.portal.call(self.ws.send_bytes, data) # type: ignore

def send_bytes(self, data: bytes) -> None:
if settings.main.experimental_performance:
self.to_send.append(data)
if self.portal is None:
task = self.event_loop.create_task(self.ws.send_bytes(data))
self.tasks.add(task)
task.add_done_callback(self.tasks.discard)
else:
self.portal.call(self.ws.send_bytes, data) # type: ignore
if settings.main.experimental_performance:
self.to_send.append(data)
else:
self.portal.call(self.ws.send_bytes, data) # type: ignore

async def receive(self):
if hasattr(self.portal, "start_task_soon"):
# version 3+
fut = self.portal.start_task_soon(self.ws.receive) # type: ignore
if self.portal is None:
message = await asyncio.ensure_future(self.ws.receive())
else:
fut = self.portal.spawn_task(self.ws.receive) # type: ignore
if hasattr(self.portal, "start_task_soon"):
# version 3+
fut = self.portal.start_task_soon(self.ws.receive) # type: ignore
else:
fut = self.portal.spawn_task(self.ws.receive) # type: ignore

message = await asyncio.wrap_future(fut)
message = await asyncio.wrap_future(fut)
if "text" in message:
return message["text"]
elif "bytes" in message:
Expand Down Expand Up @@ -237,35 +257,45 @@ async def _kernel_connection(ws: starlette.websockets.WebSocket):
WebsocketDebugInfo.connecting -= 1
WebsocketDebugInfo.open += 1

def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal):
async def run():
async def run(ws_wrapper: WebsocketWrapper):
if kernel_context.async_context_id is not None:
kernel_context.async_context_id.set(uuid4().hex)
assert session_id is not None
assert kernel_id is not None
telemetry.connection_open(session_id)
headers_dict: Dict[str, List[str]] = {}
for k, v in ws.headers.items():
if k not in headers_dict.keys():
headers_dict[k] = [v]
else:
headers_dict[k].append(v)
await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user)

def websocket_thread_runner(ws_wrapper: WebsocketWrapper, portal: anyio.from_thread.BlockingPortal):
async def run_wrapper():
try:
assert session_id is not None
assert kernel_id is not None
telemetry.connection_open(session_id)
headers_dict: Dict[str, List[str]] = {}
for k, v in ws.headers.items():
if k not in headers_dict.keys():
headers_dict[k] = [v]
else:
headers_dict[k].append(v)
await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user)
await run(ws_wrapper)
except: # noqa
await portal.stop(cancel_remaining=True)
if portal is not None:
await portal.stop(cancel_remaining=True)
raise
finally:
telemetry.connection_close(session_id)

# sometimes throws: RuntimeError: Already running asyncio in this thread
anyio.run(run) # type: ignore
anyio.run(run_wrapper) # type: ignore

# this portal allows us to sync call the websocket calls from this current event loop we are in
# each websocket however, is handled from a separate thread
try:
async with anyio.from_thread.BlockingPortal() as portal:
ws_wrapper = WebsocketWrapper(ws, portal)
thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws, portal, limiter=limiter) # type: ignore
await thread_return
if settings.kernel.threaded:
async with anyio.from_thread.BlockingPortal() as portal:
ws_wrapper = WebsocketWrapper(ws, portal)
thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws_wrapper, portal, limiter=limiter) # type: ignore
await thread_return
else:
ws_wrapper = WebsocketWrapper(ws, None)
await run(ws_wrapper)
finally:
if settings.main.experimental_performance:
try:
Expand Down
7 changes: 1 addition & 6 deletions solara/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@

logger = logging.getLogger("solara.task")

try:
threading.Thread(target=lambda: None).start()
has_threads = True
except RuntimeError:
has_threads = False
has_threads
has_threads = solara.util.has_threads


class TaskState(Enum):
Expand Down
6 changes: 6 additions & 0 deletions solara/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
ipyvuetify_major_version = int(ipyvuetify.__version__.split(".")[0])
ipywidgets_major = int(ipywidgets.__version__.split(".")[0])

try:
threading.Thread(target=lambda: None).start()
has_threads = True
except RuntimeError:
has_threads = False


def github_url(file):
rel_path = os.path.relpath(file, Path(solara.__file__).parent.parent)
Expand Down
41 changes: 41 additions & 0 deletions tests/integration/server_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import playwright
import playwright.sync_api
import pytest
import reacton.ipywidgets as w

import solara
Expand Down Expand Up @@ -154,3 +155,43 @@ def test_run_in_iframe(page_session: playwright.sync_api.Page, solara_server, so
iframe = page_session.frame("main")
el = iframe.locator(".jupyter-widgets")
assert el.text_content() == "Hello world"


@solara.component
def ClickTaskButton():
count = solara.use_reactive(0)

@solara.lab.use_task(dependencies=None)
def on_click():
count.value += 1

return solara.Button(f"Clicked: {count}", on_click=on_click)


def test_kernel_asyncio(browser: playwright.sync_api.Browser, solara_server, solara_app, extra_include_path, request):
if request.node.callspec.params["solara_server"] != "starlette":
pytest.skip("Async is only supported on starlette.")
return
# ClickTaskButton also tests the use of tasks
try:
threaded = solara.server.settings.kernel.threaded
solara.server.settings.kernel.threaded = False
with extra_include_path(HERE), solara_app("server_test:ClickTaskButton"):
context1 = browser.new_context()
page1 = context1.new_page()
page1.goto(solara_server.base_url)
page1.locator("text=Clicked: 0").click()
page1.locator("text=Clicked: 1").click()
context2 = browser.new_context()
page2 = context2.new_page()
page2.goto(solara_server.base_url)
page2.locator("text=Clicked: 0").click()
page2.locator("text=Clicked: 1").click()
page1.locator("text=Clicked: 2").wait_for()
page2.locator("text=Clicked: 2").wait_for()
finally:
page1.close()
page2.close()
context1.close()
context2.close()
solara.server.settings.kernel.threaded = threaded

0 comments on commit f2f06bc

Please sign in to comment.