Skip to content

Commit

Permalink
Add support for hot reloading of client certificates (#1783)
Browse files Browse the repository at this point in the history
* Refactor Unit Tests

* Add new certificate provider interface instead of read only certificates list.

* Update ReleaseNotes.md

* Fix Unit Tests

* Fix build

* unit testing hot swappable certificates (#1787)

* Move test classes to correct namespace

* Apply code style

---------

Co-authored-by: Sean Hanna <[email protected]>
  • Loading branch information
chkr1011 and hannasm authored Aug 19, 2023
1 parent 1a08ec1 commit 111a0d6
Show file tree
Hide file tree
Showing 22 changed files with 991 additions and 163 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ReleaseNotes.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
* [Client] Fixed _PlatformNotSupportedException_ when using Blazor (#1755, thanks to @Nickztar).
* [Client] Added hot reload of client certificates (#1781).
* [Client] Added several new option builders and aligned usage (#1781, BREAKING CHANGE!).
* [Client] Added support for _RemoteCertificateValidationCallback_ for .NET 4.5.2, 4.6.1 and 4.8 (#1806, thanks to @troky).
* [Client] Fixed wrong logging of obsolete feature when connection was not successful (#1801, thanks to @ramonsmits).
* [Client] Fixed _NullReferenceException_ when performing several actions when not connected (#1800, thanks to @ramonsmits).
Expand Down
92 changes: 45 additions & 47 deletions Samples/Client/Client_Connection_Samples.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,32 +70,6 @@ public static async Task Connect_Client()
await mqttClient.DisconnectAsync(mqttClientDisconnectOptions, CancellationToken.None);
}
}

public static async Task Connect_With_Amazon_AWS()
{
/*
* This sample creates a simple MQTT client and connects to an Amazon Web Services broker.
*
* The broker requires special settings which are set here.
*/

var mqttFactory = new MqttFactory();

using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder()
.WithTcpServer("amazon.web.services.broker")
// Disabling packet fragmentation is very important!
.WithoutPacketFragmentation()
.Build();

await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None);

Console.WriteLine("The MQTT client is connected.");

await mqttClient.DisconnectAsync();
}
}

public static async Task Connect_Client_Timeout()
{
Expand Down Expand Up @@ -161,15 +135,15 @@ public static async Task Connect_Client_Using_TLS_1_2()
using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("mqtt.fluux.io")
.WithTls(
.WithTlsOptions(
o =>
{
// The used public broker sometimes has invalid certificates. This sample accepts all
// certificates. This should not be used in live environments.
o.CertificateValidationHandler = _ => true;
o.WithCertificateValidationHandler(_ => true);
// The default value is determined by the OS. Set manually to force version.
o.SslProtocol = SslProtocols.Tls12;
o.WithSslProtocols(SslProtocols.Tls12);
})
.Build();

Expand All @@ -196,7 +170,7 @@ public static async Task Connect_Client_Using_WebSocket4Net()

using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").Build();
var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).Build();

var response = await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None);

Expand All @@ -218,7 +192,7 @@ public static async Task Connect_Client_Using_WebSockets()

using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").Build();
var mqttClientOptions = new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).Build();

var response = await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None);

Expand All @@ -241,13 +215,11 @@ public static async Task Connect_Client_With_TLS_Encryption()
using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883)
.WithTls(
o =>
{
.WithTlsOptions(
o => o.WithCertificateValidationHandler(
// The used public broker sometimes has invalid certificates. This sample accepts all
// certificates. This should not be used in live environments.
o.CertificateValidationHandler = _ => true;
})
_ => true))
.Build();

// In MQTTv5 the response contains much more information.
Expand All @@ -262,6 +234,31 @@ public static async Task Connect_Client_With_TLS_Encryption()
}
}

public static async Task Connect_With_Amazon_AWS()
{
/*
* This sample creates a simple MQTT client and connects to an Amazon Web Services broker.
*
* The broker requires special settings which are set here.
*/

var mqttFactory = new MqttFactory();

using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("amazon.web.services.broker")
// Disabling packet fragmentation is very important!
.WithoutPacketFragmentation()
.Build();

await mqttClient.ConnectAsync(mqttClientOptions, CancellationToken.None);

Console.WriteLine("The MQTT client is connected.");

await mqttClient.DisconnectAsync();
}
}

public static async Task Disconnect_Clean()
{
/*
Expand Down Expand Up @@ -317,18 +314,19 @@ public static async Task Inspect_Certificate_Validation_Errors()
using (var mqttClient = mqttFactory.CreateMqttClient())
{
var mqttClientOptions = new MqttClientOptionsBuilder().WithTcpServer("mqtt.fluux.io", 8883)
.WithTls(
.WithTlsOptions(
o =>
{
o.CertificateValidationHandler = eventArgs =>
{
eventArgs.Certificate.Subject.DumpToConsole();
eventArgs.Certificate.GetExpirationDateString().DumpToConsole();
eventArgs.Chain.ChainPolicy.RevocationMode.DumpToConsole();
eventArgs.Chain.ChainStatus.DumpToConsole();
eventArgs.SslPolicyErrors.DumpToConsole();
return true;
};
o.WithCertificateValidationHandler(
eventArgs =>
{
eventArgs.Certificate.Subject.DumpToConsole();
eventArgs.Certificate.GetExpirationDateString().DumpToConsole();
eventArgs.Chain.ChainPolicy.RevocationMode.DumpToConsole();
eventArgs.Chain.ChainStatus.DumpToConsole();
eventArgs.SslPolicyErrors.DumpToConsole();
return true;
});
})
.Build();

Expand Down Expand Up @@ -434,4 +432,4 @@ public static void Reconnect_Using_Timer()
});
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,18 +92,7 @@ public Task ConnectAsync(CancellationToken cancellationToken)
var webSocketVersion = WebSocketVersion.None;
var receiveBufferSize = 0;

var certificates = new X509CertificateCollection();
if (_webSocketOptions.TlsOptions?.Certificates != null)
{
foreach (var certificate in _webSocketOptions.TlsOptions.Certificates)
{
#if WINDOWS_UWP
certificates.Add(new X509Certificate(certificate));
#else
certificates.Add(certificate);
#endif
}
}
var certificates = _webSocketOptions.TlsOptions?.ClientCertificatesProvider?.GetCertificates();

_webSocket = new WebSocket(uri, subProtocol, cookies, customHeaders, userAgent, origin, webSocketVersion, proxy, sslProtocols, receiveBufferSize)
{
Expand Down
42 changes: 21 additions & 21 deletions Source/MQTTnet.TestApp/PublicBrokerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public static async Task RunAsync()
{
#if NET5_0_OR_GREATER
// TLS13 is only available in Net5.0
var unsafeTls13 = new MqttClientOptionsBuilderTlsParameters
var unsafeTls13 = new MqttClientTlsOptions
{
UseTls = true,
SslProtocol = SslProtocols.Tls13,
Expand All @@ -29,7 +29,7 @@ public static async Task RunAsync()
};
#endif
// Also defining TLS12 for servers that don't seem no to support TLS13.
var unsafeTls12 = new MqttClientOptionsBuilderTlsParameters
var unsafeTls12 = new MqttClientTlsOptions
{
UseTls = true,
SslProtocol = SslProtocols.Tls12,
Expand All @@ -44,16 +44,16 @@ await ExecuteTestAsync(

await ExecuteTestAsync(
"mqtt.eclipseprojects.io WS",
new MqttClientOptionsBuilder().WithWebSocketServer("mqtt.eclipseprojects.io:80/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("mqtt.eclipseprojects.io:80/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build());

#if NET5_0_OR_GREATER
await ExecuteTestAsync("mqtt.eclipseprojects.io WS TLS13",
new MqttClientOptionsBuilder().WithWebSocketServer("mqtt.eclipseprojects.io:443/mqtt")
.WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("mqtt.eclipseprojects.io:443/mqtt"))
.WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build());

await ExecuteTestAsync("mqtt.eclipseprojects.io WS TLS13 (WebSocket4Net)",
new MqttClientOptionsBuilder().WithWebSocketServer("mqtt.eclipseprojects.io:443/mqtt")
.WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build(),
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("mqtt.eclipseprojects.io:443/mqtt"))
.WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build(),
true);
#endif

Expand All @@ -68,34 +68,34 @@ await ExecuteTestAsync(

await ExecuteTestAsync(
"test.mosquitto.org TCP TLS12",
new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build());
new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build());

#if NET5_0_OR_GREATER
await ExecuteTestAsync("test.mosquitto.org TCP TLS13",
new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8883)
.WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build());
.WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build());
#endif

await ExecuteTestAsync(
"test.mosquitto.org TCP TLS12 - Authenticated",
new MqttClientOptionsBuilder().WithTcpServer("test.mosquitto.org", 8885)
.WithCredentials("rw", "readwrite")
.WithProtocolVersion(MqttProtocolVersion.V311)
.WithTls(unsafeTls12)
.WithTlsOptions(unsafeTls12)
.Build());

await ExecuteTestAsync(
"test.mosquitto.org WS",
new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8080/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("test.mosquitto.org:8080/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build());

await ExecuteTestAsync(
"test.mosquitto.org WS (WebSocket4Net)",
new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8080/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build(),
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("test.mosquitto.org:8080/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build(),
true);

await ExecuteTestAsync(
"test.mosquitto.org WS TLS12",
new MqttClientOptionsBuilder().WithWebSocketServer("test.mosquitto.org:8081/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("test.mosquitto.org:8081/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build());

// await ExecuteTestAsync(
// "test.mosquitto.org WS TLS12 (WebSocket4Net)",
Expand All @@ -109,30 +109,30 @@ await ExecuteTestAsync(

await ExecuteTestAsync(
"broker.emqx.io TCP TLS12",
new MqttClientOptionsBuilder().WithTcpServer("broker.emqx.io", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build());
new MqttClientOptionsBuilder().WithTcpServer("broker.emqx.io", 8883).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build());

#if NET5_0_OR_GREATER
await ExecuteTestAsync("broker.emqx.io TCP TLS13",
new MqttClientOptionsBuilder().WithTcpServer("broker.emqx.io", 8883)
.WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls13).Build());
.WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls13).Build());
#endif

await ExecuteTestAsync(
"broker.emqx.io WS",
new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8083/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8083/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build());

await ExecuteTestAsync(
"broker.emqx.io WS (WebSocket4Net)",
new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8084/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build(),
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8084/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build(),
true);

await ExecuteTestAsync(
"broker.emqx.io WS TLS12",
new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8084/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8084/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build());

await ExecuteTestAsync(
"broker.emqx.io WS TLS12 (WebSocket4Net)",
new MqttClientOptionsBuilder().WithWebSocketServer("broker.emqx.io:8084/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).WithTls(unsafeTls12).Build(),
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.emqx.io:8084/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).WithTlsOptions(unsafeTls12).Build(),
true);

// broker.hivemq.com
Expand All @@ -142,11 +142,11 @@ await ExecuteTestAsync(

await ExecuteTestAsync(
"broker.hivemq.com WS",
new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build());
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build());

await ExecuteTestAsync(
"broker.hivemq.com WS (WebSocket4Net)",
new MqttClientOptionsBuilder().WithWebSocketServer("broker.hivemq.com:8000/mqtt").WithProtocolVersion(MqttProtocolVersion.V311).Build(),
new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("broker.hivemq.com:8000/mqtt")).WithProtocolVersion(MqttProtocolVersion.V311).Build(),
true);

// mqtt.swifitch.cz: Does not seem to operate any more
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ public async Task Subscriptions_Are_Published_Immediately()
var receivingClient = await CreateManagedClientAsync(testEnvironment, null, connectionCheckInterval);
var sendingClient = await testEnvironment.ConnectClient();

await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true });
await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment<byte>( new byte[] { 1 }), Retain = true });

var subscribeTime = DateTime.UtcNow;

Expand Down Expand Up @@ -454,7 +454,7 @@ public async Task Subscriptions_Subscribe_Only_New_Subscriptions()
//wait a bit for the subscription to become established
await Task.Delay(500);

await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", Payload = new byte[] { 1 }, Retain = true });
await sendingClient.PublishAsync(new MqttApplicationMessage { Topic = "topic", PayloadSegment = new ArraySegment<byte>(new byte[] { 1 }), Retain = true });

var messages = await SetupReceivingOfMessages(managedClient, 1);

Expand Down
4 changes: 2 additions & 2 deletions Source/MQTTnet.Tests/Clients/MqttClient/MqttClient_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -296,7 +296,7 @@ await receiver.SubscribeAsync(

Assert.IsNotNull(receivedMessage);
Assert.AreEqual("A", receivedMessage.Topic);
Assert.AreEqual(null, receivedMessage.Payload);
Assert.AreEqual(null, receivedMessage.PayloadSegment.Array);
}
}

Expand Down Expand Up @@ -507,7 +507,7 @@ public async Task Publish_QoS_1_In_ApplicationMessageReceiveHandler()

client2.ApplicationMessageReceivedAsync += e =>
{
client2TopicResults.Add(Encoding.UTF8.GetString(e.ApplicationMessage.Payload));
client2TopicResults.Add(Encoding.UTF8.GetString(e.ApplicationMessage.PayloadSegment.ToArray()));
return CompletedTask.Instance;
};

Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet.Tests/Extensions/WebSocket4Net_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public async Task Connect_Failed_With_Invalid_Server()

using (var client = factory.CreateMqttClient())
{
var options = new MqttClientOptionsBuilder().WithWebSocketServer("ws://a.b/mqtt").WithTimeout(TimeSpan.FromSeconds(2)).Build();
var options = new MqttClientOptionsBuilder().WithWebSocketServer(o => o.WithUri("ws://a.b/mqtt")).WithTimeout(TimeSpan.FromSeconds(2)).Build();
await client.ConnectAsync(options).ConfigureAwait(false);
}
}
Expand Down
Loading

0 comments on commit 111a0d6

Please sign in to comment.