From ba08f8540a9ea1541f0a27712e8e00bf9f3ee4b4 Mon Sep 17 00:00:00 2001 From: "Maarten A. Breddels" Date: Tue, 21 Nov 2023 21:28:44 +0100 Subject: [PATCH] feat: --experimental-performance flag for faster performance This will only send websocket messages after a kernel messages is processed. --- solara/server/flask.py | 4 ++++ solara/server/server.py | 11 +++++++---- solara/server/settings.py | 1 + solara/server/starlette.py | 24 ++++++++++++++++++++++-- solara/server/websocket.py | 18 ++++++++++++++++++ 5 files changed, 52 insertions(+), 6 deletions(-) diff --git a/solara/server/flask.py b/solara/server/flask.py index dd653101c..f33884e18 100644 --- a/solara/server/flask.py +++ b/solara/server/flask.py @@ -57,6 +57,7 @@ class WebsocketWrapper(websocket.WebsocketWrapper): def __init__(self, ws: simple_websocket.Server) -> None: self.ws = ws self.lock = threading.Lock() + super().__init__() def close(self): with self.lock: @@ -78,6 +79,9 @@ async def receive(self): except simple_websocket.ws.ConnectionClosed: raise websocket.WebSocketDisconnect() + def flush(self): + pass # we do not implement queueing messages in flask (yet) + class ServerFlask(ServerBase): server: Any diff --git a/solara/server/server.py b/solara/server/server.py index 9cc959d2d..b0c783988 100644 --- a/solara/server/server.py +++ b/solara/server/server.py @@ -152,11 +152,14 @@ async def app_loop(ws: websocket.WebsocketWrapper, session_id: str, kernel_id: s else: msg = deserialize_binary_message(message) t1 = time.time() - if not process_kernel_messages(kernel, msg): - # if we shut down the kernel, we do not keep the page session alive - context.close() - return + hold_messages = ws.hold_messages() if settings.main.experimental_performance else solara.util.nullcontext() + with hold_messages: + if not process_kernel_messages(kernel, msg): + # if we shut down the kernel, we do not keep the page session alive + context.close() + return t2 = time.time() + if settings.main.timing: widgets_ids_after = set(patch.widgets) created_widgets_count = len(widgets_ids_after - widgets_ids) diff --git a/solara/server/settings.py b/solara/server/settings.py index 58f8607a0..c3b52cc2b 100644 --- a/solara/server/settings.py +++ b/solara/server/settings.py @@ -151,6 +151,7 @@ class MainSettings(BaseSettings): base_url: str = "" # e.g. https://myapp.solara.run/myapp/ platform: str = sys.platform host: str = HOST_DEFAULT + experimental_performance: bool = True class Config: env_prefix = "solara_" diff --git a/solara/server/starlette.py b/solara/server/starlette.py index fd178d34c..06ca2bd2c 100644 --- a/solara/server/starlette.py +++ b/solara/server/starlette.py @@ -72,17 +72,37 @@ class WebsocketWrapper(websocket.WebsocketWrapper): ws: starlette.websockets.WebSocket def __init__(self, ws: starlette.websockets.WebSocket, portal: anyio.from_thread.BlockingPortal) -> None: + super().__init__() self.ws = ws self.portal = portal + self.to_send: List[Union[str, bytes]] = [] + + def flush(self): + async def _flush(): + to_send, self.to_send = self.to_send, [] + for data in to_send: + if isinstance(data, bytes): + await self.ws.send_bytes(data) + else: + await self.ws.send_text(data) + + if settings.main.experimental_performance: + self.portal.call(_flush) def close(self): self.portal.call(self.ws.close) def send_text(self, data: str) -> None: - self.portal.call(self.ws.send_text, data) + if self._queuing_messages: + self.to_send.append(data) + else: + self.portal.call(self.ws.send_bytes, data) def send_bytes(self, data: bytes) -> None: - self.portal.call(self.ws.send_bytes, data) + if self._queuing_messages: + self.to_send.append(data) + else: + self.portal.call(self.ws.send_bytes, data) async def receive(self): if hasattr(self.portal, "start_task_soon"): diff --git a/solara/server/websocket.py b/solara/server/websocket.py index 7bbd5e8a7..44634f985 100644 --- a/solara/server/websocket.py +++ b/solara/server/websocket.py @@ -3,6 +3,7 @@ Async implementation have to come up with a way how to do this sync (see e.g. the starlette implementation) """ import abc +import contextlib import json from typing import Union @@ -12,6 +13,9 @@ class WebSocketDisconnect(Exception): class WebsocketWrapper(abc.ABC): + def __init__(self): + self._queuing_messages = False + @abc.abstractmethod def send_text(self, data: str) -> None: pass @@ -42,3 +46,17 @@ async def receive(self) -> Union[str, bytes]: async def receive_json(self): text = await self.receive() return json.loads(text) + + @abc.abstractmethod + def flush(self): + pass + + @contextlib.contextmanager + def hold_messages(self): + # we're assuming this only get used from a single thread (only at server.py app loop) + self._queuing_messages = True + try: + yield + finally: + self._queuing_messages = False + self.flush()