Skip to content

Commit

Permalink
Add option to force sending text or binary frames.
Browse files Browse the repository at this point in the history
This adds the same functionality to the threading implemetation
as bc4b8f2 did to the asyncio implementation. Refs #1515.
  • Loading branch information
aaugustin committed Oct 13, 2024
1 parent c5985d5 commit b2f0a76
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 42 deletions.
43 changes: 24 additions & 19 deletions src/websockets/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,19 +409,17 @@ async def send(
# strings and bytes-like objects are iterable.

if isinstance(message, str):
if text is False:
async with self.send_context():
async with self.send_context():
if text is False:
self.protocol.send_binary(message.encode())
else:
async with self.send_context():
else:
self.protocol.send_text(message.encode())

elif isinstance(message, BytesLike):
if text is True:
async with self.send_context():
async with self.send_context():
if text is True:
self.protocol.send_text(message)
else:
async with self.send_context():
else:
self.protocol.send_binary(message)

# Catch a common mistake -- passing a dict to send().
Expand All @@ -443,19 +441,17 @@ async def send(
try:
# First fragment.
if isinstance(chunk, str):
if text is False:
async with self.send_context():
async with self.send_context():
if text is False:
self.protocol.send_binary(chunk.encode(), fin=False)
else:
async with self.send_context():
else:
self.protocol.send_text(chunk.encode(), fin=False)
encode = True
elif isinstance(chunk, BytesLike):
if text is True:
async with self.send_context():
async with self.send_context():
if text is True:
self.protocol.send_text(chunk, fin=False)
else:
async with self.send_context():
else:
self.protocol.send_binary(chunk, fin=False)
encode = False
else:
Expand All @@ -480,7 +476,10 @@ async def send(
# We're half-way through a fragmented message and we can't
# complete it. This makes the connection unusable.
async with self.send_context():
self.protocol.fail(1011, "error in fragmented message")
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"error in fragmented message",
)
raise

finally:
Expand Down Expand Up @@ -538,7 +537,10 @@ async def send(
# We're half-way through a fragmented message and we can't
# complete it. This makes the connection unusable.
async with self.send_context():
self.protocol.fail(1011, "error in fragmented message")
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"error in fragmented message",
)
raise

finally:
Expand Down Expand Up @@ -568,7 +570,10 @@ async def close(self, code: int = 1000, reason: str = "") -> None:
# to terminate after calling a method that sends a close frame.
async with self.send_context():
if self.fragmented_send_waiter is not None:
self.protocol.fail(1011, "close during fragmented message")
self.protocol.fail(
CloseCode.INTERNAL_ERROR,
"close during fragmented message",
)
else:
self.protocol.send_close(code, reason)
except ConnectionClosed:
Expand Down
61 changes: 38 additions & 23 deletions src/websockets/sync/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,11 @@ def recv_streaming(self) -> Iterator[Data]:
"is already running recv or recv_streaming"
) from None

def send(self, message: Data | Iterable[Data]) -> None:
def send(
self,
message: Data | Iterable[Data],
text: bool | None = None,
) -> None:
"""
Send a message.
Expand All @@ -262,6 +266,17 @@ def send(self, message: Data | Iterable[Data]) -> None:
.. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
.. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6
You may override this behavior with the ``text`` argument:
* Set ``text=True`` to send a bytestring or bytes-like object
(:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a
Text_ frame. This improves performance when the message is already
UTF-8 encoded, for example if the message contains JSON and you're
using a JSON library that produces a bytestring.
* Set ``text=False`` to send a string (:class:`str`) in a Binary_
frame. This may be useful for servers that expect binary frames
instead of text frames.
:meth:`send` also accepts an iterable of strings, bytestrings, or
bytes-like objects to enable fragmentation_. Each item is treated as a
message fragment and sent in its own frame. All items must be of the
Expand Down Expand Up @@ -300,7 +315,10 @@ def send(self, message: Data | Iterable[Data]) -> None:
"cannot call send while another thread "
"is already running send"
)
self.protocol.send_text(message.encode())
if text is False:
self.protocol.send_binary(message.encode())
else:
self.protocol.send_text(message.encode())

elif isinstance(message, BytesLike):
with self.send_context():
Expand All @@ -309,7 +327,10 @@ def send(self, message: Data | Iterable[Data]) -> None:
"cannot call send while another thread "
"is already running send"
)
self.protocol.send_binary(message)
if text is True:
self.protocol.send_text(message)
else:
self.protocol.send_binary(message)

# Catch a common mistake -- passing a dict to send().

Expand All @@ -328,50 +349,44 @@ def send(self, message: Data | Iterable[Data]) -> None:
try:
# First fragment.
if isinstance(chunk, str):
text = True
with self.send_context():
if self.send_in_progress:
raise ConcurrencyError(
"cannot call send while another thread "
"is already running send"
)
self.send_in_progress = True
self.protocol.send_text(
chunk.encode(),
fin=False,
)
if text is False:
self.protocol.send_binary(chunk.encode(), fin=False)
else:
self.protocol.send_text(chunk.encode(), fin=False)
encode = True
elif isinstance(chunk, BytesLike):
text = False
with self.send_context():
if self.send_in_progress:
raise ConcurrencyError(
"cannot call send while another thread "
"is already running send"
)
self.send_in_progress = True
self.protocol.send_binary(
chunk,
fin=False,
)
if text is True:
self.protocol.send_text(chunk, fin=False)
else:
self.protocol.send_binary(chunk, fin=False)
encode = False
else:
raise TypeError("data iterable must contain bytes or str")

# Other fragments
for chunk in chunks:
if isinstance(chunk, str) and text:
if isinstance(chunk, str) and encode:
with self.send_context():
assert self.send_in_progress
self.protocol.send_continuation(
chunk.encode(),
fin=False,
)
elif isinstance(chunk, BytesLike) and not text:
self.protocol.send_continuation(chunk.encode(), fin=False)
elif isinstance(chunk, BytesLike) and not encode:
with self.send_context():
assert self.send_in_progress
self.protocol.send_continuation(
chunk,
fin=False,
)
self.protocol.send_continuation(chunk, fin=False)
else:
raise TypeError("data iterable must contain uniform types")

Expand Down
28 changes: 28 additions & 0 deletions tests/sync/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,16 @@ def test_send_binary(self):
self.connection.send(b"\x01\x02\xfe\xff")
self.assertEqual(self.remote_connection.recv(), b"\x01\x02\xfe\xff")

def test_send_binary_from_str(self):
"""send sends a binary message from a str."""
self.connection.send("😀", text=False)
self.assertEqual(self.remote_connection.recv(), "😀".encode())

def test_send_text_from_bytes(self):
"""send sends a text message from bytes."""
self.connection.send("😀".encode(), text=True)
self.assertEqual(self.remote_connection.recv(), "😀")

def test_send_fragmented_text(self):
"""send sends a fragmented text message."""
self.connection.send(["😀", "😀"])
Expand All @@ -326,6 +336,24 @@ def test_send_fragmented_binary(self):
[b"\x01\x02", b"\xfe\xff", b""],
)

def test_send_fragmented_binary_from_str(self):
"""send sends a fragmented binary message from a str."""
self.connection.send(["😀", "😀"], text=False)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
list(self.remote_connection.recv_streaming()),
["😀".encode(), "😀".encode(), b""],
)

def test_send_fragmented_text_from_bytes(self):
"""send sends a fragmented text message from bytes."""
self.connection.send(["😀".encode(), "😀".encode()], text=True)
# websockets sends an trailing empty fragment. That's an implementation detail.
self.assertEqual(
list(self.remote_connection.recv_streaming()),
["😀", "😀", ""],
)

def test_send_connection_closed_ok(self):
"""send raises ConnectionClosedOK after a normal closure."""
self.remote_connection.close()
Expand Down

0 comments on commit b2f0a76

Please sign in to comment.