Skip to content

Commit

Permalink
Fix support for unix sockets (#1995)
Browse files Browse the repository at this point in the history
* Add support for changing the protocol type on socket level

* Do not set NoDelay when not supported by protocol

* Update release notes

* Fix release notes
  • Loading branch information
chkr1011 authored May 14, 2024
1 parent b476977 commit 41d5b70
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 32 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ReleaseNotes.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
* [Core] Optimized packet serialization of PUBACK and PUBREC packets for protocol version 5.0.0 (#1939, thanks to @Y-Sindo).
* [Core] The package inspector is now fully async (#1941).
* [Core] Fixed decoding of DISCONNECT packet with empty body (#1994, thanks to @Y-Sindo).
* [Client] Exposed the _EndPoint_ type to support other endpoint types (like Unix Domain Sockets) in client options (#1919).
* [Client] Fixed support for unix sockets by exposing more options (#1995).
* [Client] Added a dedicated exception when the client is not connected (#1954, thanks to @marcpiulachs).
* [Client] The client will now throw a _MqttClientUnexpectedDisconnectReceivedException_ when publishing a QoS 0 message which leads to a server disconnect (BREAKING CHANGE!, #1974, thanks to @fazho).
* [Client] Exposed the certificate selection event handler in client options (#1984).
Expand Down
4 changes: 2 additions & 2 deletions Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class MqttTcpChannel_Tests
public async Task Dispose_Channel_While_Used()
{
var ct = new CancellationTokenSource();
var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork);
var serverSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);

try
{
Expand All @@ -38,7 +38,7 @@ public async Task Dispose_Channel_While_Used()
}
}, ct.Token);

var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork);
var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);
await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None);

var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null);
Expand Down
4 changes: 2 additions & 2 deletions Source/MQTTnet.Tests/Server/Connection_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public async Task Close_Idle_Connection_On_Connect()
{
await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1)));

var client = new CrossPlatformSocket(AddressFamily.InterNetwork);
var client = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);

// Don't send anything. The server should close the connection.
Expand Down Expand Up @@ -54,7 +54,7 @@ public async Task Send_Garbage()

// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state
// forever. This is security related.
var client = new CrossPlatformSocket(AddressFamily.InterNetwork);
var client = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);

var buffer = Encoding.UTF8.GetBytes("Garbage");
Expand Down
37 changes: 25 additions & 12 deletions Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Text;
using MQTTnet.Formatter;
using MQTTnet.Packets;
Expand All @@ -18,14 +19,14 @@ namespace MQTTnet.Client
public sealed class MqttClientOptionsBuilder
{
readonly MqttClientOptions _options = new MqttClientOptions();
int? _port;

[Obsolete] MqttClientWebSocketProxyOptions _proxyOptions;
EndPoint _remoteEndPoint;

MqttClientTcpOptions _tcpOptions;
MqttClientTlsOptions _tlsOptions;
EndPoint _remoteEndPoint;
int? _port;


[Obsolete] MqttClientOptionsBuilderTlsParameters _tlsParameters;

MqttClientWebSocketOptions _webSocketOptions;
Expand Down Expand Up @@ -89,7 +90,7 @@ public MqttClientOptions Build()

if (_tcpOptions.RemoteEndpoint == null)
{
_tcpOptions.RemoteEndpoint = _remoteEndPoint;
_tcpOptions.RemoteEndpoint = _remoteEndPoint;
}
}
else if (_webSocketOptions != null)
Expand All @@ -114,6 +115,12 @@ public MqttClientOptions Build()
return _options;
}

public MqttClientOptionsBuilder WithAddressFamily(AddressFamily addressFamily)
{
_tcpOptions.AddressFamily = addressFamily;
return this;
}

public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data)
{
_options.AuthenticationMethod = method;
Expand Down Expand Up @@ -218,6 +225,14 @@ public MqttClientOptionsBuilder WithCredentials(IMqttClientCredentialsProvider c
return this;
}

public MqttClientOptionsBuilder WithEndPoint(EndPoint endPoint)
{
_remoteEndPoint = endPoint ?? throw new ArgumentNullException(nameof(endPoint));
_tcpOptions = new MqttClientTcpOptions();

return this;
}

public MqttClientOptionsBuilder WithExtendedAuthenticationExchangeHandler(IMqttExtendedAuthenticationExchangeHandler handler)
{
_options.ExtendedAuthenticationExchangeHandler = handler;
Expand Down Expand Up @@ -263,6 +278,12 @@ public MqttClientOptionsBuilder WithoutThrowOnNonSuccessfulConnectResponse()
return this;
}

public MqttClientOptionsBuilder WithProtocolType(ProtocolType protocolType)
{
_tcpOptions.ProtocolType = protocolType;
return this;
}

public MqttClientOptionsBuilder WithProtocolVersion(MqttProtocolVersion value)
{
if (value == MqttProtocolVersion.Unknown)
Expand Down Expand Up @@ -344,14 +365,6 @@ public MqttClientOptionsBuilder WithTcpServer(string server, int? port = null)

return this;
}

public MqttClientOptionsBuilder WithEndPoint(EndPoint endPoint)
{
_remoteEndPoint = endPoint ?? throw new ArgumentNullException(nameof(endPoint));
_tcpOptions = new MqttClientTcpOptions();

return this;
}

public MqttClientOptionsBuilder WithTcpServer(Action<MqttClientTcpOptions> optionsBuilder)
{
Expand Down
12 changes: 12 additions & 0 deletions Source/MQTTnet/Client/Options/MqttClientTcpOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,20 @@ public sealed class MqttClientTcpOptions : IMqttClientChannelOptions
/// </summary>
public EndPoint LocalEndpoint { get; set; }

/// <summary>
/// Enables or disables the Nagle algorithm for the socket.
/// This is only supported for TCP.
/// For other protocol types the value is ignored.
/// Default: true
/// </summary>
public bool NoDelay { get; set; } = true;

/// <summary>
/// The MQTT transport is usually TCP but when using other endpoint types like
/// unix sockets it must be changed (IP for unix sockets).
/// </summary>
public ProtocolType ProtocolType { get; set; } = ProtocolType.Tcp;

public EndPoint RemoteEndpoint { get; set; }

public MqttClientTlsOptions TlsOptions { get; set; } = new MqttClientTlsOptions();
Expand Down
4 changes: 2 additions & 2 deletions Source/MQTTnet/Implementations/CrossPlatformSocket.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ public sealed class CrossPlatformSocket : IDisposable

NetworkStream _networkStream;

public CrossPlatformSocket(AddressFamily addressFamily)
public CrossPlatformSocket(AddressFamily addressFamily, ProtocolType protocolType)
{
_socket = new Socket(addressFamily, SocketType.Stream, ProtocolType.Tcp);
_socket = new Socket(addressFamily, SocketType.Stream, protocolType);

#if !NET5_0_OR_GREATER
_socketDisposeAction = _socket.Dispose;
Expand Down
9 changes: 7 additions & 2 deletions Source/MQTTnet/Implementations/MqttTcpChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
}
else
{
socket = new CrossPlatformSocket(_tcpOptions.AddressFamily);
socket = new CrossPlatformSocket(_tcpOptions.AddressFamily, _tcpOptions.ProtocolType);
}

if (_tcpOptions.LocalEndpoint != null)
Expand All @@ -78,7 +78,12 @@ public async Task ConnectAsync(CancellationToken cancellationToken)
socket.ReceiveBufferSize = _tcpOptions.BufferSize;
socket.SendBufferSize = _tcpOptions.BufferSize;
socket.SendTimeout = (int)_clientOptions.Timeout.TotalMilliseconds;
socket.NoDelay = _tcpOptions.NoDelay;

if (_tcpOptions.ProtocolType == ProtocolType.Tcp)
{
// Other protocol types do not support the Nagle algorithm.
socket.NoDelay = _tcpOptions.NoDelay;
}

if (socket.LingerState != null)
{
Expand Down
24 changes: 12 additions & 12 deletions Source/MQTTnet/Implementations/MqttTcpServerListener.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public sealed class MqttTcpServerListener : IDisposable
readonly MqttServerOptions _serverOptions;
readonly MqttServerTcpEndpointBaseOptions _options;
readonly MqttServerTlsTcpEndpointOptions _tlsOptions;

CrossPlatformSocket _socket;
IPEndPoint _localEndPoint;

Expand Down Expand Up @@ -65,7 +65,7 @@ public bool Start(bool treatErrorsAsWarning, CancellationToken cancellationToken

_logger.Info("Starting TCP listener (Endpoint={0}, TLS={1})", _localEndPoint, _tlsOptions?.CertificateProvider != null);

_socket = new CrossPlatformSocket(_addressFamily);
_socket = new CrossPlatformSocket(_addressFamily, ProtocolType.Tcp);

// Usage of socket options is described here: https://docs.microsoft.com/en-us/dotnet/api/system.net.sockets.socket.setsocketoption?view=netcore-2.2
if (_options.ReuseAddress)
Expand All @@ -87,33 +87,33 @@ public bool Start(bool treatErrorsAsWarning, CancellationToken cancellationToken
{
_socket.KeepAlive = _options.KeepAlive.Value;
}

if (_options.TcpKeepAliveInterval.HasValue)
{
_socket.TcpKeepAliveInterval = _options.TcpKeepAliveInterval.Value;
}

if (_options.TcpKeepAliveRetryCount.HasValue)
{
_socket.TcpKeepAliveInterval = _options.TcpKeepAliveRetryCount.Value;
}

if (_options.TcpKeepAliveTime.HasValue)
{
_socket.TcpKeepAliveTime = _options.TcpKeepAliveTime.Value;
}

_socket.Bind(_localEndPoint);

// Get the local endpoint back from the socket. The port may have changed.
// This can happen when port 0 is used. Then the OS will choose the next free port.
_localEndPoint = (IPEndPoint)_socket.LocalEndPoint;
_options.Port = _localEndPoint.Port;

_socket.Listen(_options.ConnectionBacklog);

_logger.Verbose("TCP listener started (Endpoint={0})", _localEndPoint);

Task.Run(() => AcceptClientConnectionsAsync(cancellationToken), cancellationToken).RunInBackground(_logger);

return true;
Expand Down Expand Up @@ -183,7 +183,7 @@ async Task TryHandleClientConnectionAsync(CrossPlatformSocket clientSocket)
clientSocket.NoDelay = _options.NoDelay;
stream = clientSocket.GetStream();
var clientCertificate = _tlsOptions?.CertificateProvider?.GetCertificate();

if (clientCertificate != null)
{
if (!clientCertificate.HasPrivateKey)
Expand Down Expand Up @@ -228,7 +228,7 @@ await sslStream.AuthenticateAsServerAsync(
var tcpChannel = new MqttTcpChannel(stream, remoteEndPoint, clientCertificate);
var bufferWriter = new MqttBufferWriter(_serverOptions.WriterBufferSize, _serverOptions.WriterBufferSizeMax);
var packetFormatterAdapter = new MqttPacketFormatterAdapter(bufferWriter);

using (var clientAdapter = new MqttChannelAdapter(tcpChannel, packetFormatterAdapter, _rootLogger))
{
clientAdapter.AllowPacketFragmentation = _options.AllowPacketFragmentation;
Expand Down

0 comments on commit 41d5b70

Please sign in to comment.