Skip to content

Commit

Permalink
Use correct hostname when cancelling query. Fixes #1514
Browse files Browse the repository at this point in the history
Connect directly to the cancelled command's session's IP address but maintain the hostname for SSL certificate validation.

Signed-off-by: Bradley Grainger <[email protected]>
  • Loading branch information
bgrainger committed Oct 5, 2024
1 parent f74ce0b commit 9d17f7f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 9 deletions.
11 changes: 8 additions & 3 deletions src/MySqlConnector/Core/ConnectionSettings.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#if NETCOREAPP3_0_OR_GREATER
using System.Net.Security;
#endif
using System.Net;
using System.Security.Authentication;
using MySqlConnector.Utilities;

Expand Down Expand Up @@ -150,7 +151,9 @@ public ConnectionSettings(MySqlConnectionStringBuilder csb)
static int ToSigned(uint value) => value >= int.MaxValue ? int.MaxValue : (int) value;
}

public ConnectionSettings CloneWith(string host, int port, string userId) => new(this, host, port, userId);
public ConnectionSettings CloneWith(string host, int port, string userId) => new(this, host, port, userId, null);

public ConnectionSettings CloneWith(IPAddress ipAddress) => new(this, HostNames![0], Port, UserID, ipAddress);

private static MySqlGuidFormat GetEffectiveGuidFormat(MySqlGuidFormat guidFormat, bool oldGuids)
{
Expand Down Expand Up @@ -182,6 +185,7 @@ private static MySqlGuidFormat GetEffectiveGuidFormat(MySqlGuidFormat guidFormat
public string ConnectionString { get; }
public MySqlConnectionProtocol ConnectionProtocol { get; }
public IReadOnlyList<string>? HostNames { get; }
public IPAddress? IPAddress { get; }
public MySqlLoadBalance LoadBalance { get; }
public int Port { get; }
public string PipeName { get; }
Expand Down Expand Up @@ -268,17 +272,18 @@ public int ConnectionTimeoutMilliseconds
}
}

private ConnectionSettings(ConnectionSettings other, string host, int port, string userId)
private ConnectionSettings(ConnectionSettings other, string host, int port, string userId, IPAddress? ipAddress)
{
ConnectionStringBuilder = new MySqlConnectionStringBuilder(other.ConnectionString);
ConnectionStringBuilder.Port = (uint)port;
ConnectionStringBuilder.Port = (uint) port;
ConnectionStringBuilder.Server = host;
ConnectionStringBuilder.UserID = userId;

ConnectionString = ConnectionStringBuilder.ConnectionString;

ConnectionProtocol = MySqlConnectionProtocol.Sockets;
HostNames = [host];
IPAddress = ipAddress;
LoadBalance = other.LoadBalance;
Port = port;
PipeName = other.PipeName;
Expand Down
15 changes: 11 additions & 4 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1164,13 +1164,20 @@ private async Task<bool> OpenTcpSocketAsync(ConnectionSettings cs, ILoadBalancer
IPAddress[] ipAddresses;
try
{
ipAddresses = ioBehavior == IOBehavior.Asynchronous
if (cs.IPAddress is { } ipAddress)
{
ipAddresses = [ipAddress];
}
else
{
ipAddresses = ioBehavior == IOBehavior.Asynchronous
#if NET6_0_OR_GREATER
? await Dns.GetHostAddressesAsync(hostName, cancellationToken).ConfigureAwait(false)
? await Dns.GetHostAddressesAsync(hostName, cancellationToken).ConfigureAwait(false)
#else
? await Dns.GetHostAddressesAsync(hostName).ConfigureAwait(false)
? await Dns.GetHostAddressesAsync(hostName).ConfigureAwait(false)
#endif
: Dns.GetHostAddresses(hostName);
: Dns.GetHostAddresses(hostName);
}
}
catch (SocketException ex)
{
Expand Down
15 changes: 13 additions & 2 deletions src/MySqlConnector/MySqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -896,16 +896,27 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel)
AutoEnlist = false,
Pooling = false,
};
if (session.IPEndPoint is { Address: { } ipAddress, Port: { } port })

// connect directly to the session's IP address to ensure we're cancelling the query on the right server (in a load-balanced scenario)
IPAddress? ipAddress = null;
if (session.IPEndPoint is { Address: { } sessionIpAddress, Port: { } port })
{
csb.Server = ipAddress.ToString();
// set the hostname to the existing session's hostname (for SSL validation)
csb.Server = session.HostName;
csb.Port = (uint) port;
ipAddress = sessionIpAddress;
}
csb.UserID = session.UserID;
var cancellationTimeout = GetConnectionSettings().CancellationTimeout;
csb.ConnectionTimeout = cancellationTimeout < 1 ? 3u : (uint) cancellationTimeout;

// forcibly set the IP address the new connection should use
var connectionSettings = new ConnectionSettings(csb);
if (ipAddress is not null)
connectionSettings = connectionSettings.CloneWith(ipAddress);

using var connection = CloneWith(csb.ConnectionString);
connection.m_connectionSettings = connectionSettings;
connection.Open();
#if NET6_0_OR_GREATER
var killQuerySql = string.Create(CultureInfo.InvariantCulture, $"KILL QUERY {command.Connection!.ServerThread}");
Expand Down

0 comments on commit 9d17f7f

Please sign in to comment.