From 9d17f7fc7cc2349a6bebb4efe82d7051b2b07e6a Mon Sep 17 00:00:00 2001 From: Bradley Grainger Date: Fri, 4 Oct 2024 20:16:34 -0700 Subject: [PATCH] Use correct hostname when cancelling query. Fixes #1514 Connect directly to the cancelled command's session's IP address but maintain the hostname for SSL certificate validation. Signed-off-by: Bradley Grainger --- src/MySqlConnector/Core/ConnectionSettings.cs | 11 ++++++++--- src/MySqlConnector/Core/ServerSession.cs | 15 +++++++++++---- src/MySqlConnector/MySqlConnection.cs | 15 +++++++++++++-- 3 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/MySqlConnector/Core/ConnectionSettings.cs b/src/MySqlConnector/Core/ConnectionSettings.cs index 8b365833c..f43162fbe 100644 --- a/src/MySqlConnector/Core/ConnectionSettings.cs +++ b/src/MySqlConnector/Core/ConnectionSettings.cs @@ -1,6 +1,7 @@ #if NETCOREAPP3_0_OR_GREATER using System.Net.Security; #endif +using System.Net; using System.Security.Authentication; using MySqlConnector.Utilities; @@ -150,7 +151,9 @@ public ConnectionSettings(MySqlConnectionStringBuilder csb) static int ToSigned(uint value) => value >= int.MaxValue ? int.MaxValue : (int) value; } - public ConnectionSettings CloneWith(string host, int port, string userId) => new(this, host, port, userId); + public ConnectionSettings CloneWith(string host, int port, string userId) => new(this, host, port, userId, null); + + public ConnectionSettings CloneWith(IPAddress ipAddress) => new(this, HostNames![0], Port, UserID, ipAddress); private static MySqlGuidFormat GetEffectiveGuidFormat(MySqlGuidFormat guidFormat, bool oldGuids) { @@ -182,6 +185,7 @@ private static MySqlGuidFormat GetEffectiveGuidFormat(MySqlGuidFormat guidFormat public string ConnectionString { get; } public MySqlConnectionProtocol ConnectionProtocol { get; } public IReadOnlyList? HostNames { get; } + public IPAddress? IPAddress { get; } public MySqlLoadBalance LoadBalance { get; } public int Port { get; } public string PipeName { get; } @@ -268,10 +272,10 @@ public int ConnectionTimeoutMilliseconds } } - private ConnectionSettings(ConnectionSettings other, string host, int port, string userId) + private ConnectionSettings(ConnectionSettings other, string host, int port, string userId, IPAddress? ipAddress) { ConnectionStringBuilder = new MySqlConnectionStringBuilder(other.ConnectionString); - ConnectionStringBuilder.Port = (uint)port; + ConnectionStringBuilder.Port = (uint) port; ConnectionStringBuilder.Server = host; ConnectionStringBuilder.UserID = userId; @@ -279,6 +283,7 @@ private ConnectionSettings(ConnectionSettings other, string host, int port, stri ConnectionProtocol = MySqlConnectionProtocol.Sockets; HostNames = [host]; + IPAddress = ipAddress; LoadBalance = other.LoadBalance; Port = port; PipeName = other.PipeName; diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index 25e36f5a9..93112e16f 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -1164,13 +1164,20 @@ private async Task OpenTcpSocketAsync(ConnectionSettings cs, ILoadBalancer IPAddress[] ipAddresses; try { - ipAddresses = ioBehavior == IOBehavior.Asynchronous + if (cs.IPAddress is { } ipAddress) + { + ipAddresses = [ipAddress]; + } + else + { + ipAddresses = ioBehavior == IOBehavior.Asynchronous #if NET6_0_OR_GREATER - ? await Dns.GetHostAddressesAsync(hostName, cancellationToken).ConfigureAwait(false) + ? await Dns.GetHostAddressesAsync(hostName, cancellationToken).ConfigureAwait(false) #else - ? await Dns.GetHostAddressesAsync(hostName).ConfigureAwait(false) + ? await Dns.GetHostAddressesAsync(hostName).ConfigureAwait(false) #endif - : Dns.GetHostAddresses(hostName); + : Dns.GetHostAddresses(hostName); + } } catch (SocketException ex) { diff --git a/src/MySqlConnector/MySqlConnection.cs b/src/MySqlConnector/MySqlConnection.cs index bb6cbe795..b98c330f7 100644 --- a/src/MySqlConnector/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlConnection.cs @@ -896,16 +896,27 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel) AutoEnlist = false, Pooling = false, }; - if (session.IPEndPoint is { Address: { } ipAddress, Port: { } port }) + + // connect directly to the session's IP address to ensure we're cancelling the query on the right server (in a load-balanced scenario) + IPAddress? ipAddress = null; + if (session.IPEndPoint is { Address: { } sessionIpAddress, Port: { } port }) { - csb.Server = ipAddress.ToString(); + // set the hostname to the existing session's hostname (for SSL validation) + csb.Server = session.HostName; csb.Port = (uint) port; + ipAddress = sessionIpAddress; } csb.UserID = session.UserID; var cancellationTimeout = GetConnectionSettings().CancellationTimeout; csb.ConnectionTimeout = cancellationTimeout < 1 ? 3u : (uint) cancellationTimeout; + // forcibly set the IP address the new connection should use + var connectionSettings = new ConnectionSettings(csb); + if (ipAddress is not null) + connectionSettings = connectionSettings.CloneWith(ipAddress); + using var connection = CloneWith(csb.ConnectionString); + connection.m_connectionSettings = connectionSettings; connection.Open(); #if NET6_0_OR_GREATER var killQuerySql = string.Create(CultureInfo.InvariantCulture, $"KILL QUERY {command.Connection!.ServerThread}");