From 1ef1de991b573ff1a39b121e70fbb2e0027624f3 Mon Sep 17 00:00:00 2001 From: Brent Yi Date: Tue, 1 Aug 2023 17:08:24 -0700 Subject: [PATCH] Fix client API version bugs --- viser/infra/_infra.py | 28 ++++++++++++++++++---------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/viser/infra/_infra.py b/viser/infra/_infra.py index a2ed12a1b..46c0369ff 100644 --- a/viser/infra/_infra.py +++ b/viser/infra/_infra.py @@ -14,7 +14,6 @@ Callable, Dict, List, - Literal, NewType, Optional, Sequence, @@ -32,6 +31,7 @@ from rich import box, style from rich.panel import Panel from rich.table import Table +from typing_extensions import Literal, assert_never from websockets.legacy.server import WebSocketServerProtocol from ._async_message_buffer import AsyncMessageBuffer, MessageWindow @@ -134,7 +134,7 @@ def __init__( self._message_class = message_class self._http_server_root = http_server_root self._verbose = verbose - self._client_api_version = client_api_version + self._client_api_version: Literal[0, 1] = client_api_version self._thread_executor = ThreadPoolExecutor(max_workers=32) @@ -332,7 +332,7 @@ async def _client_producer( websocket: WebSocketServerProtocol, client_id: ClientId, get_next: Callable[[], Awaitable[Message]], - client_api_version: int, + client_api_version: Literal[0, 1], ) -> None: """Infinite loop to send messages from a buffer to a single client.""" @@ -343,34 +343,42 @@ async def _client_producer( message_future = asyncio.ensure_future(get_next()) outgoing = window.get_window_to_send() if outgoing is not None: - if client_api_version: + if client_api_version == 1: serialized = msgpack.packb( tuple(message.as_serializable_dict() for message in outgoing) ) assert isinstance(serialized, bytes) await websocket.send(serialized) - else: + elif client_api_version == 0: for msg in outgoing: - await websocket.send(msg.as_serializable_dict()) + serialized = msgpack.packb(msg.as_serializable_dict()) + assert isinstance(serialized, bytes) + await websocket.send(serialized) + else: + assert_never(client_api_version) async def _broadcast_producer( websocket: WebSocketServerProtocol, get_next_window: Callable[[], Awaitable[Sequence[Message]]], - client_api_version: int, + client_api_version: Literal[0, 1], ) -> None: """Infinite loop to broadcast windows of messages from a buffer.""" while True: outgoing = await get_next_window() - if client_api_version: + if client_api_version == 1: serialized = msgpack.packb( tuple(message.as_serializable_dict() for message in outgoing) ) assert isinstance(serialized, bytes) await websocket.send(serialized) - else: + elif client_api_version == 0: for msg in outgoing: - await websocket.send(msg.as_serializable_dict()) + serialized = msgpack.packb(msg.as_serializable_dict()) + assert isinstance(serialized, bytes) + await websocket.send(serialized) + else: + assert_never(client_api_version) async def _consumer(