diff --git a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs index 1c2ecdf10..7476ffea8 100644 --- a/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs +++ b/Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs @@ -3,6 +3,8 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; +using System.Net.Sockets; using System.Text; using System.Threading; using System.Threading.Tasks; @@ -17,8 +19,8 @@ public class CrossPlatformSocket_Tests [TestMethod] public async Task Connect_Send_Receive() { - var crossPlatformSocket = new CrossPlatformSocket(); - await crossPlatformSocket.ConnectAsync("www.google.de", 80, CancellationToken.None); + var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp); + await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.google.de", 80), CancellationToken.None); var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n"); await crossPlatformSocket.SendAsync(new ArraySegment(requestBuffer), System.Net.Sockets.SocketFlags.None); @@ -36,12 +38,12 @@ public async Task Connect_Send_Receive() [ExpectedException(typeof(OperationCanceledException))] public async Task Try_Connect_Invalid_Host() { - var crossPlatformSocket = new CrossPlatformSocket(); + var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp); var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(5)); cancellationToken.Token.Register(() => crossPlatformSocket.Dispose()); - await crossPlatformSocket.ConnectAsync("www.google.de", 54321, cancellationToken.Token); + await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.google.de", 54321), cancellationToken.Token); } //[TestMethod] @@ -65,7 +67,7 @@ public async Task Try_Connect_Invalid_Host() [TestMethod] public void Set_Options() { - var crossPlatformSocket = new CrossPlatformSocket(); + var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp); Assert.IsFalse(crossPlatformSocket.ReuseAddress); crossPlatformSocket.ReuseAddress = true; diff --git a/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs b/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs index 533cbb2bb..7b6276642 100644 --- a/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs +++ b/Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs @@ -39,7 +39,7 @@ public async Task Dispose_Channel_While_Used() }, ct.Token); var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None); + await clientSocket.ConnectAsync(new DnsEndPoint("localhost", 50001), CancellationToken.None); var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null); diff --git a/Source/MQTTnet.Tests/Server/Connection_Tests.cs b/Source/MQTTnet.Tests/Server/Connection_Tests.cs index ba17ae0b9..ef2a482c4 100644 --- a/Source/MQTTnet.Tests/Server/Connection_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Connection_Tests.cs @@ -3,6 +3,7 @@ // See the LICENSE file in the project root for more information. using System; +using System.Net; using System.Net.Sockets; using System.Text; using System.Threading; @@ -24,7 +25,7 @@ public async Task Close_Idle_Connection_On_Connect() await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1))); var client = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); + await client.ConnectAsync(new DnsEndPoint("localhost", testEnvironment.ServerPort), CancellationToken.None); // Don't send anything. The server should close the connection. await Task.Delay(TimeSpan.FromSeconds(3)); @@ -55,7 +56,7 @@ public async Task Send_Garbage() // Send an invalid packet and ensure that the server will close the connection and stay in a waiting state // forever. This is security related. var client = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp); - await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None); + await client.ConnectAsync(new DnsEndPoint("localhost", testEnvironment.ServerPort), CancellationToken.None); var buffer = Encoding.UTF8.GetBytes("Garbage"); await client.SendAsync(new ArraySegment(buffer), SocketFlags.None); diff --git a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs index fb8a56f28..bf3b14797 100644 --- a/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs +++ b/Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs @@ -53,11 +53,26 @@ public MqttClientOptions Build() { if (_port.HasValue) { - _remoteEndPoint = new DnsEndPoint(dns.Host, _port.Value); + _remoteEndPoint = new DnsEndPoint(dns.Host, _port.Value, dns.AddressFamily); } else { - _remoteEndPoint = new DnsEndPoint(dns.Host, tlsOptions?.UseTls == false ? MqttPorts.Default : MqttPorts.Secure); + _remoteEndPoint = new DnsEndPoint(dns.Host, tlsOptions?.UseTls == false ? MqttPorts.Default : MqttPorts.Secure, dns.AddressFamily); + } + } + } + + if (_remoteEndPoint is IPEndPoint ip) + { + if (ip.Port == 0) + { + if (_port.HasValue) + { + _remoteEndPoint = new IPEndPoint(ip.Address, _port.Value); + } + else + { + _remoteEndPoint = new IPEndPoint(ip.Address, tlsOptions?.UseTls == false ? MqttPorts.Default : MqttPorts.Secure); } } } @@ -219,7 +234,7 @@ public MqttClientOptionsBuilder WithNoKeepAlive() } /// - /// Usually the MQTT packets can be send partially. This is done by using multiple TCP packets + /// Usually the MQTT packets can be sent partially. This is done by using multiple TCP packets /// or WebSocket frames etc. Unfortunately not all brokers (like Amazon Web Services (AWS)) do support this feature and /// will close the connection when receiving such packets. If such a service is used this flag must /// be set to _true_. @@ -271,13 +286,27 @@ public MqttClientOptionsBuilder WithSessionExpiryInterval(uint sessionExpiryInte return this; } - public MqttClientOptionsBuilder WithTcpServer(string server, int? port = null) + public MqttClientOptionsBuilder WithTcpServer(string host, int? port = null, AddressFamily addressFamily = AddressFamily.Unspecified) { + if (host == null) + { + throw new ArgumentNullException(nameof(host)); + } + _tcpOptions = new MqttClientTcpOptions(); // The value 0 will be updated when building the options. // This a backward compatibility feature. - _remoteEndPoint = new DnsEndPoint(server, port ?? 0); + + if (IPAddress.TryParse(host, out var ipAddress)) + { + _remoteEndPoint = new IPEndPoint(ipAddress, port ?? 0); + } + else + { + _remoteEndPoint = new DnsEndPoint(host, port ?? 0, addressFamily); + } + _port = port; return this; diff --git a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs index c3d29c169..96f90dc24 100644 --- a/Source/MQTTnet/Implementations/CrossPlatformSocket.cs +++ b/Source/MQTTnet/Implementations/CrossPlatformSocket.cs @@ -10,326 +10,197 @@ using System.Threading.Tasks; using MQTTnet.Exceptions; -namespace MQTTnet.Implementations +namespace MQTTnet.Implementations; + +public sealed class CrossPlatformSocket : IDisposable { - public sealed class CrossPlatformSocket : IDisposable + readonly Socket _socket; + + NetworkStream _networkStream; + + public CrossPlatformSocket(AddressFamily addressFamily, ProtocolType protocolType) { - readonly Socket _socket; + _socket = new Socket(addressFamily, SocketType.Stream, protocolType); + } -#if !NET5_0_OR_GREATER - readonly Action _socketDisposeAction; -#endif + public CrossPlatformSocket(ProtocolType protocolType) + { + // Having this constructor is important because avoiding the address family as parameter + // will make use of dual mode in the .net framework. + _socket = new Socket(SocketType.Stream, protocolType); + } - NetworkStream _networkStream; + CrossPlatformSocket(Socket socket) + { + _socket = socket ?? throw new ArgumentNullException(nameof(socket)); + _networkStream = new NetworkStream(socket, true); + } - public CrossPlatformSocket(AddressFamily addressFamily, ProtocolType protocolType) - { - _socket = new Socket(addressFamily, SocketType.Stream, protocolType); + public bool DualMode + { + get => _socket.DualMode; + set => _socket.DualMode = value; + } -#if !NET5_0_OR_GREATER - _socketDisposeAction = _socket.Dispose; -#endif - } + public bool IsConnected => _socket.Connected; - public CrossPlatformSocket() - { - // Having this constructor is important because avoiding the address family as parameter - // will make use of dual mode in the .net framework. - _socket = new Socket(SocketType.Stream, ProtocolType.Tcp); + public bool KeepAlive + { + get => _socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive) as int? == 1; + set => _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, value ? 1 : 0); + } -#if !NET5_0_OR_GREATER - _socketDisposeAction = _socket.Dispose; -#endif - } + public LingerOption LingerState + { + get => _socket.LingerState; + set => _socket.LingerState = value; + } - CrossPlatformSocket(Socket socket) - { - _socket = socket ?? throw new ArgumentNullException(nameof(socket)); - _networkStream = new NetworkStream(socket, true); + public EndPoint LocalEndPoint => _socket.LocalEndPoint; -#if !NET5_0_OR_GREATER - _socketDisposeAction = _socket.Dispose; -#endif - } + public bool NoDelay + { + get => (int?)_socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay) != 0; + set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, value ? 1 : 0); + } - public bool DualMode - { - get => _socket.DualMode; - set => _socket.DualMode = value; - } + public int ReceiveBufferSize + { + get => _socket.ReceiveBufferSize; + set => _socket.ReceiveBufferSize = value; + } - public bool IsConnected => _socket.Connected; + public EndPoint RemoteEndPoint => _socket.RemoteEndPoint; - public bool KeepAlive - { - get => _socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive) as int? == 1; - set => _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.KeepAlive, value ? 1 : 0); - } + public bool ReuseAddress + { + get => _socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress) as int? != 0; + set => _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, value ? 1 : 0); + } - public int TcpKeepAliveInterval - { -#if NETCOREAPP3_0_OR_GREATER - get => _socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval) as int? ?? 0; - set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, value); -#else - get { throw new NotSupportedException("TcpKeepAliveInterval requires at least netcoreapp3.0."); } - set { throw new NotSupportedException("TcpKeepAliveInterval requires at least netcoreapp3.0."); } -#endif - } + public int SendBufferSize + { + get => _socket.SendBufferSize; + set => _socket.SendBufferSize = value; + } - public int TcpKeepAliveRetryCount - { -#if NETCOREAPP3_0_OR_GREATER - get => _socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveRetryCount) as int? ?? 0; - set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveRetryCount, value); -#else - get { throw new NotSupportedException("TcpKeepAliveRetryCount requires at least netcoreapp3.0."); } - set { throw new NotSupportedException("TcpKeepAliveRetryCount requires at least netcoreapp3.0."); } -#endif - } + public int SendTimeout + { + get => _socket.SendTimeout; + set => _socket.SendTimeout = value; + } - public int TcpKeepAliveTime - { -#if NETCOREAPP3_0_OR_GREATER - get => _socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime) as int? ?? 0; - set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, value); -#else - get { throw new NotSupportedException("TcpKeepAliveTime requires at least netcoreapp3.0."); } - set { throw new NotSupportedException("TcpKeepAliveTime requires at least netcoreapp3.0."); } -#endif - } + public int TcpKeepAliveInterval + { + get => _socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval) as int? ?? 0; + set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveInterval, value); + } - public LingerOption LingerState - { - get => _socket.LingerState; - set => _socket.LingerState = value; - } + public int TcpKeepAliveRetryCount + { + get => _socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveRetryCount) as int? ?? 0; + set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveRetryCount, value); + } - public EndPoint LocalEndPoint => _socket.LocalEndPoint; + public int TcpKeepAliveTime + { + get => _socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime) as int? ?? 0; + set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.TcpKeepAliveTime, value); + } - public bool NoDelay + public async Task AcceptAsync() + { + try { - // We cannot use the _NoDelay_ property from the socket because there is an issue in .NET 4.5.2, 4.6. - // The decompiled code is: this.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.Debug, value ? 1 : 0); - // Which is wrong because the "NoDelay" should be set and not "Debug". - get => (int?)_socket.GetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay) != 0; - set => _socket.SetSocketOption(SocketOptionLevel.Tcp, SocketOptionName.NoDelay, value ? 1 : 0); + var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); + return new CrossPlatformSocket(clientSocket); } - - public int ReceiveBufferSize + catch (ObjectDisposedException) { - get => _socket.ReceiveBufferSize; - set => _socket.ReceiveBufferSize = value; + // This will happen when _socket.EndAccept_ gets called by Task library but the socket is already disposed. + return null; } + } - public EndPoint RemoteEndPoint => _socket.RemoteEndPoint; - - public bool ReuseAddress + public void Bind(EndPoint localEndPoint) + { + if (localEndPoint is null) { - get => _socket.GetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress) as int? != 0; - set => _socket.SetSocketOption(SocketOptionLevel.Socket, SocketOptionName.ReuseAddress, value ? 1 : 0); + throw new ArgumentNullException(nameof(localEndPoint)); } - public int SendBufferSize - { - get => _socket.SendBufferSize; - set => _socket.SendBufferSize = value; - } + _socket.Bind(localEndPoint); + } - public int SendTimeout + public async Task ConnectAsync(EndPoint endPoint, CancellationToken cancellationToken) + { + if (endPoint is null) { - get => _socket.SendTimeout; - set => _socket.SendTimeout = value; + throw new ArgumentNullException(nameof(endPoint)); } - public async Task AcceptAsync() - { - try - { -#if NET452 || NET461 - var clientSocket = await Task.Factory.FromAsync(_socket.BeginAccept, _socket.EndAccept, null).ConfigureAwait(false); -#else - var clientSocket = await _socket.AcceptAsync().ConfigureAwait(false); -#endif - return new CrossPlatformSocket(clientSocket); - } - catch (ObjectDisposedException) - { - // This will happen when _socket.EndAccept_ gets called by Task library but the socket is already disposed. - return null; - } - } + cancellationToken.ThrowIfCancellationRequested(); - public void Bind(EndPoint localEndPoint) + try { - if (localEndPoint is null) + if (_networkStream != null) { - throw new ArgumentNullException(nameof(localEndPoint)); + await _networkStream.DisposeAsync().ConfigureAwait(false); } - _socket.Bind(localEndPoint); + await _socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); + _networkStream = new NetworkStream(_socket, true); } - - public Task ConnectAsync(string host, int port, CancellationToken cancellationToken) - { - return ConnectAsync(new DnsEndPoint(host, port), cancellationToken); - } - - public async Task ConnectAsync(EndPoint endPoint, CancellationToken cancellationToken) + catch (SocketException socketException) { - if (endPoint is null) + if (socketException.SocketErrorCode == SocketError.OperationAborted) { - throw new ArgumentNullException(nameof(endPoint)); + throw new OperationCanceledException(); } - cancellationToken.ThrowIfCancellationRequested(); - - try + if (socketException.SocketErrorCode == SocketError.TimedOut) { -#if NETCOREAPP3_0_OR_GREATER - if (_networkStream != null) - { - await _networkStream.DisposeAsync().ConfigureAwait(false); - } -#else - _networkStream?.Dispose(); -#endif - -#if NET5_0_OR_GREATER - await _socket.ConnectAsync(endPoint, cancellationToken).ConfigureAwait(false); -#else - // Workaround for: https://github.com/dotnet/corefx/issues/24430 - using (cancellationToken.Register(_socketDisposeAction)) - { -#if NET452 || NET461 - await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, endPoint, null).ConfigureAwait(false); -#else - await _socket.ConnectAsync(endPoint).ConfigureAwait(false); -#endif - } -#endif - _networkStream = new NetworkStream(_socket, true); + throw new MqttCommunicationTimedOutException(); } - catch (SocketException socketException) - { - if (socketException.SocketErrorCode == SocketError.OperationAborted) - { - throw new OperationCanceledException(); - } - - if (socketException.SocketErrorCode == SocketError.TimedOut) - { - throw new MqttCommunicationTimedOutException(); - } - throw new MqttCommunicationException($"Error while connecting host '{endPoint}'.", socketException); - } - catch (ObjectDisposedException) - { - // This will happen when _socket.EndConnect_ gets called by Task library but the socket is already disposed. - cancellationToken.ThrowIfCancellationRequested(); - } + throw new MqttCommunicationException($"Error while connecting host '{endPoint}'.", socketException); } - - public void Dispose() - { - _networkStream?.Dispose(); - _socket?.Dispose(); - } - - public NetworkStream GetStream() - { - var networkStream = _networkStream; - if (networkStream == null) - { - throw new IOException("The socket is not connected."); - } - - return networkStream; - } - - public void Listen(int connectionBacklog) + catch (ObjectDisposedException) { - _socket.Listen(connectionBacklog); + // This will happen when _socket.EndConnect_ gets called by Task library but the socket is already disposed. + cancellationToken.ThrowIfCancellationRequested(); } + } -#if NET452 || NET461 - public async Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags) - { - try - { - return await Task.Factory.FromAsync(SocketWrapper.BeginReceive, _socket.EndReceive, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false); - } - catch (ObjectDisposedException) - { - // This will happen when _socket.EndReceive_ gets called by Task library but the socket is already disposed. - return -1; - } - } -#else - public Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags) - { - return _socket.ReceiveAsync(buffer, socketFlags); - } -#endif + public void Dispose() + { + _networkStream?.Dispose(); + _socket?.Dispose(); + } -#if NET452 || NET461 - public async Task SendAsync(ArraySegment buffer, SocketFlags socketFlags) - { - try - { - await Task.Factory.FromAsync(SocketWrapper.BeginSend, _socket.EndSend, new SocketWrapper(_socket, buffer, socketFlags)).ConfigureAwait(false); - } - catch (ObjectDisposedException) - { - // This will happen when _socket.EndSend_ gets called by Task library but the socket is already disposed. - } - } -#else - public Task SendAsync(ArraySegment buffer, SocketFlags socketFlags) + public NetworkStream GetStream() + { + var networkStream = _networkStream; + if (networkStream == null) { - return _socket.SendAsync(buffer, socketFlags); + throw new IOException("The socket is not connected."); } -#endif -#if NET452 || NET461 - sealed class SocketWrapper - { - readonly ArraySegment _buffer; - readonly Socket _socket; - readonly SocketFlags _socketFlags; + return networkStream; + } - public SocketWrapper(Socket socket, ArraySegment buffer, SocketFlags socketFlags) - { - _socket = socket; - _buffer = buffer; - _socketFlags = socketFlags; - } + public void Listen(int connectionBacklog) + { + _socket.Listen(connectionBacklog); + } - public static IAsyncResult BeginReceive(AsyncCallback callback, object state) - { - var socketWrapper = (SocketWrapper)state; - return socketWrapper._socket.BeginReceive( - socketWrapper._buffer.Array, - socketWrapper._buffer.Offset, - socketWrapper._buffer.Count, - socketWrapper._socketFlags, - callback, - state); - } + public Task ReceiveAsync(ArraySegment buffer, SocketFlags socketFlags) + { + return _socket.ReceiveAsync(buffer, socketFlags); + } - public static IAsyncResult BeginSend(AsyncCallback callback, object state) - { - var socketWrapper = (SocketWrapper)state; - return socketWrapper._socket.BeginSend( - socketWrapper._buffer.Array, - socketWrapper._buffer.Offset, - socketWrapper._buffer.Count, - socketWrapper._socketFlags, - callback, - state); - } - } -#endif + public Task SendAsync(ArraySegment buffer, SocketFlags socketFlags) + { + return _socket.SendAsync(buffer, socketFlags); } } \ No newline at end of file diff --git a/Source/MQTTnet/Implementations/MqttTcpChannel.cs b/Source/MQTTnet/Implementations/MqttTcpChannel.cs index af979c004..10164f60c 100644 --- a/Source/MQTTnet/Implementations/MqttTcpChannel.cs +++ b/Source/MQTTnet/Implementations/MqttTcpChannel.cs @@ -62,7 +62,7 @@ public async Task ConnectAsync(CancellationToken cancellationToken) { if (_tcpOptions.AddressFamily == AddressFamily.Unspecified) { - socket = new CrossPlatformSocket(); + socket = new CrossPlatformSocket(_tcpOptions.ProtocolType); } else {