Skip to content

Commit

Permalink
Fix client endpoint handling
Browse files Browse the repository at this point in the history
  • Loading branch information
chkr1011 committed May 17, 2024
1 parent 4ea3ac7 commit f3d3b70
Show file tree
Hide file tree
Showing 6 changed files with 183 additions and 280 deletions.
12 changes: 7 additions & 5 deletions Source/MQTTnet.Tests/Internal/CrossPlatformSocket_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -17,8 +19,8 @@ public class CrossPlatformSocket_Tests
[TestMethod]
public async Task Connect_Send_Receive()
{
var crossPlatformSocket = new CrossPlatformSocket();
await crossPlatformSocket.ConnectAsync("www.google.de", 80, CancellationToken.None);
var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp);
await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.google.de", 80), CancellationToken.None);

var requestBuffer = Encoding.UTF8.GetBytes("GET / HTTP/1.1\r\nHost: www.google.de\r\n\r\n");
await crossPlatformSocket.SendAsync(new ArraySegment<byte>(requestBuffer), System.Net.Sockets.SocketFlags.None);
Expand All @@ -36,12 +38,12 @@ public async Task Connect_Send_Receive()
[ExpectedException(typeof(OperationCanceledException))]
public async Task Try_Connect_Invalid_Host()
{
var crossPlatformSocket = new CrossPlatformSocket();
var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp);

var cancellationToken = new CancellationTokenSource(TimeSpan.FromSeconds(5));
cancellationToken.Token.Register(() => crossPlatformSocket.Dispose());

await crossPlatformSocket.ConnectAsync("www.google.de", 54321, cancellationToken.Token);
await crossPlatformSocket.ConnectAsync(new DnsEndPoint("www.google.de", 54321), cancellationToken.Token);
}

//[TestMethod]
Expand All @@ -65,7 +67,7 @@ public async Task Try_Connect_Invalid_Host()
[TestMethod]
public void Set_Options()
{
var crossPlatformSocket = new CrossPlatformSocket();
var crossPlatformSocket = new CrossPlatformSocket(ProtocolType.Tcp);

Assert.IsFalse(crossPlatformSocket.ReuseAddress);
crossPlatformSocket.ReuseAddress = true;
Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet.Tests/MqttTcpChannel_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public async Task Dispose_Channel_While_Used()
}, ct.Token);

var clientSocket = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);
await clientSocket.ConnectAsync("localhost", 50001, CancellationToken.None);
await clientSocket.ConnectAsync(new DnsEndPoint("localhost", 50001), CancellationToken.None);

var tcpChannel = new MqttTcpChannel(clientSocket.GetStream(), "test", null);

Expand Down
5 changes: 3 additions & 2 deletions Source/MQTTnet.Tests/Server/Connection_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Net;
using System.Net.Sockets;
using System.Text;
using System.Threading;
Expand All @@ -24,7 +25,7 @@ public async Task Close_Idle_Connection_On_Connect()
await testEnvironment.StartServer(new MqttServerOptionsBuilder().WithDefaultCommunicationTimeout(TimeSpan.FromSeconds(1)));

var client = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);
await client.ConnectAsync(new DnsEndPoint("localhost", testEnvironment.ServerPort), CancellationToken.None);

// Don't send anything. The server should close the connection.
await Task.Delay(TimeSpan.FromSeconds(3));
Expand Down Expand Up @@ -55,7 +56,7 @@ public async Task Send_Garbage()
// Send an invalid packet and ensure that the server will close the connection and stay in a waiting state
// forever. This is security related.
var client = new CrossPlatformSocket(AddressFamily.InterNetwork, ProtocolType.Tcp);
await client.ConnectAsync("localhost", testEnvironment.ServerPort, CancellationToken.None);
await client.ConnectAsync(new DnsEndPoint("localhost", testEnvironment.ServerPort), CancellationToken.None);

var buffer = Encoding.UTF8.GetBytes("Garbage");
await client.SendAsync(new ArraySegment<byte>(buffer), SocketFlags.None);
Expand Down
39 changes: 34 additions & 5 deletions Source/MQTTnet/Client/Options/MqttClientOptionsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,26 @@ public MqttClientOptions Build()
{
if (_port.HasValue)
{
_remoteEndPoint = new DnsEndPoint(dns.Host, _port.Value);
_remoteEndPoint = new DnsEndPoint(dns.Host, _port.Value, dns.AddressFamily);
}
else
{
_remoteEndPoint = new DnsEndPoint(dns.Host, tlsOptions?.UseTls == false ? MqttPorts.Default : MqttPorts.Secure);
_remoteEndPoint = new DnsEndPoint(dns.Host, tlsOptions?.UseTls == false ? MqttPorts.Default : MqttPorts.Secure, dns.AddressFamily);
}
}
}

if (_remoteEndPoint is IPEndPoint ip)
{
if (ip.Port == 0)
{
if (_port.HasValue)
{
_remoteEndPoint = new IPEndPoint(ip.Address, _port.Value);
}
else
{
_remoteEndPoint = new IPEndPoint(ip.Address, tlsOptions?.UseTls == false ? MqttPorts.Default : MqttPorts.Secure);
}
}
}
Expand Down Expand Up @@ -219,7 +234,7 @@ public MqttClientOptionsBuilder WithNoKeepAlive()
}

/// <summary>
/// Usually the MQTT packets can be send partially. This is done by using multiple TCP packets
/// Usually the MQTT packets can be sent partially. This is done by using multiple TCP packets
/// or WebSocket frames etc. Unfortunately not all brokers (like Amazon Web Services (AWS)) do support this feature and
/// will close the connection when receiving such packets. If such a service is used this flag must
/// be set to _true_.
Expand Down Expand Up @@ -271,13 +286,27 @@ public MqttClientOptionsBuilder WithSessionExpiryInterval(uint sessionExpiryInte
return this;
}

public MqttClientOptionsBuilder WithTcpServer(string server, int? port = null)
public MqttClientOptionsBuilder WithTcpServer(string host, int? port = null, AddressFamily addressFamily = AddressFamily.Unspecified)
{
if (host == null)
{
throw new ArgumentNullException(nameof(host));
}

_tcpOptions = new MqttClientTcpOptions();

// The value 0 will be updated when building the options.
// This a backward compatibility feature.
_remoteEndPoint = new DnsEndPoint(server, port ?? 0);

if (IPAddress.TryParse(host, out var ipAddress))
{
_remoteEndPoint = new IPEndPoint(ipAddress, port ?? 0);
}
else
{
_remoteEndPoint = new DnsEndPoint(host, port ?? 0, addressFamily);
}

_port = port;

return this;
Expand Down
Loading

0 comments on commit f3d3b70

Please sign in to comment.