diff --git a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs index 64282ffb2..9c8df7008 100644 --- a/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs +++ b/Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs @@ -1,3 +1,4 @@ +#if !(NET452 || NET461) using System; using System.Collections.Concurrent; using System.Diagnostics; @@ -12,7 +13,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Certificates; using MQTTnet.Client; -using MQTTnet.Extensions.ManagedClient; using MQTTnet.Formatter; using MQTTnet.Protocol; using MQTTnet.Server; @@ -20,15 +20,14 @@ namespace MQTTnet.Tests.Server { // missing certificate builder api means tests won't work for older frameworks -#if !(NET452 || NET461) + [TestClass] -#endif public sealed class HotSwapCerts_Tests { - readonly TimeSpan DEFAULT_TIMEOUT = TimeSpan.FromSeconds(10); + static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(10); [TestMethod] - public void ClientCertChangeWithoutServerUpdateFailsReconnect() + public async Task ClientCertChangeWithoutServerUpdateFailsReconnect() { using (var server = new ServerTestHarness()) using (var client01 = new ClientTestHarness()) @@ -36,22 +35,22 @@ public void ClientCertChangeWithoutServerUpdateFailsReconnect() server.InstallNewClientCert(client01.GetCurrentClientCert()); client01.InstallNewServerCert(server.GetCurrentServerCert()); - server.StartServer().Wait(); + await server.StartServer(); - client01.Connect(); + await client01.Connect(); - client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + client01.WaitForConnectOrFail(DefaultTimeout); client01.HotSwapClientCert(); - server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); - client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + server.ForceDisconnectAsync(client01).Wait(DefaultTimeout); + client01.WaitForDisconnectOrFail(DefaultTimeout); - client01.WaitForConnectToFail(DEFAULT_TIMEOUT); + client01.WaitForConnectToFail(DefaultTimeout); } } [TestMethod] - public void ClientCertChangeWithServerUpdateAcceptsReconnect() + public async Task ClientCertChangeWithServerUpdateAcceptsReconnect() { using (var server = new ServerTestHarness()) using (var client01 = new ClientTestHarness()) @@ -59,23 +58,24 @@ public void ClientCertChangeWithServerUpdateAcceptsReconnect() server.InstallNewClientCert(client01.GetCurrentClientCert()); client01.InstallNewServerCert(server.GetCurrentServerCert()); - server.StartServer().Wait(); - client01.Connect(); + await server.StartServer(); - client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + await client01.Connect(); + + client01.WaitForConnectOrFail(DefaultTimeout); client01.HotSwapClientCert(); - server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); - client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + server.ForceDisconnectAsync(client01).Wait(DefaultTimeout); + client01.WaitForDisconnectOrFail(DefaultTimeout); server.InstallNewClientCert(client01.GetCurrentClientCert()); - client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + client01.WaitForConnectOrFail(DefaultTimeout); } } [TestMethod] - public void ServerCertChangeWithClientCertUpdateAllowsReconnect() + public async Task ServerCertChangeWithClientCertUpdateAllowsReconnect() { using (var server = new ServerTestHarness()) using (var client01 = new ClientTestHarness()) @@ -83,22 +83,22 @@ public void ServerCertChangeWithClientCertUpdateAllowsReconnect() server.InstallNewClientCert(client01.GetCurrentClientCert()); client01.InstallNewServerCert(server.GetCurrentServerCert()); - server.StartServer().Wait(); - client01.Connect(); + await server.StartServer(); + await client01.Connect(); - client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + client01.WaitForConnectOrFail(DefaultTimeout); server.HotSwapServerCert(); - server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); - client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + server.ForceDisconnectAsync(client01).Wait(DefaultTimeout); + client01.WaitForDisconnectOrFail(DefaultTimeout); client01.InstallNewServerCert(server.GetCurrentServerCert()); - client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + client01.WaitForConnectOrFail(DefaultTimeout); } } [TestMethod] - public void ServerCertChangeWithoutClientCertUpdateFailsReconnect() + public async Task ServerCertChangeWithoutClientCertUpdateFailsReconnect() { using (var server = new ServerTestHarness()) using (var client01 = new ClientTestHarness()) @@ -106,24 +106,21 @@ public void ServerCertChangeWithoutClientCertUpdateFailsReconnect() server.InstallNewClientCert(client01.GetCurrentClientCert()); client01.InstallNewServerCert(server.GetCurrentServerCert()); - server.StartServer().Wait(); - client01.Connect(); + await server.StartServer(); + await client01.Connect(); - client01.WaitForConnectOrFail(DEFAULT_TIMEOUT); + client01.WaitForConnectOrFail(DefaultTimeout); server.HotSwapServerCert(); - server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT); - client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT); + server.ForceDisconnectAsync(client01).Wait(DefaultTimeout); + client01.WaitForDisconnectOrFail(DefaultTimeout); - client01.WaitForConnectToFail(DEFAULT_TIMEOUT); + client01.WaitForConnectToFail(DefaultTimeout); } } static X509Certificate2 CreateSelfSignedCertificate(string oid) { -#if NET452 || NET461 - throw new NotImplementedException(); -#else var sanBuilder = new SubjectAlternativeNameBuilder(); sanBuilder.AddIpAddress(IPAddress.Loopback); sanBuilder.AddIpAddress(IPAddress.IPv6Loopback); @@ -150,24 +147,23 @@ static X509Certificate2 CreateSelfSignedCertificate(string oid) return pfxCertificate; } } -#endif } class ClientTestHarness : IDisposable { - IManagedMqttClient _client; readonly HotSwappableClientCertProvider _hotSwapClient = new HotSwappableClientCertProvider(); + IMqttClient _client; - public string ClientID => _client.InternalClient.Options.ClientId; + public string ClientId => _client.Options.ClientId; public void ClearServerCerts() { _hotSwapClient.ClearServerCerts(); } - public void Connect() + public Task Connect() { - Run_Client_Connection().Wait(); + return Run_Client_Connection(); } public void Dispose() @@ -191,18 +187,12 @@ public void InstallNewServerCert(X509Certificate2 serverCert) _hotSwapClient.InstallNewServerCert(serverCert); } - public void WaitForConnect(TimeSpan timeout) + public void WaitForConnectOrFail(TimeSpan timeout) { - var timer = Stopwatch.StartNew(); - while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout) + if (!_client.IsConnected) { - Thread.Sleep(5); + _client.ReconnectAsync().Wait(timeout); } - } - - public void WaitForConnectOrFail(TimeSpan timeout) - { - Assert.IsFalse(_client.IsConnected, "Client should be disconnected before waiting for connect."); WaitForConnect(timeout); @@ -231,7 +221,7 @@ public void WaitForDisconnect(TimeSpan timeout) public void WaitForDisconnectOrFail(TimeSpan timeout) { - WaitForConnect(timeout); + WaitForDisconnect(timeout); Assert.IsNotNull(_client, "Client was never initialized"); Assert.IsFalse(_client.IsConnected, $"Client connection should have disconnected after {timeout}"); @@ -242,28 +232,35 @@ async Task Run_Client_Connection() var optionsBuilder = new MqttClientOptionsBuilder() .WithTlsOptions( o => o.WithClientCertificatesProvider(_hotSwapClient) - .WithCertificateValidationHandler(_hotSwapClient.OnCertifciateValidation) + .WithCertificateValidationHandler(_hotSwapClient.OnCertificateValidation) .WithSslProtocols(SslProtocols.Tls12)) .WithTcpServer("localhost") .WithCleanSession() .WithProtocolVersion(MqttProtocolVersion.V500); - var mqttClientOptions = optionsBuilder.Build(); - var managedClientOptionsBuilder = new ManagedMqttClientOptionsBuilder().WithClientOptions(mqttClientOptions); - var managedClientOptions = managedClientOptionsBuilder.Build(); + var mqttClientOptions = optionsBuilder.Build(); var factory = new MqttFactory(); - var mqttClient = factory.CreateManagedMqttClient(); + var mqttClient = factory.CreateMqttClient(); _client = mqttClient; - await mqttClient.StartAsync(managedClientOptions); + await mqttClient.ConnectAsync(mqttClientOptions); + } + + void WaitForConnect(TimeSpan timeout) + { + var timer = Stopwatch.StartNew(); + while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout) + { + Thread.Sleep(5); + } } } class ServerTestHarness : IDisposable { - CancellationTokenSource _cts = new CancellationTokenSource(); readonly HotSwappableServerCertProvider _hotSwapServer = new HotSwappableServerCertProvider(); + MqttServer _server; public void ClearClientCerts() @@ -287,7 +284,7 @@ public void Dispose() public async Task ForceDisconnectAsync(ClientTestHarness client) { - await _server.DisconnectClientAsync(client.ClientID, MqttDisconnectReasonCode.UnspecifiedError); + await _server.DisconnectClientAsync(client.ClientId, MqttDisconnectReasonCode.UnspecifiedError); } public X509Certificate2 GetCurrentServerCert() @@ -313,6 +310,7 @@ public async Task StartServer() .WithRemoteCertificateValidationCallback(_hotSwapServer.RemoteCertificateValidationCallback) .WithEncryptedEndpoint() .Build(); + mqttServerOptions.TlsEndpointOptions.ClientCertificateRequired = true; _server = mqttFactory.CreateMqttServer(mqttServerOptions); await _server.StartAsync(); @@ -350,7 +348,7 @@ public void InstallNewServerCert(X509Certificate2 serverCert) ServerCerts.Add(serverCert); } - public bool OnCertifciateValidation(MqttClientCertificateValidationEventArgs certContext) + public bool OnCertificateValidation(MqttClientCertificateValidationEventArgs certContext) { var serverCerts = ServerCerts.ToArray(); @@ -430,9 +428,9 @@ public bool RemoteCertificateValidationCallback(object sender, X509Certificate c var providedCert = certificate.GetRawCertData(); for (int i = 0, n = serverCerts.Length; i < n; i++) { - var currentcert = serverCerts[i]; + var currentCert = serverCerts[i]; - if (currentcert.RawData.SequenceEqual(providedCert)) + if (currentCert.RawData.SequenceEqual(providedCert)) { return true; } @@ -442,4 +440,6 @@ public bool RemoteCertificateValidationCallback(object sender, X509Certificate c } } } -} \ No newline at end of file +} + +#endif \ No newline at end of file