From faf0f9f1e92973a504724e3e2e7c687629d47835 Mon Sep 17 00:00:00 2001 From: Tom Deseyn Date: Thu, 11 Apr 2024 16:04:43 +0200 Subject: [PATCH] Protocol: limit max size hint. --- src/Tmds.DBus.Protocol/MessageWriter.Basic.cs | 60 ++++++++++++++----- src/Tmds.DBus.Protocol/MessageWriter.cs | 9 --- .../Netstandard2_1Extensions.cs | 40 ------------- 3 files changed, 46 insertions(+), 63 deletions(-) diff --git a/src/Tmds.DBus.Protocol/MessageWriter.Basic.cs b/src/Tmds.DBus.Protocol/MessageWriter.Basic.cs index 10ddfeb9..2a6598ae 100644 --- a/src/Tmds.DBus.Protocol/MessageWriter.Basic.cs +++ b/src/Tmds.DBus.Protocol/MessageWriter.Basic.cs @@ -2,6 +2,8 @@ namespace Tmds.DBus.Protocol; public ref partial struct MessageWriter { + private const int MaxSizeHint = 4096; + public void WriteBool(bool value) => WriteUInt32(value ? 1u : 0u); public void WriteByte(byte value) => WritePrimitiveCore(value, DBusType.Byte); @@ -170,33 +172,63 @@ private void WritePrimitiveCore(T value, DBusType type) private int WriteRaw(ReadOnlySpan data) { - int length = data.Length; - var dst = GetSpan(length); - data.CopyTo(dst); - Advance(length); - return length; + int totalLength = data.Length; + if (totalLength <= MaxSizeHint) + { + var dst = GetSpan(totalLength); + data.CopyTo(dst); + Advance(totalLength); + return totalLength; + } + else + { + while (!data.IsEmpty) + { + var dst = GetSpan(1); + int length = Math.Min(data.Length, dst.Length); + data.Slice(0, length).CopyTo(dst); + Advance(length); + data = data.Slice(length); + } + return totalLength; + } } private int WriteRaw(string data) { -#if NETSTANDARD2_1_OR_GREATER || NET - // To use the IBufferWriter we need to flush the Span. - // Avoid it when we're writing small strings. - if (data.Length <= 2048) + const int MaxUtf8BytesPerChar = 3; + + if (data.Length <= MaxSizeHint / MaxUtf8BytesPerChar) { ReadOnlySpan chars = data.AsSpan(); int byteCount = Encoding.UTF8.GetByteCount(chars); var dst = GetSpan(byteCount); - byteCount = Encoding.UTF8.GetBytes(data, dst); + byteCount = Encoding.UTF8.GetBytes(data.AsSpan(), dst); Advance(byteCount); return byteCount; } else -#endif { - int length = (int)Encoding.UTF8.GetBytes(data.AsSpan(), Writer); - _offset += length; - return length; + ReadOnlySpan chars = data.AsSpan(); + Encoder encoder = Encoding.UTF8.GetEncoder(); + int totalLength = 0; + do + { + Debug.Assert(!chars.IsEmpty); + + var dst = GetSpan(MaxUtf8BytesPerChar); + encoder.Convert(chars, dst, flush: true, out int charsUsed, out int bytesUsed, out bool completed); + + Advance(bytesUsed); + totalLength += bytesUsed; + + if (completed) + { + return totalLength; + } + + chars = chars.Slice(charsUsed); + } while (true); } } } diff --git a/src/Tmds.DBus.Protocol/MessageWriter.cs b/src/Tmds.DBus.Protocol/MessageWriter.cs index 83ee1290..950c4b66 100644 --- a/src/Tmds.DBus.Protocol/MessageWriter.cs +++ b/src/Tmds.DBus.Protocol/MessageWriter.cs @@ -54,15 +54,6 @@ public MessageBuffer CreateMessage() return message; } - private IBufferWriter Writer - { - get - { - Flush(); - return _data; - } - } - internal MessageWriter(MessageBufferPool messagePool, uint serial) { _message = messagePool.Rent(); diff --git a/src/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs b/src/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs index a71c3803..70694d03 100644 --- a/src/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs +++ b/src/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs @@ -25,46 +25,6 @@ public static SafeHandle GetSafeHandle(this Socket socket) return null!; } - public static long GetBytes(this Encoding encoding, ReadOnlySpan chars, IBufferWriter writer) - { - if (chars.Length <= MaxInputElementsPerIteration) - { - int byteCount = encoding.GetByteCount(chars); - Span scratchBuffer = writer.GetSpan(byteCount); - - int actualBytesWritten = encoding.GetBytes(chars, scratchBuffer); - - writer.Advance(actualBytesWritten); - return actualBytesWritten; - } - else - { - Convert(encoding.GetEncoder(), chars, writer, flush: true, out long totalBytesWritten, out bool completed); - return totalBytesWritten; - } - } - - public static void Convert(this Encoder encoder, ReadOnlySpan chars, IBufferWriter writer, bool flush, out long bytesUsed, out bool completed) - { - long totalBytesWritten = 0; - do - { - int byteCountForThisSlice = (chars.Length <= MaxInputElementsPerIteration) - ? encoder.GetByteCount(chars, flush) - : encoder.GetByteCount(chars.Slice(0, MaxInputElementsPerIteration), flush: false); - - Span scratchBuffer = writer.GetSpan(byteCountForThisSlice); - - encoder.Convert(chars, scratchBuffer, flush, out int charsUsedJustNow, out int bytesWrittenJustNow, out completed); - - chars = chars.Slice(charsUsedJustNow); - writer.Advance(bytesWrittenJustNow); - totalBytesWritten += bytesWrittenJustNow; - } while (!chars.IsEmpty); - - bytesUsed = totalBytesWritten; - } - public static async Task ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken) { using var ctr = cancellationToken.Register(state => ((Socket)state!).Dispose(), socket, useSynchronizationContext: false);