Skip to content

Commit

Permalink
Implement enhanced authentication for client and server.
Browse files Browse the repository at this point in the history
  • Loading branch information
chkr1011 committed Oct 27, 2024
1 parent 63d8518 commit 7784d58
Show file tree
Hide file tree
Showing 17 changed files with 544 additions and 409 deletions.
372 changes: 203 additions & 169 deletions Source/MQTTnet.Server/Events/ValidatingConnectionEventArgs.cs

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Source/MQTTnet.Server/Internal/MqttClientSessionsManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ public async Task HandleClientConnectionAsync(IMqttChannelAdapter channelAdapter
return;
}

var validatingConnectionEventArgs = await ValidateConnection(connectPacket, channelAdapter).ConfigureAwait(false);
var validatingConnectionEventArgs = await ValidateConnection(connectPacket, channelAdapter, cancellationToken).ConfigureAwait(false);
var connAckPacket = MqttConnAckPacketFactory.Create(validatingConnectionEventArgs);

if (validatingConnectionEventArgs.ReasonCode != MqttConnectReasonCode.Success)
Expand Down Expand Up @@ -710,11 +710,11 @@ static bool ShouldPersistSession(MqttConnectedClient connectedClient)
}
}

async Task<ValidatingConnectionEventArgs> ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter)
async Task<ValidatingConnectionEventArgs> ValidateConnection(MqttConnectPacket connectPacket, IMqttChannelAdapter channelAdapter, CancellationToken cancellationToken)
{
// TODO: Load session items from persisted sessions in the future.
var sessionItems = new ConcurrentDictionary<object, object>();
var eventArgs = new ValidatingConnectionEventArgs(connectPacket, channelAdapter, sessionItems);
var eventArgs = new ValidatingConnectionEventArgs(connectPacket, channelAdapter, sessionItems, cancellationToken);
await _eventContainer.ValidatingConnectionEvent.InvokeAsync(eventArgs).ConfigureAwait(false);

// Check the client ID and set a random one if supported.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

using System.Collections.ObjectModel;
using MQTTnet.Diagnostics.Logger;
using MQTTnet.Internal;

Expand Down
276 changes: 156 additions & 120 deletions Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Connection_Tests.cs

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions Source/MQTTnet.Tests/MQTTnet.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@

<ItemGroup>
<PackageReference Include="Microsoft.IO.RecyclableMemoryStream" Version="3.0.1" />
<PackageReference Include="MSTest.TestAdapter" Version="3.3.1" />
<PackageReference Include="MSTest.TestFramework" Version="3.3.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.9.0" />
<PackageReference Include="MSTest.TestAdapter" Version="3.6.1" />
<PackageReference Include="MSTest.TestFramework" Version="3.6.1" />
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.11.1" />
<FrameworkReference Include="Microsoft.AspNetCore.App" />
</ItemGroup>

Expand Down

This file was deleted.

35 changes: 17 additions & 18 deletions Source/MQTTnet/Formatter/ReadFixedHeaderResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,23 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

namespace MQTTnet.Formatter
namespace MQTTnet.Formatter;

public struct ReadFixedHeaderResult
{
public struct ReadFixedHeaderResult
public static ReadFixedHeaderResult Canceled { get; } = new()
{
IsCanceled = true
};

public static ReadFixedHeaderResult ConnectionClosed { get; } = new()
{
public static ReadFixedHeaderResult Canceled { get; } = new ReadFixedHeaderResult
{
IsCanceled = true
};

public static ReadFixedHeaderResult ConnectionClosed { get; } = new ReadFixedHeaderResult
{
IsConnectionClosed = true
};

public bool IsCanceled { get; set; }

public bool IsConnectionClosed { get; set; }
IsConnectionClosed = true
};

public bool IsCanceled { get; set; }

public bool IsConnectionClosed { get; init; }

public MqttFixedHeader FixedHeader { get; set; }
}
}
public MqttFixedHeader FixedHeader { get; init; }
}
2 changes: 1 addition & 1 deletion Source/MQTTnet/Formatter/V5/MqttV5PacketDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace MQTTnet.Formatter.V5
{
public sealed class MqttV5PacketDecoder
{
readonly MqttBufferReader _bufferReader = new MqttBufferReader();
readonly MqttBufferReader _bufferReader = new();

public MqttPacket Decode(ReceivedMqttPacket receivedMqttPacket)
{
Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet/IMqttClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public interface IMqttClient : IDisposable

Task<MqttClientPublishResult> PublishAsync(MqttApplicationMessage applicationMessage, CancellationToken cancellationToken = default);

Task SendExtendedAuthenticationExchangeDataAsync(MqttExtendedAuthenticationExchangeData data, CancellationToken cancellationToken = default);
Task SendEnhancedAuthenticationExchangeDataAsync(MqttEnhancedAuthenticationExchangeData data, CancellationToken cancellationToken = default);

Task<MqttClientSubscribeResult> SubscribeAsync(MqttClientSubscribeOptions options, CancellationToken cancellationToken = default);

Expand Down
60 changes: 41 additions & 19 deletions Source/MQTTnet/MqttClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,7 @@ public Task<MqttClientPublishResult> PublishAsync(MqttApplicationMessage applica
}
}

public Task SendExtendedAuthenticationExchangeDataAsync(MqttExtendedAuthenticationExchangeData data, CancellationToken cancellationToken = default)
public Task SendEnhancedAuthenticationExchangeDataAsync(MqttEnhancedAuthenticationExchangeData data, CancellationToken cancellationToken = default)
{
ArgumentNullException.ThrowIfNull(data);

Expand Down Expand Up @@ -437,27 +437,30 @@ async Task<MqttClientConnectResult> Authenticate(IMqttChannelAdapter channelAdap
var connectPacket = MqttConnectPacketFactory.Create(options);
await Send(connectPacket, cancellationToken).ConfigureAwait(false);

var receivedPacket = await Receive(cancellationToken).ConfigureAwait(false);

switch (receivedPacket)
while (true)
{
case MqttConnAckPacket connAckPacket:
{
result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion);
break;
}
case MqttAuthPacket _:
cancellationToken.ThrowIfCancellationRequested();

var receivedPacket = await Receive(cancellationToken).ConfigureAwait(false);

if (receivedPacket is MqttAuthPacket authPacket)
{
throw new NotSupportedException("Extended authentication handler is not yet supported");
await HandleEnhancedAuthentication(authPacket);
continue;
}
case null:

if (receivedPacket is MqttConnAckPacket connAckPacket)
{
throw new MqttCommunicationException("Connection closed.");
result = MqttClientResultFactory.ConnectResult.Create(connAckPacket, channelAdapter.PacketFormatterAdapter.ProtocolVersion);
break;
}
default:

if (receivedPacket != null)
{
throw new InvalidOperationException($"Received an unexpected MQTT packet ({receivedPacket}).");
throw new MqttProtocolViolationException($"Received other packet than CONNACK or AUTH while connecting ({receivedPacket}).");
}

throw new MqttCommunicationException("Connection closed.");
}
}
catch (Exception exception)
Expand All @@ -470,6 +473,12 @@ async Task<MqttClientConnectResult> Authenticate(IMqttChannelAdapter channelAdap
return result;
}

async Task HandleEnhancedAuthentication(MqttAuthPacket authPacket)
{
var eventArgs = new MqttEnhancedAuthenticationEventArgs(authPacket, _adapter);
await Options.EnhancedAuthenticationHandler.HandleEnhancedAuthenticationAsync(eventArgs);
}

void Cleanup()
{
try
Expand Down Expand Up @@ -648,10 +657,23 @@ Task OnConnected(MqttClientConnectResult connectResult)

Task ProcessReceivedAuthPacket(MqttAuthPacket authPacket)
{
var extendedAuthenticationExchangeHandler = Options.ExtendedAuthenticationExchangeHandler;
return extendedAuthenticationExchangeHandler != null
? extendedAuthenticationExchangeHandler.HandleRequestAsync(new MqttExtendedAuthenticationExchangeContext(authPacket, this))
: CompletedTask.Instance;
if (Options.EnhancedAuthenticationHandler == null)
{
// From RFC: If the re-authentication fails, the Client or Server SHOULD send DISCONNECT with an appropriate Reason Code
// as described in section 4.13, and MUST close the Network Connection [MQTT-4.12.1-2].
//
// Since we have no handler there is no chance to fulfil the re-authentication request.
_ = DisconnectAsync(new MqttClientDisconnectOptions
{
Reason = MqttClientDisconnectOptionsReason.ImplementationSpecificError,
ReasonString = "Unable to handle AUTH packet"
});

return CompletedTask.Instance;
}

var eventArgs = new MqttEnhancedAuthenticationEventArgs(authPacket, _adapter);
return Options.EnhancedAuthenticationHandler.HandleEnhancedAuthenticationAsync(eventArgs);
}

Task ProcessReceivedDisconnectPacket(MqttDisconnectPacket disconnectPacket)
Expand Down
4 changes: 2 additions & 2 deletions Source/MQTTnet/MqttClientExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ public static Task ReconnectAsync(this IMqttClient client, CancellationToken can
return client.ConnectAsync(client.Options, cancellationToken);
}

public static Task SendExtendedAuthenticationExchangeDataAsync(this IMqttClient client, MqttExtendedAuthenticationExchangeData data)
public static Task SendEnhancedAuthenticationExchangeDataAsync(this IMqttClient client, MqttEnhancedAuthenticationExchangeData data)
{
ArgumentNullException.ThrowIfNull(client);

return client.SendExtendedAuthenticationExchangeDataAsync(data, CancellationToken.None);
return client.SendEnhancedAuthenticationExchangeDataAsync(data, CancellationToken.None);
}

public static Task<MqttClientSubscribeResult> SubscribeAsync(this IMqttClient mqttClient, MqttTopicFilter topicFilter, CancellationToken cancellationToken = default)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

namespace MQTTnet;

public interface IMqttExtendedAuthenticationExchangeHandler
public interface IMqttEnhancedAuthenticationHandler
{
Task HandleRequestAsync(MqttExtendedAuthenticationExchangeContext context);
Task HandleEnhancedAuthenticationAsync(MqttEnhancedAuthenticationEventArgs eventArgs);
}
13 changes: 9 additions & 4 deletions Source/MQTTnet/Options/MqttClientOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ namespace MQTTnet;
public sealed class MqttClientOptions
{
/// <summary>
/// Usually the MQTT packets can be send partially. This is done by using multiple TCP packets
/// Usually the MQTT packets can be sent partially. This is done by using multiple TCP packets
/// or WebSocket frames etc. Unfortunately not all brokers (like Amazon Web Services (AWS)) do support this feature and
/// will close the connection when receiving such packets. If such a service is used this flag must
/// be set to _false_.
Expand All @@ -38,7 +38,7 @@ public sealed class MqttClientOptions
/// Gets or sets a value indicating whether clean sessions are used or not.
/// When a client connects to a broker it can connect using either a non persistent connection (clean session) or a
/// persistent connection.
/// With a non persistent connection the broker doesn't store any subscription information or undelivered messages for
/// With a non-persistent connection the broker doesn't store any subscription information or undelivered messages for
/// the client.
/// This mode is ideal when the client only publishes messages.
/// It can also connect as a durable client using a persistent connection.
Expand All @@ -54,7 +54,12 @@ public sealed class MqttClientOptions

public IMqttClientCredentialsProvider Credentials { get; set; }

public IMqttExtendedAuthenticationExchangeHandler ExtendedAuthenticationExchangeHandler { get; set; }
/// <summary>
/// Gets or sets the handler for AUTH packets.
/// This can happen when connecting or at any time while being already connected.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public IMqttEnhancedAuthenticationHandler EnhancedAuthenticationHandler { get; set; }

/// <summary>
/// Gets or sets the keep alive period.
Expand All @@ -80,7 +85,7 @@ public sealed class MqttClientOptions

/// <summary>
/// Gets or sets the receive maximum.
/// This gives the maximum length of the receive messages.
/// This gives the maximum length of the received messages.
/// <remarks>MQTT 5.0.0+ feature.</remarks>
/// </summary>
public ushort ReceiveMaximum { get; set; }
Expand Down
6 changes: 3 additions & 3 deletions Source/MQTTnet/Options/MqttClientOptionsBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ public MqttClientOptionsBuilder WithAddressFamily(AddressFamily addressFamily)
return this;
}

public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data)
public MqttClientOptionsBuilder WithAuthentication(string method, byte[] data = null)
{
_options.AuthenticationMethod = method;
_options.AuthenticationData = data;
Expand Down Expand Up @@ -205,9 +205,9 @@ public MqttClientOptionsBuilder WithEndPoint(EndPoint endPoint)
return this;
}

public MqttClientOptionsBuilder WithExtendedAuthenticationExchangeHandler(IMqttExtendedAuthenticationExchangeHandler handler)
public MqttClientOptionsBuilder WithEnhancedAuthenticationHandler(IMqttEnhancedAuthenticationHandler handler)
{
_options.ExtendedAuthenticationExchangeHandler = handler;
_options.EnhancedAuthenticationHandler = handler;
return this;
}

Expand Down
5 changes: 5 additions & 0 deletions Source/MQTTnet/Options/MqttClientOptionsValidator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ public static void ThrowIfNotSupported(MqttClientOptions options)
{
Throw(nameof(options.WillUserProperties));
}

if (options.EnhancedAuthenticationHandler != null)
{
Throw(nameof(options.EnhancedAuthenticationHandler));
}
}

static void Throw(string featureName)
Expand Down
Loading

0 comments on commit 7784d58

Please sign in to comment.