diff --git a/src/CommandLine/TunnelDashboard.cs b/src/CommandLine/TunnelDashboard.cs index 16fdc93..01963eb 100644 --- a/src/CommandLine/TunnelDashboard.cs +++ b/src/CommandLine/TunnelDashboard.cs @@ -127,7 +127,7 @@ private static void UpdateConnections(Tunnel tunnel, Stack con foreach (var connection in connectionHistory) { var httpConnection = connection as ProxiedHttpTunnelConnection; - var requestMessage = httpConnection?.RequestMessage; + var requestMessage = httpConnection?.HttpRequest; var bytesIn = httpConnection?.Statistics.BytesIn / 1024F; var bytesOut = httpConnection?.Statistics.BytesIn / 1024F; diff --git a/src/Connections/ProxiedHttpTunnelConnection.cs b/src/Connections/ProxiedHttpTunnelConnection.cs index 81a7606..7b59502 100644 --- a/src/Connections/ProxiedHttpTunnelConnection.cs +++ b/src/Connections/ProxiedHttpTunnelConnection.cs @@ -28,7 +28,7 @@ public ProxiedHttpTunnelConnection(TunnelConnectionHandle handle, ProxiedHttpTun public ProxiedHttpTunnelOptions Options { get; } - public HttpRequestMessage? RequestMessage { get; private set; } + public HttpRequestMessage? HttpRequest { get; private set; } public ConnectionStatistics Statistics => _statistics; @@ -75,7 +75,6 @@ protected override void Dispose(bool disposing) return; } - RequestMessage?.Dispose(); _proxyStream?.Dispose(); _proxySocket?.Dispose(); ArrayPool.Shared.Return(_receiveBuffer); @@ -104,28 +103,31 @@ private void BeginRead() private void ProcessRequest(ref ArraySegment data) { - var memoryStream = new MemoryStream(data.Array!, data.Offset, data.Array!.Length); + var requestBuffer = data.Array!; + var requestBody = (ReadOnlySpan)data; - using (var streamReader = new StreamReader(memoryStream, leaveOpen: true)) - { - RequestMessage = RequestReader.Parse(streamReader, BaseUri)!; - } + HttpRequest = RequestReader.Parse(ref requestBody, BaseUri)!; + Options.RequestProcessor!.Process(this, HttpRequest); - // save request body as span - var requestBody = data.Array.AsSpan(data.Offset + (int)memoryStream.Position); - memoryStream.Position = 0; - - Options.RequestProcessor!.Process(this, RequestMessage); + var pooledBuffer = Tunnel.ArrayPool.Rent(data.Count + 8096); // write request back - using (var streamWriter = new StreamWriter(memoryStream, leaveOpen: true)) + int requestLength; + using (var memoryStream = new MemoryStream(pooledBuffer)) { - RequestWriter.WriteRequest(streamWriter, RequestMessage); + using (var streamWriter = new StreamWriter(memoryStream, leaveOpen: true)) + { + RequestWriter.WriteRequest(streamWriter, HttpRequest, requestBody.Length); + } + + requestLength = (int)memoryStream.Position; } - // write request body - memoryStream.Write(requestBody); - data = new(data.Array!, data.Offset, (int)memoryStream.Position); + requestBody.CopyTo(pooledBuffer.AsSpan(requestLength)); + data = new(pooledBuffer, 0, requestLength + requestBody.Length); + + // return current buffer + Tunnel.ArrayPool.Return(requestBuffer); } private void ReceiveCallbackInternal(IAsyncResult asyncResult) diff --git a/src/Http/RequestReader.cs b/src/Http/RequestReader.cs index 260c90e..0f13dc1 100644 --- a/src/Http/RequestReader.cs +++ b/src/Http/RequestReader.cs @@ -1,16 +1,32 @@ namespace Localtunnel.Http { using System; - using System.IO; using System.Net; using System.Net.Http; using System.Net.Http.Headers; + using System.Text; internal static class RequestReader { - public static HttpRequestMessage? Parse(TextReader textReader, Uri baseUri) + private static readonly byte[] _eol = new byte[] { (byte)'\r', (byte)'\n' }; + + private static string? ReadLine(ref ReadOnlySpan span) + { + var start = span.IndexOf(_eol); + + if (start is -1) + { + return null; + } + + var content = span[0..start]; + span = span[(start + 2)..]; + return Encoding.UTF8.GetString(content); + } + + public static HttpRequestMessage? Parse(ref ReadOnlySpan span, Uri baseUri) { - var statusLine = textReader.ReadLine(); + var statusLine = ReadLine(ref span); if (string.IsNullOrWhiteSpace(statusLine)) { @@ -32,7 +48,7 @@ internal static class RequestReader }; // read headers - ReadHttpHeaders(textReader, requestMessage.Headers); + ReadHttpHeaders(ref span, requestMessage.Headers); return requestMessage; } @@ -45,10 +61,10 @@ internal static class RequestReader _ => HttpVersion.Unknown, }; - private static void ReadHttpHeaders(TextReader textReader, HttpRequestHeaders headers) + private static void ReadHttpHeaders(ref ReadOnlySpan span, HttpRequestHeaders headers) { string? line; - while (!string.IsNullOrWhiteSpace(line = textReader.ReadLine())) + while (!string.IsNullOrWhiteSpace(line = ReadLine(ref span))) { var index = line.IndexOf(':'); diff --git a/src/Http/RequestWriter.cs b/src/Http/RequestWriter.cs index dc09fdc..a52f049 100644 --- a/src/Http/RequestWriter.cs +++ b/src/Http/RequestWriter.cs @@ -7,14 +7,18 @@ internal static class RequestWriter { private const string HTTP_EOL = "\r\n"; - public static void WriteRequest(TextWriter writer, HttpRequestMessage request) + public static void WriteRequest(TextWriter writer, HttpRequestMessage request, long contentLength) { // status line writer.Write(request.Method); writer.Write(' '); writer.Write(request.RequestUri!.PathAndQuery); - writer.Write(" HTTP/"); - writer.Write(request.Version.ToString(2)); + writer.Write(" HTTP/1.1"); + writer.Write(HTTP_EOL); + + // content length + writer.Write("Content-Length: "); + writer.Write(contentLength); writer.Write(HTTP_EOL); // headers @@ -27,7 +31,6 @@ public static void WriteRequest(TextWriter writer, HttpRequestMessage request) } writer.Write(HTTP_EOL); - writer.Write(HTTP_EOL); } } } diff --git a/src/Tunnels/TunnelSocketContext.cs b/src/Tunnels/TunnelSocketContext.cs index a18a087..e70e49f 100644 --- a/src/Tunnels/TunnelSocketContext.cs +++ b/src/Tunnels/TunnelSocketContext.cs @@ -145,7 +145,16 @@ private void NotifyCompletedReceive(SocketAsyncEventArgs eventArgs) // initialize connection var handle = new TunnelConnectionHandle(this); connection = _connection = Tunnel.ConnectionFactory(handle); - connection.Open(); + + try + { + connection.Open(); + } + catch (Exception) + { + Dispose(); + return; + } } // capture buffer