Skip to content

Commit

Permalink
Fix issue with connection management.
Browse files Browse the repository at this point in the history
  • Loading branch information
chkr1011 committed Aug 17, 2017
1 parent 598ed66 commit c90bdbf
Show file tree
Hide file tree
Showing 7 changed files with 159 additions and 92 deletions.
1 change: 1 addition & 0 deletions Build/MQTTnet.nuspec
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* [Server] Providing the used protocol version of connected clients
* [Client] Added support for protocol version 3.1.0
* [Core] Several minor performance improvements
* [Core] Fixed an issue with connection management (Thanks to wuzhenda; Zuendelmeister)
</releaseNotes>
<copyright>Copyright Christian Kratky 2016-2017</copyright>
<tags>MQTT MQTTClient MQTTServer MQTTBroker Broker NETStandard IoT InternetOfThings Messaging Queue Hardware Arduino</tags>
Expand Down
14 changes: 10 additions & 4 deletions Frameworks/MQTTnet.NetFramework/Implementations/MqttTcpChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ namespace MQTTnet.Implementations
{
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{
private readonly Socket _socket;
private Socket _socket;
private SslStream _sslStream;

public MqttTcpChannel()
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

public MqttTcpChannel(Socket socket, SslStream sslStream)
Expand All @@ -31,6 +30,11 @@ public async Task ConnectAsync(MqttClientOptions options)
if (options == null) throw new ArgumentNullException(nameof(options));
try
{
if (_socket == null)
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

await Task.Factory.FromAsync(_socket.BeginConnect, _socket.EndConnect, options.Server, options.GetPort(), null);

if (options.TlsOptions.UseTls)
Expand All @@ -49,8 +53,7 @@ public Task DisconnectAsync()
{
try
{
_sslStream.Dispose();
_socket.Dispose();
Dispose();
return Task.FromResult(0);
}
catch (SocketException exception)
Expand Down Expand Up @@ -108,6 +111,9 @@ public void Dispose()
{
_socket?.Dispose();
_sslStream?.Dispose();

_socket = null;
_sslStream = null;
}

private static X509CertificateCollection LoadCertificates(MqttClientOptions options)
Expand Down
14 changes: 10 additions & 4 deletions Frameworks/MQTTnet.NetStandard/Implementations/MqttTcpChannel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@ namespace MQTTnet.Implementations
{
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{
private readonly Socket _socket;
private Socket _socket;
private SslStream _sslStream;

public MqttTcpChannel()
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

public MqttTcpChannel(Socket socket, SslStream sslStream)
Expand All @@ -31,6 +30,11 @@ public async Task ConnectAsync(MqttClientOptions options)
if (options == null) throw new ArgumentNullException(nameof(options));
try
{
if (_socket == null)
{
_socket = new Socket(SocketType.Stream, ProtocolType.Tcp);
}

await _socket.ConnectAsync(options.Server, options.GetPort());

if (options.TlsOptions.UseTls)
Expand All @@ -49,8 +53,7 @@ public Task DisconnectAsync()
{
try
{
_sslStream.Dispose();
_socket.Dispose();
Dispose();
return Task.FromResult(0);
}
catch (SocketException exception)
Expand Down Expand Up @@ -101,6 +104,9 @@ public void Dispose()
{
_socket?.Dispose();
_sslStream?.Dispose();

_socket = null;
_sslStream = null;
}

private static X509CertificateCollection LoadCertificates(MqttClientOptions options)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@ namespace MQTTnet.Implementations
{
public sealed class MqttTcpChannel : IMqttCommunicationChannel, IDisposable
{
private readonly StreamSocket _socket;
private StreamSocket _socket;

public MqttTcpChannel()
{
_socket = new StreamSocket();
}

public MqttTcpChannel(StreamSocket socket)
Expand All @@ -32,6 +31,11 @@ public async Task ConnectAsync(MqttClientOptions options)
if (options == null) throw new ArgumentNullException(nameof(options));
try
{
if (_socket == null)
{
_socket = new StreamSocket();
}

if (!options.TlsOptions.UseTls)
{
await _socket.ConnectAsync(new HostName(options.Server), options.GetPort().ToString());
Expand Down Expand Up @@ -59,7 +63,7 @@ public Task DisconnectAsync()
{
try
{
_socket.Dispose();
Dispose();
return Task.FromResult(0);
}
catch (SocketException exception)
Expand Down Expand Up @@ -100,6 +104,8 @@ public async Task ReadAsync(byte[] buffer)
public void Dispose()
{
_socket?.Dispose();

_socket = null;
}

private static Certificate LoadCertificate(MqttClientOptions options)
Expand Down
63 changes: 33 additions & 30 deletions MQTTnet.Core/Adapter/MqttChannelCommunicationAdapter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,7 @@ public MqttChannelCommunicationAdapter(IMqttCommunicationChannel channel, IMqttP

public async Task ConnectAsync(MqttClientOptions options, TimeSpan timeout)
{
var task = _channel.ConnectAsync(options);
if (await Task.WhenAny(Task.Delay(timeout), task) != task)
{
throw new MqttCommunicationTimedOutException();
}
await ExecuteWithTimeoutAsync(_channel.ConnectAsync(options), timeout);
}

public async Task DisconnectAsync()
Expand All @@ -39,38 +35,15 @@ public async Task SendPacketAsync(MqttBasePacket packet, TimeSpan timeout)
{
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), $"TX >>> {packet} [Timeout={timeout}]");

bool hasTimeout;
try
{
var task = PacketSerializer.SerializeAsync(packet, _channel);
hasTimeout = await Task.WhenAny(Task.Delay(timeout), task) != task;
}
catch (Exception exception)
{
throw new MqttCommunicationException(exception);
}

if (hasTimeout)
{
throw new MqttCommunicationTimedOutException();
}
await ExecuteWithTimeoutAsync(PacketSerializer.SerializeAsync(packet, _channel), timeout);
}

public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout)
{
MqttBasePacket packet;
if (timeout > TimeSpan.Zero)
{
var workerTask = PacketSerializer.DeserializeAsync(_channel);
var timeoutTask = Task.Delay(timeout);
var hasTimeout = Task.WhenAny(timeoutTask, workerTask) == timeoutTask;

if (hasTimeout)
{
throw new MqttCommunicationTimedOutException();
}

packet = workerTask.Result;
packet = await ExecuteWithTimeoutAsync(PacketSerializer.DeserializeAsync(_channel), timeout);
}
else
{
Expand All @@ -85,5 +58,35 @@ public async Task<MqttBasePacket> ReceivePacketAsync(TimeSpan timeout)
MqttTrace.Information(nameof(MqttChannelCommunicationAdapter), $"RX <<< {packet}");
return packet;
}

private static async Task<TResult> ExecuteWithTimeoutAsync<TResult>(Task<TResult> task, TimeSpan timeout)
{
var timeoutTask = Task.Delay(timeout);
if (await Task.WhenAny(timeoutTask, task) == timeoutTask)
{
throw new MqttCommunicationTimedOutException();
}

if (task.IsFaulted)
{
throw new MqttCommunicationException(task.Exception);
}

return task.Result;
}

private static async Task ExecuteWithTimeoutAsync(Task task, TimeSpan timeout)
{
var timeoutTask = Task.Delay(timeout);
if (await Task.WhenAny(timeoutTask, task) == timeoutTask)
{
throw new MqttCommunicationTimedOutException();
}

if (task.IsFaulted)
{
throw new MqttCommunicationException(task.Exception);
}
}
}
}
Loading

0 comments on commit c90bdbf

Please sign in to comment.