Skip to content

Commit

Permalink
Fix Unit Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
chkr1011 committed Aug 22, 2023
1 parent e604345 commit fc31c30
Showing 1 changed file with 62 additions and 62 deletions.
124 changes: 62 additions & 62 deletions Source/MQTTnet.Tests/Server/HotSwapCerts_Tests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
#if !(NET452 || NET461)
using System;
using System.Collections.Concurrent;
using System.Diagnostics;
Expand All @@ -12,118 +13,114 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Certificates;
using MQTTnet.Client;
using MQTTnet.Extensions.ManagedClient;
using MQTTnet.Formatter;
using MQTTnet.Protocol;
using MQTTnet.Server;

namespace MQTTnet.Tests.Server
{
// missing certificate builder api means tests won't work for older frameworks
#if !(NET452 || NET461)

[TestClass]
#endif
public sealed class HotSwapCerts_Tests
{
readonly TimeSpan DEFAULT_TIMEOUT = TimeSpan.FromSeconds(10);
static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(10);

[TestMethod]
public void ClientCertChangeWithoutServerUpdateFailsReconnect()
public async Task ClientCertChangeWithoutServerUpdateFailsReconnect()
{
using (var server = new ServerTestHarness())
using (var client01 = new ClientTestHarness())
{
server.InstallNewClientCert(client01.GetCurrentClientCert());
client01.InstallNewServerCert(server.GetCurrentServerCert());

server.StartServer().Wait();
await server.StartServer();

client01.Connect();
await client01.Connect();

client01.WaitForConnectOrFail(DEFAULT_TIMEOUT);
client01.WaitForConnectOrFail(DefaultTimeout);

client01.HotSwapClientCert();
server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT);
client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT);
server.ForceDisconnectAsync(client01).Wait(DefaultTimeout);
client01.WaitForDisconnectOrFail(DefaultTimeout);

client01.WaitForConnectToFail(DEFAULT_TIMEOUT);
client01.WaitForConnectToFail(DefaultTimeout);
}
}

[TestMethod]
public void ClientCertChangeWithServerUpdateAcceptsReconnect()
public async Task ClientCertChangeWithServerUpdateAcceptsReconnect()
{
using (var server = new ServerTestHarness())
using (var client01 = new ClientTestHarness())
{
server.InstallNewClientCert(client01.GetCurrentClientCert());
client01.InstallNewServerCert(server.GetCurrentServerCert());

server.StartServer().Wait();
client01.Connect();
await server.StartServer();

client01.WaitForConnectOrFail(DEFAULT_TIMEOUT);
await client01.Connect();

client01.WaitForConnectOrFail(DefaultTimeout);

client01.HotSwapClientCert();
server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT);
client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT);
server.ForceDisconnectAsync(client01).Wait(DefaultTimeout);
client01.WaitForDisconnectOrFail(DefaultTimeout);

server.InstallNewClientCert(client01.GetCurrentClientCert());

client01.WaitForConnectOrFail(DEFAULT_TIMEOUT);
client01.WaitForConnectOrFail(DefaultTimeout);
}
}

[TestMethod]
public void ServerCertChangeWithClientCertUpdateAllowsReconnect()
public async Task ServerCertChangeWithClientCertUpdateAllowsReconnect()
{
using (var server = new ServerTestHarness())
using (var client01 = new ClientTestHarness())
{
server.InstallNewClientCert(client01.GetCurrentClientCert());
client01.InstallNewServerCert(server.GetCurrentServerCert());

server.StartServer().Wait();
client01.Connect();
await server.StartServer();
await client01.Connect();

client01.WaitForConnectOrFail(DEFAULT_TIMEOUT);
client01.WaitForConnectOrFail(DefaultTimeout);
server.HotSwapServerCert();

server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT);
client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT);
server.ForceDisconnectAsync(client01).Wait(DefaultTimeout);
client01.WaitForDisconnectOrFail(DefaultTimeout);
client01.InstallNewServerCert(server.GetCurrentServerCert());

client01.WaitForConnectOrFail(DEFAULT_TIMEOUT);
client01.WaitForConnectOrFail(DefaultTimeout);
}
}

[TestMethod]
public void ServerCertChangeWithoutClientCertUpdateFailsReconnect()
public async Task ServerCertChangeWithoutClientCertUpdateFailsReconnect()
{
using (var server = new ServerTestHarness())
using (var client01 = new ClientTestHarness())
{
server.InstallNewClientCert(client01.GetCurrentClientCert());
client01.InstallNewServerCert(server.GetCurrentServerCert());

server.StartServer().Wait();
client01.Connect();
await server.StartServer();
await client01.Connect();

client01.WaitForConnectOrFail(DEFAULT_TIMEOUT);
client01.WaitForConnectOrFail(DefaultTimeout);
server.HotSwapServerCert();

server.ForceDisconnectAsync(client01).Wait(DEFAULT_TIMEOUT);
client01.WaitForDisconnectOrFail(DEFAULT_TIMEOUT);
server.ForceDisconnectAsync(client01).Wait(DefaultTimeout);
client01.WaitForDisconnectOrFail(DefaultTimeout);

client01.WaitForConnectToFail(DEFAULT_TIMEOUT);
client01.WaitForConnectToFail(DefaultTimeout);
}
}

static X509Certificate2 CreateSelfSignedCertificate(string oid)
{
#if NET452 || NET461
throw new NotImplementedException();
#else
var sanBuilder = new SubjectAlternativeNameBuilder();
sanBuilder.AddIpAddress(IPAddress.Loopback);
sanBuilder.AddIpAddress(IPAddress.IPv6Loopback);
Expand All @@ -150,24 +147,23 @@ static X509Certificate2 CreateSelfSignedCertificate(string oid)
return pfxCertificate;
}
}
#endif
}

class ClientTestHarness : IDisposable
{
IManagedMqttClient _client;
readonly HotSwappableClientCertProvider _hotSwapClient = new HotSwappableClientCertProvider();
IMqttClient _client;

public string ClientID => _client.InternalClient.Options.ClientId;
public string ClientId => _client.Options.ClientId;

public void ClearServerCerts()
{
_hotSwapClient.ClearServerCerts();
}

public void Connect()
public Task Connect()
{
Run_Client_Connection().Wait();
return Run_Client_Connection();
}

public void Dispose()
Expand All @@ -191,18 +187,12 @@ public void InstallNewServerCert(X509Certificate2 serverCert)
_hotSwapClient.InstallNewServerCert(serverCert);
}

public void WaitForConnect(TimeSpan timeout)
public void WaitForConnectOrFail(TimeSpan timeout)
{
var timer = Stopwatch.StartNew();
while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout)
if (!_client.IsConnected)
{
Thread.Sleep(5);
_client.ReconnectAsync().Wait(timeout);
}
}

public void WaitForConnectOrFail(TimeSpan timeout)
{
Assert.IsFalse(_client.IsConnected, "Client should be disconnected before waiting for connect.");

WaitForConnect(timeout);

Expand Down Expand Up @@ -231,7 +221,7 @@ public void WaitForDisconnect(TimeSpan timeout)

public void WaitForDisconnectOrFail(TimeSpan timeout)
{
WaitForConnect(timeout);
WaitForDisconnect(timeout);

Assert.IsNotNull(_client, "Client was never initialized");
Assert.IsFalse(_client.IsConnected, $"Client connection should have disconnected after {timeout}");
Expand All @@ -242,28 +232,35 @@ async Task Run_Client_Connection()
var optionsBuilder = new MqttClientOptionsBuilder()
.WithTlsOptions(
o => o.WithClientCertificatesProvider(_hotSwapClient)
.WithCertificateValidationHandler(_hotSwapClient.OnCertifciateValidation)
.WithCertificateValidationHandler(_hotSwapClient.OnCertificateValidation)
.WithSslProtocols(SslProtocols.Tls12))
.WithTcpServer("localhost")
.WithCleanSession()
.WithProtocolVersion(MqttProtocolVersion.V500);
var mqttClientOptions = optionsBuilder.Build();

var managedClientOptionsBuilder = new ManagedMqttClientOptionsBuilder().WithClientOptions(mqttClientOptions);
var managedClientOptions = managedClientOptionsBuilder.Build();
var mqttClientOptions = optionsBuilder.Build();

var factory = new MqttFactory();
var mqttClient = factory.CreateManagedMqttClient();
var mqttClient = factory.CreateMqttClient();
_client = mqttClient;

await mqttClient.StartAsync(managedClientOptions);
await mqttClient.ConnectAsync(mqttClientOptions);
}

void WaitForConnect(TimeSpan timeout)
{
var timer = Stopwatch.StartNew();
while ((_client == null || !_client.IsConnected) && timer.Elapsed < timeout)
{
Thread.Sleep(5);
}
}
}

class ServerTestHarness : IDisposable
{
CancellationTokenSource _cts = new CancellationTokenSource();
readonly HotSwappableServerCertProvider _hotSwapServer = new HotSwappableServerCertProvider();

MqttServer _server;

public void ClearClientCerts()
Expand All @@ -287,7 +284,7 @@ public void Dispose()

public async Task ForceDisconnectAsync(ClientTestHarness client)
{
await _server.DisconnectClientAsync(client.ClientID, MqttDisconnectReasonCode.UnspecifiedError);
await _server.DisconnectClientAsync(client.ClientId, MqttDisconnectReasonCode.UnspecifiedError);
}

public X509Certificate2 GetCurrentServerCert()
Expand All @@ -313,6 +310,7 @@ public async Task StartServer()
.WithRemoteCertificateValidationCallback(_hotSwapServer.RemoteCertificateValidationCallback)
.WithEncryptedEndpoint()
.Build();

mqttServerOptions.TlsEndpointOptions.ClientCertificateRequired = true;
_server = mqttFactory.CreateMqttServer(mqttServerOptions);
await _server.StartAsync();
Expand Down Expand Up @@ -350,7 +348,7 @@ public void InstallNewServerCert(X509Certificate2 serverCert)
ServerCerts.Add(serverCert);
}

public bool OnCertifciateValidation(MqttClientCertificateValidationEventArgs certContext)
public bool OnCertificateValidation(MqttClientCertificateValidationEventArgs certContext)
{
var serverCerts = ServerCerts.ToArray();

Expand Down Expand Up @@ -430,9 +428,9 @@ public bool RemoteCertificateValidationCallback(object sender, X509Certificate c
var providedCert = certificate.GetRawCertData();
for (int i = 0, n = serverCerts.Length; i < n; i++)
{
var currentcert = serverCerts[i];
var currentCert = serverCerts[i];

if (currentcert.RawData.SequenceEqual(providedCert))
if (currentCert.RawData.SequenceEqual(providedCert))
{
return true;
}
Expand All @@ -442,4 +440,6 @@ public bool RemoteCertificateValidationCallback(object sender, X509Certificate c
}
}
}
}
}

#endif

0 comments on commit fc31c30

Please sign in to comment.