From b2f0a7647f1402c84a8dabb391c3ca7371975eb3 Mon Sep 17 00:00:00 2001 From: Aymeric Augustin Date: Sun, 13 Oct 2024 21:32:30 +0200 Subject: [PATCH] Add option to force sending text or binary frames. This adds the same functionality to the threading implemetation as bc4b8f2 did to the asyncio implementation. Refs #1515. --- src/websockets/asyncio/connection.py | 43 +++++++++++--------- src/websockets/sync/connection.py | 61 +++++++++++++++++----------- tests/sync/test_connection.py | 28 +++++++++++++ 3 files changed, 90 insertions(+), 42 deletions(-) diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 12871e4b..3b81e386 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -409,19 +409,17 @@ async def send( # strings and bytes-like objects are iterable. if isinstance(message, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(message.encode()) - else: - async with self.send_context(): + else: self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(message) - else: - async with self.send_context(): + else: self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -443,19 +441,17 @@ async def send( try: # First fragment. if isinstance(chunk, str): - if text is False: - async with self.send_context(): + async with self.send_context(): + if text is False: self.protocol.send_binary(chunk.encode(), fin=False) - else: - async with self.send_context(): + else: self.protocol.send_text(chunk.encode(), fin=False) encode = True elif isinstance(chunk, BytesLike): - if text is True: - async with self.send_context(): + async with self.send_context(): + if text is True: self.protocol.send_text(chunk, fin=False) - else: - async with self.send_context(): + else: self.protocol.send_binary(chunk, fin=False) encode = False else: @@ -480,7 +476,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -538,7 +537,10 @@ async def send( # We're half-way through a fragmented message and we can't # complete it. This makes the connection unusable. async with self.send_context(): - self.protocol.fail(1011, "error in fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) raise finally: @@ -568,7 +570,10 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # to terminate after calling a method that sends a close frame. async with self.send_context(): if self.fragmented_send_waiter is not None: - self.protocol.fail(1011, "close during fragmented message") + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) else: self.protocol.send_close(code, reason) except ConnectionClosed: diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index 8c5df959..3f4cac09 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -251,7 +251,11 @@ def recv_streaming(self) -> Iterator[Data]: "is already running recv or recv_streaming" ) from None - def send(self, message: Data | Iterable[Data]) -> None: + def send( + self, + message: Data | Iterable[Data], + text: bool | None = None, + ) -> None: """ Send a message. @@ -262,6 +266,17 @@ def send(self, message: Data | Iterable[Data]) -> None: .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + :meth:`send` also accepts an iterable of strings, bytestrings, or bytes-like objects to enable fragmentation_. Each item is treated as a message fragment and sent in its own frame. All items must be of the @@ -300,7 +315,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_text(message.encode()) + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) elif isinstance(message, BytesLike): with self.send_context(): @@ -309,7 +327,10 @@ def send(self, message: Data | Iterable[Data]) -> None: "cannot call send while another thread " "is already running send" ) - self.protocol.send_binary(message) + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) # Catch a common mistake -- passing a dict to send(). @@ -328,7 +349,6 @@ def send(self, message: Data | Iterable[Data]) -> None: try: # First fragment. if isinstance(chunk, str): - text = True with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -336,12 +356,12 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_text( - chunk.encode(), - fin=False, - ) + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True elif isinstance(chunk, BytesLike): - text = False with self.send_context(): if self.send_in_progress: raise ConcurrencyError( @@ -349,29 +369,24 @@ def send(self, message: Data | Iterable[Data]) -> None: "is already running send" ) self.send_in_progress = True - self.protocol.send_binary( - chunk, - fin=False, - ) + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False else: raise TypeError("data iterable must contain bytes or str") # Other fragments for chunk in chunks: - if isinstance(chunk, str) and text: + if isinstance(chunk, str) and encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk.encode(), - fin=False, - ) - elif isinstance(chunk, BytesLike) and not text: + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: with self.send_context(): assert self.send_in_progress - self.protocol.send_continuation( - chunk, - fin=False, - ) + self.protocol.send_continuation(chunk, fin=False) else: raise TypeError("data iterable must contain uniform types") diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index 16f92e16..87333fd3 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -308,6 +308,16 @@ def test_send_binary(self): self.connection.send(b"\x01\x02\xfe\xff") self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff") + def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + self.connection.send("😀", text=False) + self.assertEqual(self.remote_connection.recv(), "😀".encode()) + + def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + self.connection.send("😀".encode(), text=True) + self.assertEqual(self.remote_connection.recv(), "😀") + def test_send_fragmented_text(self): """send sends a fragmented text message.""" self.connection.send(["😀", "😀"]) @@ -326,6 +336,24 @@ def test_send_fragmented_binary(self): [b"\x01\x02", b"\xfe\xff", b""], ) + def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + list(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + def test_send_connection_closed_ok(self): """send raises ConnectionClosedOK after a normal closure.""" self.remote_connection.close()