diff --git a/src/MySqlConnector/Core/ServerSession.cs b/src/MySqlConnector/Core/ServerSession.cs index 5d3e09dfd..8338a9bb6 100644 --- a/src/MySqlConnector/Core/ServerSession.cs +++ b/src/MySqlConnector/Core/ServerSession.cs @@ -65,10 +65,12 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool) public bool ProcAccessDenied { get; set; } public ICollection> ActivityTags => m_activityTags; public MySqlDataReader DataReader { get; set; } + public MySqlConnectionOpenedConditions Conditions { get; private set; } public ValueTask ReturnToPoolAsync(IOBehavior ioBehavior, MySqlConnection? owningConnection) { Log.ReturningToPool(m_logger, Id, Pool?.Id ?? 0); + Conditions = MySqlConnectionOpenedConditions.None; LastReturnedTimestamp = Stopwatch.GetTimestamp(); if (Pool is null) return default; @@ -414,6 +416,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella } } + Conditions = MySqlConnectionOpenedConditions.New; var connected = cs.ConnectionProtocol switch { MySqlConnectionProtocol.Sockets => await OpenTcpSocketAsync(cs, loadBalancer ?? throw new ArgumentNullException(nameof(loadBalancer)), activity, ioBehavior, cancellationToken).ConfigureAwait(false), @@ -747,6 +750,7 @@ public static async ValueTask ConnectAndRedirectAsync(ILogger con public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConnection connection, IOBehavior ioBehavior, CancellationToken cancellationToken) { VerifyState(State.Connected); + Conditions |= MySqlConnectionOpenedConditions.Reset; try { @@ -829,6 +833,7 @@ public async Task TryResetConnectionAsync(ConnectionSettings cs, MySqlConn Log.IgnoringFailureInTryResetConnectionAsync(m_logger, ex, Id, "SocketException"); } + Conditions &= ~MySqlConnectionOpenedConditions.Reset; return false; } diff --git a/src/MySqlConnector/MySqlConnection.cs b/src/MySqlConnector/MySqlConnection.cs index a037a3eed..e7c0989de 100644 --- a/src/MySqlConnector/MySqlConnection.cs +++ b/src/MySqlConnector/MySqlConnection.cs @@ -551,6 +551,13 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella ActivitySourceHelper.CopyTags(m_session!.ActivityTags, activity); m_hasBeenOpened = true; SetState(ConnectionState.Open); + + if (ConnectionOpenedCallback is { } autoEnlistConnectionOpenedCallback) + { + cancellationToken.ThrowIfCancellationRequested(); + await autoEnlistConnectionOpenedCallback(new(this, MySqlConnectionOpenedConditions.None), cancellationToken).ConfigureAwait(false); + } + return; } } @@ -582,6 +589,12 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella if (m_connectionSettings.AutoEnlist && System.Transactions.Transaction.Current is not null) EnlistTransaction(System.Transactions.Transaction.Current); + + if (ConnectionOpenedCallback is { } connectionOpenedCallback) + { + cancellationToken.ThrowIfCancellationRequested(); + await connectionOpenedCallback(new(this, m_session.Conditions), cancellationToken).ConfigureAwait(false); + } } catch (Exception ex) when (activity is { IsAllDataRequested: true }) { @@ -917,6 +930,11 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel) using var connection = CloneWith(csb.ConnectionString); connection.m_connectionSettings = connectionSettings; + + // clear the callback because this is not intended to be a user-visible MySqlConnection that will execute setup logic; it's a + // non-pooled connection that will execute "KILL QUERY" then immediately be closed + connection.ConnectionOpenedCallback = null; + connection.Open(); #if NET6_0_OR_GREATER var killQuerySql = string.Create(CultureInfo.InvariantCulture, $"KILL QUERY {command.Connection!.ServerThread}"); @@ -992,6 +1010,7 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel) internal MySqlTransaction? CurrentTransaction { get; set; } internal MySqlConnectorLoggingConfiguration LoggingConfiguration { get; } internal ZstandardPlugin? ZstandardPlugin { get; set; } + internal MySqlConnectionOpenedCallback? ConnectionOpenedCallback { get; set; } internal bool AllowLoadLocalInfile => GetInitializedConnectionSettings().AllowLoadLocalInfile; internal bool AllowUserVariables => GetInitializedConnectionSettings().AllowUserVariables; internal bool AllowZeroDateTime => GetInitializedConnectionSettings().AllowZeroDateTime; @@ -1142,6 +1161,7 @@ private MySqlConnection(MySqlConnection other, MySqlDataSource? dataSource, stri ProvideClientCertificatesCallback = other.ProvideClientCertificatesCallback; ProvidePasswordCallback = other.ProvidePasswordCallback; RemoteCertificateValidationCallback = other.RemoteCertificateValidationCallback; + ConnectionOpenedCallback = other.ConnectionOpenedCallback; } private void VerifyNotDisposed() diff --git a/src/MySqlConnector/MySqlConnectionOpenedCallback.cs b/src/MySqlConnector/MySqlConnectionOpenedCallback.cs new file mode 100644 index 000000000..c1b9c3941 --- /dev/null +++ b/src/MySqlConnector/MySqlConnectionOpenedCallback.cs @@ -0,0 +1,9 @@ +namespace MySqlConnector; + +/// +/// A callback that is invoked when a new is opened. +/// +/// A giving information about the connection being opened. +/// A that can be used to cancel the asynchronous operation. +/// A representing the result of the possibly-asynchronous operation. +public delegate ValueTask MySqlConnectionOpenedCallback(MySqlConnectionOpenedContext context, CancellationToken cancellationToken); diff --git a/src/MySqlConnector/MySqlConnectionOpenedConditions.cs b/src/MySqlConnector/MySqlConnectionOpenedConditions.cs new file mode 100644 index 000000000..49c4c670a --- /dev/null +++ b/src/MySqlConnector/MySqlConnectionOpenedConditions.cs @@ -0,0 +1,23 @@ +namespace MySqlConnector; + +/// +/// Bitflags giving the conditions under which a connection was opened. +/// +[Flags] +public enum MySqlConnectionOpenedConditions +{ + /// + /// No specific conditions apply. This value may be used when an existing pooled connection is reused without being reset. + /// + None = 0, + + /// + /// A new physical connection to a MySQL Server was opened. This value is mutually exclusive with . + /// + New = 1, + + /// + /// An existing pooled connection to a MySQL Server was reset. This value is mutually exclusive with . + /// + Reset = 2, +} diff --git a/src/MySqlConnector/MySqlConnectionOpenedContext.cs b/src/MySqlConnector/MySqlConnectionOpenedContext.cs new file mode 100644 index 000000000..46fe9adf3 --- /dev/null +++ b/src/MySqlConnector/MySqlConnectionOpenedContext.cs @@ -0,0 +1,23 @@ +namespace MySqlConnector; + +/// +/// Contains information passed to when a new is opened. +/// +public sealed class MySqlConnectionOpenedContext +{ + /// + /// The that was opened. + /// + public MySqlConnection Connection { get; } + + /// + /// Bitflags giving the conditions under which a connection was opened. + /// + public MySqlConnectionOpenedConditions Conditions { get; } + + internal MySqlConnectionOpenedContext(MySqlConnection connection, MySqlConnectionOpenedConditions conditions) + { + Connection = connection; + Conditions = conditions; + } +} diff --git a/src/MySqlConnector/MySqlDataSource.cs b/src/MySqlConnector/MySqlDataSource.cs index 0c4372fd4..56e9718af 100644 --- a/src/MySqlConnector/MySqlDataSource.cs +++ b/src/MySqlConnector/MySqlDataSource.cs @@ -19,7 +19,7 @@ public sealed class MySqlDataSource : DbDataSource /// The connection string for the MySQL Server. This parameter is required. /// Thrown if is null. public MySqlDataSource(string connectionString) - : this(connectionString ?? throw new ArgumentNullException(nameof(connectionString)), MySqlConnectorLoggingConfiguration.NullConfiguration, null, null, null, null, default, default, default) + : this(connectionString ?? throw new ArgumentNullException(nameof(connectionString)), MySqlConnectorLoggingConfiguration.NullConfiguration, null, null, null, null, default, default, default, default) { } @@ -31,7 +31,8 @@ internal MySqlDataSource(string connectionString, Func>? periodicPasswordProvider, TimeSpan periodicPasswordProviderSuccessRefreshInterval, TimeSpan periodicPasswordProviderFailureRefreshInterval, - ZstandardPlugin? zstandardPlugin) + ZstandardPlugin? zstandardPlugin, + MySqlConnectionOpenedCallback? connectionOpenedCallback) { m_connectionString = connectionString; LoggingConfiguration = loggingConfiguration; @@ -40,6 +41,7 @@ internal MySqlDataSource(string connectionString, m_remoteCertificateValidationCallback = remoteCertificateValidationCallback; m_logger = loggingConfiguration.DataSourceLogger; m_zstandardPlugin = zstandardPlugin; + m_connectionOpenedCallback = connectionOpenedCallback; Pool = ConnectionPool.CreatePool(m_connectionString, LoggingConfiguration, name); m_id = Interlocked.Increment(ref s_lastId); @@ -142,6 +144,7 @@ protected override DbConnection CreateDbConnection() ProvideClientCertificatesCallback = m_clientCertificatesCallback, ProvidePasswordCallback = m_providePasswordCallback, RemoteCertificateValidationCallback = m_remoteCertificateValidationCallback, + ConnectionOpenedCallback = m_connectionOpenedCallback, }; } @@ -225,6 +228,7 @@ private string ProvidePasswordFromInitialRefreshTask(MySqlProvidePasswordContext private readonly TimeSpan m_periodicPasswordProviderSuccessRefreshInterval; private readonly TimeSpan m_periodicPasswordProviderFailureRefreshInterval; private readonly ZstandardPlugin? m_zstandardPlugin; + private readonly MySqlConnectionOpenedCallback? m_connectionOpenedCallback; private readonly MySqlProvidePasswordContext? m_providePasswordContext; private readonly CancellationTokenSource? m_passwordProviderTimerCancellationTokenSource; private readonly Timer? m_passwordProviderTimer; diff --git a/src/MySqlConnector/MySqlDataSourceBuilder.cs b/src/MySqlConnector/MySqlDataSourceBuilder.cs index e1d15d183..4bceec180 100644 --- a/src/MySqlConnector/MySqlDataSourceBuilder.cs +++ b/src/MySqlConnector/MySqlDataSourceBuilder.cs @@ -89,6 +89,17 @@ public MySqlDataSourceBuilder UseRemoteCertificateValidationCallback(RemoteCerti return this; } + /// + /// Adds a callback that is invoked when a new is opened. + /// + /// The callback to invoke. + /// This builder, so that method calls can be chained. + public MySqlDataSourceBuilder UseConnectionOpenedCallback(MySqlConnectionOpenedCallback callback) + { + m_connectionOpenedCallback += callback; + return this; + } + /// /// Builds a which is ready for use. /// @@ -104,7 +115,8 @@ public MySqlDataSource Build() m_periodicPasswordProvider, m_periodicPasswordProviderSuccessRefreshInterval, m_periodicPasswordProviderFailureRefreshInterval, - ZstandardPlugin + ZstandardPlugin, + m_connectionOpenedCallback ); } @@ -122,4 +134,5 @@ public MySqlDataSource Build() private Func>? m_periodicPasswordProvider; private TimeSpan m_periodicPasswordProviderSuccessRefreshInterval; private TimeSpan m_periodicPasswordProviderFailureRefreshInterval; + private MySqlConnectionOpenedCallback? m_connectionOpenedCallback; } diff --git a/tests/IntegrationTests/TransactionScopeTests.cs b/tests/IntegrationTests/TransactionScopeTests.cs index ffe14b594..5a808f193 100644 --- a/tests/IntegrationTests/TransactionScopeTests.cs +++ b/tests/IntegrationTests/TransactionScopeTests.cs @@ -886,6 +886,38 @@ public void Bug1348() Assert.True(rollbacked, $"First branch transaction '{xid}1' not rolled back"); } + + [Fact] + public void ConnectionOpenedCallbackAutoEnlistInTransaction() + { + var connectionOpenedCallbackCount = 0; + var connectionOpenedConditions = MySqlConnectionOpenedConditions.None; + using var dataSource = new MySqlDataSourceBuilder(AppConfig.ConnectionString) + .UseConnectionOpenedCallback((ctx, token) => + { + connectionOpenedCallbackCount++; + connectionOpenedConditions = ctx.Conditions; + return default; + }) + .Build(); + + using (var transactionScope = new TransactionScope()) + { + using (var conn = dataSource.OpenConnection()) + { + Assert.Equal(1, connectionOpenedCallbackCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, connectionOpenedConditions); + } + + using (var conn = dataSource.OpenConnection()) + { + Assert.Equal(2, connectionOpenedCallbackCount); + Assert.Equal(MySqlConnectionOpenedConditions.None, connectionOpenedConditions); + } + + transactionScope.Complete(); + } + } #endif readonly DatabaseFixture m_database; diff --git a/tests/MySqlConnector.Tests/ConnectionOpenedCallbackTests.cs b/tests/MySqlConnector.Tests/ConnectionOpenedCallbackTests.cs new file mode 100644 index 000000000..bf9b72e84 --- /dev/null +++ b/tests/MySqlConnector.Tests/ConnectionOpenedCallbackTests.cs @@ -0,0 +1,140 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading.Tasks; + +namespace MySqlConnector.Tests; + +public class ConnectionOpenedCallbackTests : IDisposable +{ + public ConnectionOpenedCallbackTests() + { + m_server = new(); + m_server.Start(); + + m_csb = new MySqlConnectionStringBuilder() + { + Server = "localhost", + Port = (uint) m_server.Port, + }; + m_dataSource = new MySqlDataSourceBuilder(m_csb.ConnectionString) + .UseConnectionOpenedCallback(OnConnectionOpenedAsync) + .Build(); + } + + public void Dispose() + { + m_dataSource.Dispose(); + m_server.Stop(); + } + + [Fact] + public void CallbackIsInvoked() + { + using (var connection = m_dataSource.CreateConnection()) + { + Assert.Equal(0, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.None, m_connectionOpenedConditions); + + connection.Open(); + + Assert.Equal(1, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, m_connectionOpenedConditions); + } + } + + [Fact] + public void CallbackIsInvokedForPooledConnection() + { + using (var connection = m_dataSource.CreateConnection()) + { + Assert.Equal(0, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.None, m_connectionOpenedConditions); + + connection.Open(); + + Assert.Equal(1, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, m_connectionOpenedConditions); + } + + using (var connection = m_dataSource.OpenConnection()) + { + Assert.Equal(2, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.Reset, m_connectionOpenedConditions); + } + + using (var connection = m_dataSource.OpenConnection()) + { + Assert.Equal(3, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.Reset, m_connectionOpenedConditions); + } + } + + [Fact] + public void CallbackIsInvokedForNonPooledConnection() + { + var csb = new MySqlConnectionStringBuilder(m_csb.ConnectionString) + { + Pooling = false, + }; + using var dataSource = new MySqlDataSourceBuilder(csb.ConnectionString) + .UseConnectionOpenedCallback(OnConnectionOpenedAsync) + .Build(); + + using (var connection = dataSource.OpenConnection()) + { + Assert.Equal(1, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, m_connectionOpenedConditions); + } + + using (var connection = dataSource.OpenConnection()) + { + Assert.Equal(2, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, m_connectionOpenedConditions); + } + + using (var connection = dataSource.OpenConnection()) + { + Assert.Equal(3, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, m_connectionOpenedConditions); + } + } + + [Fact] + public void ConditionsForNonResetConnection() + { + var csb = new MySqlConnectionStringBuilder(m_csb.ConnectionString) + { + ConnectionReset = false, + }; + using var dataSource = new MySqlDataSourceBuilder(csb.ConnectionString) + .UseConnectionOpenedCallback(OnConnectionOpenedAsync) + .Build(); + + using (var connection = dataSource.OpenConnection()) + { + Assert.Equal(1, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.New, m_connectionOpenedConditions); + } + using (var connection = dataSource.OpenConnection()) + { + Assert.Equal(2, m_connectionOpenedCount); + Assert.Equal(MySqlConnectionOpenedConditions.None, m_connectionOpenedConditions); + } + } + + private ValueTask OnConnectionOpenedAsync(MySqlConnectionOpenedContext context, CancellationToken cancellationToken) + { + m_connectionOpenedCount++; + m_connectionOpenedConditions = context.Conditions; + return default; + } + + private readonly FakeMySqlServer m_server; + private readonly MySqlConnectionStringBuilder m_csb; + private readonly MySqlDataSource m_dataSource; + + private int m_connectionOpenedCount; + private MySqlConnectionOpenedConditions m_connectionOpenedConditions; +} diff --git a/tests/MySqlConnector.Tests/MySqlConnector.Tests.csproj b/tests/MySqlConnector.Tests/MySqlConnector.Tests.csproj index 5d14a2cce..8de95ae49 100644 --- a/tests/MySqlConnector.Tests/MySqlConnector.Tests.csproj +++ b/tests/MySqlConnector.Tests/MySqlConnector.Tests.csproj @@ -47,7 +47,7 @@ - +