Skip to content

Commit

Permalink
Protocol: limit max size hint.
Browse files Browse the repository at this point in the history
  • Loading branch information
tmds committed Apr 11, 2024
1 parent 1efa047 commit faf0f9f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 63 deletions.
60 changes: 46 additions & 14 deletions src/Tmds.DBus.Protocol/MessageWriter.Basic.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ namespace Tmds.DBus.Protocol;

public ref partial struct MessageWriter
{
private const int MaxSizeHint = 4096;

public void WriteBool(bool value) => WriteUInt32(value ? 1u : 0u);

public void WriteByte(byte value) => WritePrimitiveCore<byte>(value, DBusType.Byte);
Expand Down Expand Up @@ -170,33 +172,63 @@ private void WritePrimitiveCore<T>(T value, DBusType type)

private int WriteRaw(ReadOnlySpan<byte> data)
{
int length = data.Length;
var dst = GetSpan(length);
data.CopyTo(dst);
Advance(length);
return length;
int totalLength = data.Length;
if (totalLength <= MaxSizeHint)
{
var dst = GetSpan(totalLength);
data.CopyTo(dst);
Advance(totalLength);
return totalLength;
}
else
{
while (!data.IsEmpty)
{
var dst = GetSpan(1);
int length = Math.Min(data.Length, dst.Length);
data.Slice(0, length).CopyTo(dst);
Advance(length);
data = data.Slice(length);
}
return totalLength;
}
}

private int WriteRaw(string data)
{
#if NETSTANDARD2_1_OR_GREATER || NET
// To use the IBufferWriter we need to flush the Span.
// Avoid it when we're writing small strings.
if (data.Length <= 2048)
const int MaxUtf8BytesPerChar = 3;

if (data.Length <= MaxSizeHint / MaxUtf8BytesPerChar)
{
ReadOnlySpan<char> chars = data.AsSpan();
int byteCount = Encoding.UTF8.GetByteCount(chars);
var dst = GetSpan(byteCount);
byteCount = Encoding.UTF8.GetBytes(data, dst);
byteCount = Encoding.UTF8.GetBytes(data.AsSpan(), dst);
Advance(byteCount);
return byteCount;
}
else
#endif
{
int length = (int)Encoding.UTF8.GetBytes(data.AsSpan(), Writer);
_offset += length;
return length;
ReadOnlySpan<char> chars = data.AsSpan();
Encoder encoder = Encoding.UTF8.GetEncoder();
int totalLength = 0;
do
{
Debug.Assert(!chars.IsEmpty);

var dst = GetSpan(MaxUtf8BytesPerChar);
encoder.Convert(chars, dst, flush: true, out int charsUsed, out int bytesUsed, out bool completed);

Advance(bytesUsed);
totalLength += bytesUsed;

if (completed)
{
return totalLength;
}

chars = chars.Slice(charsUsed);
} while (true);
}
}
}
9 changes: 0 additions & 9 deletions src/Tmds.DBus.Protocol/MessageWriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,6 @@ public MessageBuffer CreateMessage()
return message;
}

private IBufferWriter<byte> Writer
{
get
{
Flush();
return _data;
}
}

internal MessageWriter(MessageBufferPool messagePool, uint serial)
{
_message = messagePool.Rent();
Expand Down
40 changes: 0 additions & 40 deletions src/Tmds.DBus.Protocol/Netstandard2_1Extensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,46 +25,6 @@ public static SafeHandle GetSafeHandle(this Socket socket)
return null!;
}

public static long GetBytes(this Encoding encoding, ReadOnlySpan<char> chars, IBufferWriter<byte> writer)
{
if (chars.Length <= MaxInputElementsPerIteration)
{
int byteCount = encoding.GetByteCount(chars);
Span<byte> scratchBuffer = writer.GetSpan(byteCount);

int actualBytesWritten = encoding.GetBytes(chars, scratchBuffer);

writer.Advance(actualBytesWritten);
return actualBytesWritten;
}
else
{
Convert(encoding.GetEncoder(), chars, writer, flush: true, out long totalBytesWritten, out bool completed);
return totalBytesWritten;
}
}

public static void Convert(this Encoder encoder, ReadOnlySpan<char> chars, IBufferWriter<byte> writer, bool flush, out long bytesUsed, out bool completed)
{
long totalBytesWritten = 0;
do
{
int byteCountForThisSlice = (chars.Length <= MaxInputElementsPerIteration)
? encoder.GetByteCount(chars, flush)
: encoder.GetByteCount(chars.Slice(0, MaxInputElementsPerIteration), flush: false);

Span<byte> scratchBuffer = writer.GetSpan(byteCountForThisSlice);

encoder.Convert(chars, scratchBuffer, flush, out int charsUsedJustNow, out int bytesWrittenJustNow, out completed);

chars = chars.Slice(charsUsedJustNow);
writer.Advance(bytesWrittenJustNow);
totalBytesWritten += bytesWrittenJustNow;
} while (!chars.IsEmpty);

bytesUsed = totalBytesWritten;
}

public static async Task ConnectAsync(this Socket socket, EndPoint remoteEP, CancellationToken cancellationToken)
{
using var ctr = cancellationToken.Register(state => ((Socket)state!).Dispose(), socket, useSynchronizationContext: false);
Expand Down

0 comments on commit faf0f9f

Please sign in to comment.