Skip to content

Commit

Permalink
Add ConnectionOpenedCallback. Fixes #1508 (#1515)
Browse files Browse the repository at this point in the history
This differs from the existing StateChanged event in that:

- it supports an async callback
- it's only invoked when a connection is opened
- it provides information about new vs existing and whether the connection was reset

Signed-off-by: Bradley Grainger <[email protected]>
  • Loading branch information
bgrainger authored Oct 13, 2024
1 parent 633a65b commit b78e43f
Show file tree
Hide file tree
Showing 10 changed files with 273 additions and 4 deletions.
5 changes: 5 additions & 0 deletions src/MySqlConnector/Core/ServerSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,12 @@ public ServerSession(ILogger logger, IConnectionPoolMetadata pool)
public bool ProcAccessDenied { get; set; }
public ICollection<KeyValuePair<string, object?>> ActivityTags => m_activityTags;
public MySqlDataReader DataReader { get; set; }
public MySqlConnectionOpenedConditions Conditions { get; private set; }

public ValueTask ReturnToPoolAsync(IOBehavior ioBehavior, MySqlConnection? owningConnection)
{
Log.ReturningToPool(m_logger, Id, Pool?.Id ?? 0);
Conditions = MySqlConnectionOpenedConditions.None;
LastReturnedTimestamp = Stopwatch.GetTimestamp();
if (Pool is null)
return default;
Expand Down Expand Up @@ -414,6 +416,7 @@ public async Task DisposeAsync(IOBehavior ioBehavior, CancellationToken cancella
}
}

Conditions = MySqlConnectionOpenedConditions.New;
var connected = cs.ConnectionProtocol switch
{
MySqlConnectionProtocol.Sockets => await OpenTcpSocketAsync(cs, loadBalancer ?? throw new ArgumentNullException(nameof(loadBalancer)), activity, ioBehavior, cancellationToken).ConfigureAwait(false),
Expand Down Expand Up @@ -747,6 +750,7 @@ public static async ValueTask<ServerSession> ConnectAndRedirectAsync(ILogger con
public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConnection connection, IOBehavior ioBehavior, CancellationToken cancellationToken)
{
VerifyState(State.Connected);
Conditions |= MySqlConnectionOpenedConditions.Reset;

try
{
Expand Down Expand Up @@ -829,6 +833,7 @@ public async Task<bool> TryResetConnectionAsync(ConnectionSettings cs, MySqlConn
Log.IgnoringFailureInTryResetConnectionAsync(m_logger, ex, Id, "SocketException");
}

Conditions &= ~MySqlConnectionOpenedConditions.Reset;
return false;
}

Expand Down
20 changes: 20 additions & 0 deletions src/MySqlConnector/MySqlConnection.cs
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,13 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella
ActivitySourceHelper.CopyTags(m_session!.ActivityTags, activity);
m_hasBeenOpened = true;
SetState(ConnectionState.Open);

if (ConnectionOpenedCallback is { } autoEnlistConnectionOpenedCallback)
{
cancellationToken.ThrowIfCancellationRequested();
await autoEnlistConnectionOpenedCallback(new(this, MySqlConnectionOpenedConditions.None), cancellationToken).ConfigureAwait(false);
}

return;
}
}
Expand Down Expand Up @@ -582,6 +589,12 @@ internal async Task OpenAsync(IOBehavior? ioBehavior, CancellationToken cancella

if (m_connectionSettings.AutoEnlist && System.Transactions.Transaction.Current is not null)
EnlistTransaction(System.Transactions.Transaction.Current);

if (ConnectionOpenedCallback is { } connectionOpenedCallback)
{
cancellationToken.ThrowIfCancellationRequested();
await connectionOpenedCallback(new(this, m_session.Conditions), cancellationToken).ConfigureAwait(false);
}
}
catch (Exception ex) when (activity is { IsAllDataRequested: true })
{
Expand Down Expand Up @@ -917,6 +930,11 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel)

using var connection = CloneWith(csb.ConnectionString);
connection.m_connectionSettings = connectionSettings;

// clear the callback because this is not intended to be a user-visible MySqlConnection that will execute setup logic; it's a
// non-pooled connection that will execute "KILL QUERY" then immediately be closed
connection.ConnectionOpenedCallback = null;

connection.Open();
#if NET6_0_OR_GREATER
var killQuerySql = string.Create(CultureInfo.InvariantCulture, $"KILL QUERY {command.Connection!.ServerThread}");
Expand Down Expand Up @@ -992,6 +1010,7 @@ internal void Cancel(ICancellableCommand command, int commandId, bool isCancel)
internal MySqlTransaction? CurrentTransaction { get; set; }
internal MySqlConnectorLoggingConfiguration LoggingConfiguration { get; }
internal ZstandardPlugin? ZstandardPlugin { get; set; }
internal MySqlConnectionOpenedCallback? ConnectionOpenedCallback { get; set; }
internal bool AllowLoadLocalInfile => GetInitializedConnectionSettings().AllowLoadLocalInfile;
internal bool AllowUserVariables => GetInitializedConnectionSettings().AllowUserVariables;
internal bool AllowZeroDateTime => GetInitializedConnectionSettings().AllowZeroDateTime;
Expand Down Expand Up @@ -1142,6 +1161,7 @@ private MySqlConnection(MySqlConnection other, MySqlDataSource? dataSource, stri
ProvideClientCertificatesCallback = other.ProvideClientCertificatesCallback;
ProvidePasswordCallback = other.ProvidePasswordCallback;
RemoteCertificateValidationCallback = other.RemoteCertificateValidationCallback;
ConnectionOpenedCallback = other.ConnectionOpenedCallback;
}

private void VerifyNotDisposed()
Expand Down
9 changes: 9 additions & 0 deletions src/MySqlConnector/MySqlConnectionOpenedCallback.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
namespace MySqlConnector;

/// <summary>
/// A callback that is invoked when a new <see cref="MySqlConnection"/> is opened.
/// </summary>
/// <param name="context">A <see cref="MySqlConnectionOpenedContext"/> giving information about the connection being opened.</param>
/// <param name="cancellationToken">A <see cref="CancellationToken"/> that can be used to cancel the asynchronous operation.</param>
/// <returns>A <see cref="ValueTask"/> representing the result of the possibly-asynchronous operation.</returns>
public delegate ValueTask MySqlConnectionOpenedCallback(MySqlConnectionOpenedContext context, CancellationToken cancellationToken);
23 changes: 23 additions & 0 deletions src/MySqlConnector/MySqlConnectionOpenedConditions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace MySqlConnector;

/// <summary>
/// Bitflags giving the conditions under which a connection was opened.
/// </summary>
[Flags]
public enum MySqlConnectionOpenedConditions
{
/// <summary>
/// No specific conditions apply. This value may be used when an existing pooled connection is reused without being reset.
/// </summary>
None = 0,

/// <summary>
/// A new physical connection to a MySQL Server was opened. This value is mutually exclusive with <see cref="Reset"/>.
/// </summary>
New = 1,

/// <summary>
/// An existing pooled connection to a MySQL Server was reset. This value is mutually exclusive with <see cref="New"/>.
/// </summary>
Reset = 2,
}
23 changes: 23 additions & 0 deletions src/MySqlConnector/MySqlConnectionOpenedContext.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
namespace MySqlConnector;

/// <summary>
/// Contains information passed to <see cref="MySqlConnectionOpenedCallback"/> when a new <see cref="MySqlConnection"/> is opened.
/// </summary>
public sealed class MySqlConnectionOpenedContext
{
/// <summary>
/// The <see cref="MySqlConnection"/> that was opened.
/// </summary>
public MySqlConnection Connection { get; }

/// <summary>
/// Bitflags giving the conditions under which a connection was opened.
/// </summary>
public MySqlConnectionOpenedConditions Conditions { get; }

internal MySqlConnectionOpenedContext(MySqlConnection connection, MySqlConnectionOpenedConditions conditions)
{
Connection = connection;
Conditions = conditions;
}
}
8 changes: 6 additions & 2 deletions src/MySqlConnector/MySqlDataSource.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public sealed class MySqlDataSource : DbDataSource
/// <param name="connectionString">The connection string for the MySQL Server. This parameter is required.</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="connectionString"/> is <c>null</c>.</exception>
public MySqlDataSource(string connectionString)
: this(connectionString ?? throw new ArgumentNullException(nameof(connectionString)), MySqlConnectorLoggingConfiguration.NullConfiguration, null, null, null, null, default, default, default)
: this(connectionString ?? throw new ArgumentNullException(nameof(connectionString)), MySqlConnectorLoggingConfiguration.NullConfiguration, null, null, null, null, default, default, default, default)
{
}

Expand All @@ -31,7 +31,8 @@ internal MySqlDataSource(string connectionString,
Func<MySqlProvidePasswordContext, CancellationToken, ValueTask<string>>? periodicPasswordProvider,
TimeSpan periodicPasswordProviderSuccessRefreshInterval,
TimeSpan periodicPasswordProviderFailureRefreshInterval,
ZstandardPlugin? zstandardPlugin)
ZstandardPlugin? zstandardPlugin,
MySqlConnectionOpenedCallback? connectionOpenedCallback)
{
m_connectionString = connectionString;
LoggingConfiguration = loggingConfiguration;
Expand All @@ -40,6 +41,7 @@ internal MySqlDataSource(string connectionString,
m_remoteCertificateValidationCallback = remoteCertificateValidationCallback;
m_logger = loggingConfiguration.DataSourceLogger;
m_zstandardPlugin = zstandardPlugin;
m_connectionOpenedCallback = connectionOpenedCallback;

Pool = ConnectionPool.CreatePool(m_connectionString, LoggingConfiguration, name);
m_id = Interlocked.Increment(ref s_lastId);
Expand Down Expand Up @@ -142,6 +144,7 @@ protected override DbConnection CreateDbConnection()
ProvideClientCertificatesCallback = m_clientCertificatesCallback,
ProvidePasswordCallback = m_providePasswordCallback,
RemoteCertificateValidationCallback = m_remoteCertificateValidationCallback,
ConnectionOpenedCallback = m_connectionOpenedCallback,
};
}

Expand Down Expand Up @@ -225,6 +228,7 @@ private string ProvidePasswordFromInitialRefreshTask(MySqlProvidePasswordContext
private readonly TimeSpan m_periodicPasswordProviderSuccessRefreshInterval;
private readonly TimeSpan m_periodicPasswordProviderFailureRefreshInterval;
private readonly ZstandardPlugin? m_zstandardPlugin;
private readonly MySqlConnectionOpenedCallback? m_connectionOpenedCallback;
private readonly MySqlProvidePasswordContext? m_providePasswordContext;
private readonly CancellationTokenSource? m_passwordProviderTimerCancellationTokenSource;
private readonly Timer? m_passwordProviderTimer;
Expand Down
15 changes: 14 additions & 1 deletion src/MySqlConnector/MySqlDataSourceBuilder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,17 @@ public MySqlDataSourceBuilder UseRemoteCertificateValidationCallback(RemoteCerti
return this;
}

/// <summary>
/// Adds a callback that is invoked when a new <see cref="MySqlConnection"/> is opened.
/// </summary>
/// <param name="callback">The callback to invoke.</param>
/// <returns>This builder, so that method calls can be chained.</returns>
public MySqlDataSourceBuilder UseConnectionOpenedCallback(MySqlConnectionOpenedCallback callback)
{
m_connectionOpenedCallback += callback;
return this;
}

/// <summary>
/// Builds a <see cref="MySqlDataSource"/> which is ready for use.
/// </summary>
Expand All @@ -104,7 +115,8 @@ public MySqlDataSource Build()
m_periodicPasswordProvider,
m_periodicPasswordProviderSuccessRefreshInterval,
m_periodicPasswordProviderFailureRefreshInterval,
ZstandardPlugin
ZstandardPlugin,
m_connectionOpenedCallback
);
}

Expand All @@ -122,4 +134,5 @@ public MySqlDataSource Build()
private Func<MySqlProvidePasswordContext, CancellationToken, ValueTask<string>>? m_periodicPasswordProvider;
private TimeSpan m_periodicPasswordProviderSuccessRefreshInterval;
private TimeSpan m_periodicPasswordProviderFailureRefreshInterval;
private MySqlConnectionOpenedCallback? m_connectionOpenedCallback;
}
32 changes: 32 additions & 0 deletions tests/IntegrationTests/TransactionScopeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -886,6 +886,38 @@ public void Bug1348()

Assert.True(rollbacked, $"First branch transaction '{xid}1' not rolled back");
}

[Fact]
public void ConnectionOpenedCallbackAutoEnlistInTransaction()
{
var connectionOpenedCallbackCount = 0;
var connectionOpenedConditions = MySqlConnectionOpenedConditions.None;
using var dataSource = new MySqlDataSourceBuilder(AppConfig.ConnectionString)
.UseConnectionOpenedCallback((ctx, token) =>
{
connectionOpenedCallbackCount++;
connectionOpenedConditions = ctx.Conditions;
return default;
})
.Build();

using (var transactionScope = new TransactionScope())
{
using (var conn = dataSource.OpenConnection())
{
Assert.Equal(1, connectionOpenedCallbackCount);
Assert.Equal(MySqlConnectionOpenedConditions.New, connectionOpenedConditions);
}

using (var conn = dataSource.OpenConnection())
{
Assert.Equal(2, connectionOpenedCallbackCount);
Assert.Equal(MySqlConnectionOpenedConditions.None, connectionOpenedConditions);
}

transactionScope.Complete();
}
}
#endif

readonly DatabaseFixture m_database;
Expand Down
Loading

0 comments on commit b78e43f

Please sign in to comment.