Skip to content

Commit

Permalink
feat: --experimental-performance flag for faster performance
Browse files Browse the repository at this point in the history
This will only send websocket messages after a kernel messages is
processed.
  • Loading branch information
maartenbreddels committed Nov 24, 2023
1 parent aaa5431 commit 55e9e2d
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 6 deletions.
4 changes: 4 additions & 0 deletions solara/server/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class WebsocketWrapper(websocket.WebsocketWrapper):
def __init__(self, ws: simple_websocket.Server) -> None:
self.ws = ws
self.lock = threading.Lock()
super().__init__()

def close(self):
with self.lock:
Expand All @@ -78,6 +79,9 @@ async def receive(self):
except simple_websocket.ws.ConnectionClosed:
raise websocket.WebSocketDisconnect()

def flush(self):
pass # we do not implement queueing messages in flask (yet)


class ServerFlask(ServerBase):
server: Any
Expand Down
11 changes: 7 additions & 4 deletions solara/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,14 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s
else:
msg = deserialize_binary_message(message)
t1 = time.time()
if not process_kernel_messages(kernel, msg):
# if we shut down the kernel, we do not keep the page session alive
context.close()
return
hold_messages = ws.hold_messages() if settings.main.experimental_performance else solara.util.nullcontext()
with hold_messages:
if not process_kernel_messages(kernel, msg):
# if we shut down the kernel, we do not keep the page session alive
context.close()
return
t2 = time.time()

if settings.main.timing:
widgets_ids_after = set(patch.widgets)
created_widgets_count = len(widgets_ids_after - widgets_ids)
Expand Down
1 change: 1 addition & 0 deletions solara/server/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ class MainSettings(BaseSettings):
base_url: str = "" # e.g. https://myapp.solara.run/myapp/
platform: str = sys.platform
host: str = HOST_DEFAULT
experimental_performance: bool = True

class Config:
env_prefix = "solara_"
Expand Down
24 changes: 22 additions & 2 deletions solara/server/starlette.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,37 @@ class WebsocketWrapper(websocket.WebsocketWrapper):
ws: starlette.websockets.WebSocket

def __init__(self, ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal) -> None:
super().__init__()
self.ws = ws
self.portal = portal
self.to_send: List[Union[str, bytes]] = []

def flush(self):
async def _flush():
to_send, self.to_send = self.to_send, []
for data in to_send:
if isinstance(data, bytes):
await self.ws.send_bytes(data)
else:
await self.ws.send_text(data)

if settings.main.experimental_performance:
self.portal.call(_flush)

def close(self):
self.portal.call(self.ws.close)

def send_text(self, data: str) -> None:
self.portal.call(self.ws.send_text, data)
if self._queuing_messages:
self.to_send.append(data)
else:
self.portal.call(self.ws.send_bytes, data)

def send_bytes(self, data: bytes) -> None:
self.portal.call(self.ws.send_bytes, data)
if self._queuing_messages:
self.to_send.append(data)
else:
self.portal.call(self.ws.send_bytes, data)

async def receive(self):
if hasattr(self.portal, "start_task_soon"):
Expand Down
18 changes: 18 additions & 0 deletions solara/server/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Async implementation have to come up with a way how to do this sync (see e.g. the starlette implementation)
"""
import abc
import contextlib
import json
from typing import Union

Expand All @@ -12,6 +13,9 @@ class WebSocketDisconnect(Exception):


class WebsocketWrapper(abc.ABC):
def __init__(self):
self._queuing_messages = False

@abc.abstractmethod
def send_text(self, data: str) -> None:
pass
Expand Down Expand Up @@ -42,3 +46,17 @@ async def receive(self) -> Union[str, bytes]:
async def receive_json(self):
text = await self.receive()
return json.loads(text)

@abc.abstractmethod
def flush(self):
pass

@contextlib.contextmanager
def hold_messages(self):
# we're assuming this only get used from a single thread (only at server.py app loop)
self._queuing_messages = True
try:
yield
finally:
self._queuing_messages = False
self.flush()

0 comments on commit 55e9e2d

Please sign in to comment.