Skip to content

Commit

Permalink
feat: make the server (starlette) work without threads for pyodide
Browse files Browse the repository at this point in the history
in pyodide (pycafe) we cannot use threads. We currently have workarounds
in pycafe, but it would be easier to just not use threads in the server
if they are not available.
  • Loading branch information
maartenbreddels committed Mar 22, 2024
1 parent 0c33e33 commit f98e884
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 32 deletions.
28 changes: 26 additions & 2 deletions solara/server/kernel_context.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
import asyncio

try:
import contextvars
except ModuleNotFoundError:
contextvars = None

import dataclasses
import enum
import logging
Expand Down Expand Up @@ -255,9 +261,26 @@ def create_dummy_context():
return kernel_context


if contextvars is not None:
async_context_id = contextvars.ContextVar[str]("async_context_id")
async_context_id.set("default")
else:
async_context_id = None


def get_current_thread_key() -> str:
thread = threading.current_thread()
return get_thread_key(thread)
if not solara.util.has_threads:
if async_context_id is not None:
try:
key = async_context_id.get()
except LookupError:
raise RuntimeError("no kernel context set")
else:
raise RuntimeError("No threading support, and no contextvars support (Python 3.6 is not supported for this)")
else:
thread = threading.current_thread()
key = get_thread_key(thread)
return key


def get_thread_key(thread: threading.Thread) -> str:
Expand Down Expand Up @@ -318,6 +341,7 @@ def initialize_virtual_kernel(session_id: str, kernel_id: str, websocket: websoc
widgets.register_comm_target(kernel)
appmodule.register_solara_comm_target(kernel)
with context:
assert has_current_context()
assert kernel is Kernel.instance()
kernel.shell_stream = WebsocketStreamWrapper(websocket, "shell")
kernel.control_stream = WebsocketStreamWrapper(websocket, "control")
Expand Down
68 changes: 44 additions & 24 deletions solara/server/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class WebsocketDebugInfo:
class WebsocketWrapper(websocket.WebsocketWrapper):
ws: starlette.websockets.WebSocket

def __init__(self, ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal) -> None:
def __init__(self, ws: starlette.websockets.WebSocket, portal: Optional[anyio.from_thread.BlockingPortal]) -> None:
self.ws = ws
self.portal = portal
self.to_send: List[Union[str, bytes]] = []
Expand All @@ -114,28 +114,38 @@ async def process_messages_task(self):
await self.ws.send_text(first)

def close(self):
self.portal.call(self.ws.close) # type: ignore
if self.portal is None:
asyncio.ensure_future(self.ws.close())
else:
self.portal.call(self.ws.close) # type: ignore

def send_text(self, data: str) -> None:
if self.portal is None:
asyncio.ensure_future(self.ws.send_text(data))
if settings.main.experimental_performance:
self.to_send.append(data)
else:
self.portal.call(self.ws.send_bytes, data) # type: ignore

def send_bytes(self, data: bytes) -> None:
if self.portal is None:
asyncio.ensure_future(self.ws.send_bytes(data))
if settings.main.experimental_performance:
self.to_send.append(data)
else:
self.portal.call(self.ws.send_bytes, data) # type: ignore

async def receive(self):
if hasattr(self.portal, "start_task_soon"):
# version 3+
fut = self.portal.start_task_soon(self.ws.receive) # type: ignore
if self.portal is None:
message = await self.ws.receive()
else:
fut = self.portal.spawn_task(self.ws.receive) # type: ignore
if hasattr(self.portal, "start_task_soon"):
# version 3+
fut = self.portal.start_task_soon(self.ws.receive) # type: ignore
else:
fut = self.portal.spawn_task(self.ws.receive) # type: ignore

message = await asyncio.wrap_future(fut)
message = await asyncio.wrap_future(fut)
if "text" in message:
return message["text"]
elif "bytes" in message:
Expand Down Expand Up @@ -237,35 +247,45 @@ async def _kernel_connection(ws: starlette.websockets.WebSocket):
WebsocketDebugInfo.connecting -= 1
WebsocketDebugInfo.open += 1

async def run(ws_wrapper: WebsocketWrapper):
if kernel_context.async_context_id is not None:
kernel_context.async_context_id.set(uuid4().hex)
assert session_id is not None
assert kernel_id is not None
telemetry.connection_open(session_id)
headers_dict: Dict[str, List[str]] = {}
for k, v in ws.headers.items():
if k not in headers_dict.keys():
headers_dict[k] = [v]
else:
headers_dict[k].append(v)
await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user)

def websocket_thread_runner(ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal):
async def run():
async def run_wrapper():
try:
assert session_id is not None
assert kernel_id is not None
telemetry.connection_open(session_id)
headers_dict: Dict[str, List[str]] = {}
for k, v in ws.headers.items():
if k not in headers_dict.keys():
headers_dict[k] = [v]
else:
headers_dict[k].append(v)
await server.app_loop(ws_wrapper, ws.cookies, headers_dict, session_id, kernel_id, page_id, user)
ws_wrapper = WebsocketWrapper(ws, portal)
await run(ws_wrapper)
except: # noqa
await portal.stop(cancel_remaining=True)
if portal is not None:
await portal.stop(cancel_remaining=True)
raise
finally:
telemetry.connection_close(session_id)

# sometimes throws: RuntimeError: Already running asyncio in this thread
anyio.run(run) # type: ignore
anyio.run(run_wrapper) # type: ignore

# this portal allows us to sync call the websocket calls from this current event loop we are in
# each websocket however, is handled from a separate thread
try:
async with anyio.from_thread.BlockingPortal() as portal:
ws_wrapper = WebsocketWrapper(ws, portal)
thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws, portal, limiter=limiter) # type: ignore
await thread_return
if solara.util.has_threads:
async with anyio.from_thread.BlockingPortal() as portal:
ws_wrapper = WebsocketWrapper(ws, portal)
thread_return = anyio.to_thread.run_sync(websocket_thread_runner, ws, portal, limiter=limiter) # type: ignore
await thread_return
else:
await run(WebsocketWrapper(ws, None))
finally:
if settings.main.experimental_performance:
try:
Expand Down
7 changes: 1 addition & 6 deletions solara/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,7 @@

logger = logging.getLogger("solara.task")

try:
threading.Thread(target=lambda: None).start()
has_threads = True
except RuntimeError:
has_threads = False
has_threads
has_threads = solara.util.has_threads


class TaskState(Enum):
Expand Down
6 changes: 6 additions & 0 deletions solara/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
ipyvuetify_major_version = int(ipyvuetify.__version__.split(".")[0])
ipywidgets_major = int(ipywidgets.__version__.split(".")[0])

try:
threading.Thread(target=lambda: None).start()
has_threads = True
except RuntimeError:
has_threads = False


def github_url(file):
rel_path = os.path.relpath(file, Path(solara.__file__).parent.parent)
Expand Down

0 comments on commit f98e884

Please sign in to comment.