Skip to content

Commit

Permalink
Add onUnexpectedConnectionClose callback to pool
Browse files Browse the repository at this point in the history
Backports `onUnexpectedConnectionClose` on the pool to 1.x
  • Loading branch information
marius-se authored and Mordil committed Jun 9, 2023
1 parent 3bd5940 commit 5875335
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 75 deletions.
28 changes: 16 additions & 12 deletions Sources/RediStack/Configuration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,20 @@ extension RedisConnection.Configuration {
var localizedDescription: String { self.kind.localizedDescription }

private let kind: Kind

private init(_ kind: Kind) { self.kind = kind }

public static func ==(lhs: ValidationError, rhs: ValidationError) -> Bool {
return lhs.kind == rhs.kind
}

private enum Kind: LocalizedError {
case invalidURLString
case missingURLScheme
case invalidURLScheme
case missingURLHost
case outOfBoundsDatabaseID

var localizedDescription: String {
let message: String = {
switch self {
Expand All @@ -66,7 +66,7 @@ extension RedisConnection {
///
/// See [https://redis.io/topics/quickstart](https://redis.io/topics/quickstart)
public static var defaultPort = 6379

internal static let defaultLogger = Logger.redisBaseConnectionLogger

/// The hostname of the connection address. If the address is a Unix socket, then it will be `nil`.
Expand All @@ -85,9 +85,9 @@ extension RedisConnection {
public let initialDatabase: Int?
/// The logger prototype that will be used by the connection by default when generating logs.
public let defaultLogger: Logger

internal let address: SocketAddress

/// Creates a new connection configuration with the provided details.
/// - Parameters:
/// - address: The socket address information to use for creating the Redis connection.
Expand All @@ -106,7 +106,7 @@ extension RedisConnection {
if initialDatabase != nil && initialDatabase! < 0 {
throw ValidationError.outOfBoundsDatabaseID
}

self.address = address
self.password = password
self.initialDatabase = initialDatabase
Expand Down Expand Up @@ -182,9 +182,9 @@ extension RedisConnection {
try Self.validateRedisURL(url)

guard let host = url.host, !host.isEmpty else { throw ValidationError.missingURLHost }

let databaseID = Int(url.lastPathComponent)

try self.init(
address: try .makeAddressResolvingHost(host, port: url.port ?? Self.defaultPort),
password: url.password,
Expand Down Expand Up @@ -219,7 +219,7 @@ extension RedisConnectionPool {
public let connectionInitialDatabase: Int?
/// The pre-configured TCP client for connections to use.
public let tcpClient: ClientBootstrap?

/// Creates a new connection factory configuration with the provided options.
/// - Parameters:
/// - connectionInitialDatabase: The optional database index to initially connect to. The default is `nil`.
Expand Down Expand Up @@ -255,11 +255,13 @@ extension RedisConnectionPool {
public let maximumConnectionCount: RedisConnectionPoolSize
/// The configuration object that controls the connection retry behavior.
public let connectionRetryConfiguration: (backoff: (initialDelay: TimeAmount, factor: Float32), timeout: TimeAmount)
/// Called when a connection in the pool is closed unexpectedly.
public let onUnexpectedConnectionClose: ((RedisConnection) -> Void)?
// these need to be var so they can be updated by the pool in some cases
public internal(set) var factoryConfiguration: ConnectionFactoryConfiguration
/// The logger prototype that will be used by the connection pool by default when generating logs.
public internal(set) var poolDefaultLogger: Logger

/// Creates a new connection configuration with the provided options.
/// - Parameters:
/// - initialServerConnectionAddresses: The set of Redis servers to which this pool is initially willing to connect.
Expand All @@ -284,6 +286,7 @@ extension RedisConnectionPool {
connectionBackoffFactor: Float32 = 2,
initialConnectionBackoffDelay: TimeAmount = .milliseconds(100),
connectionRetryTimeout: TimeAmount? = .seconds(60),
onUnexpectedConnectionClose: ((RedisConnection) -> Void)? = nil,
poolDefaultLogger: Logger? = nil
) {
self.initialConnectionAddresses = initialServerConnectionAddresses
Expand All @@ -294,6 +297,7 @@ extension RedisConnectionPool {
(initialConnectionBackoffDelay, connectionBackoffFactor),
connectionRetryTimeout ?? .milliseconds(10) // always default to a baseline 10ms
)
self.onUnexpectedConnectionClose = onUnexpectedConnectionClose
self.poolDefaultLogger = poolDefaultLogger ?? .redisBaseConnectionPoolLogger
}
}
Expand Down
78 changes: 39 additions & 39 deletions Sources/RediStack/RedisConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import NIO
import NIOConcurrencyHelpers

extension RedisConnection {

/// Creates a new connection with provided configuration and sychronization objects.
///
/// If you would like to specialize the `NIO.ClientBootstrap` that the connection communicates on, override the default by passing it in as `configuredTCPClient`.
Expand Down Expand Up @@ -55,7 +55,7 @@ extension RedisConnection {
configuredTCPClient client: ClientBootstrap? = nil
) -> EventLoopFuture<RedisConnection> {
let client = client ?? .makeRedisTCPClient(group: eventLoop)

var future = client
.connect(to: config.address)
.map { return RedisConnection(configuredRESPChannel: $0, context: config.defaultLogger) }
Expand All @@ -73,7 +73,7 @@ extension RedisConnection {
return connection.select(database: database).map { connection }
}
}

return future
}
}
Expand Down Expand Up @@ -157,14 +157,14 @@ public final class RedisConnection: RedisClient, RedisClientWithUserContext {
get { return _stateLock.withLock { self._state } }
set(newValue) { _stateLock.withLockVoid { self._state = newValue } }
}

deinit {
if isConnected {
assertionFailure("close() was not called before deinit!")
self.logger.warning("connection was not properly shutdown before deinit")
}
}

internal init(configuredRESPChannel: Channel, context: Context) {
self.channel = configuredRESPChannel
// there is a mix of verbiage here as the API is forward thinking towards "baggage context"
Expand All @@ -176,14 +176,14 @@ public final class RedisConnection: RedisClient, RedisClientWithUserContext {

RedisMetrics.activeConnectionCount.increment()
RedisMetrics.totalConnectionCount.increment()

// attach a callback to the channel to capture situations where the channel might be closed out from under
// the connection
self.channel.closeFuture.whenSuccess {
// if our state is still open, that means we didn't cause the closeFuture to resolve.
// update state, metrics, and logging
guard self.state.isConnected else { return }

self.state = .closed
self.logger.error("connection was closed unexpectedly")
RedisMetrics.activeConnectionCount.decrement()
Expand All @@ -192,13 +192,13 @@ public final class RedisConnection: RedisClient, RedisClientWithUserContext {

self.logger.trace("connection created")
}

internal enum ConnectionState {
case open
case pubsub(RedisPubSubHandler)
case shuttingDown
case closed

var isConnected: Bool {
switch self {
case .open, .pubsub: return true
Expand Down Expand Up @@ -242,40 +242,40 @@ extension RedisConnection {
return self.channel.eventLoop.makeFailedFuture(error)
}
logger.trace("received command request")

logger.debug("sending command", metadata: [
RedisLogging.MetadataKeys.commandKeyword: "\(command)",
RedisLogging.MetadataKeys.commandArguments: "\(arguments)"
])

var message: [RESPValue] = [.init(bulk: command)]
message.append(contentsOf: arguments)

let promise = channel.eventLoop.makePromise(of: RESPValue.self)
let command = RedisCommand(
message: .array(message),
responsePromise: promise
)

let startTime = DispatchTime.now().uptimeNanoseconds
promise.futureResult.whenComplete { result in
let duration = DispatchTime.now().uptimeNanoseconds - startTime
RedisMetrics.commandRoundTripTime.recordNanoseconds(duration)

// log data based on the result
switch result {
case let .failure(error):
logger.error("command failed", metadata: [
RedisLogging.MetadataKeys.error: "\(error.localizedDescription)"
])

case let .success(value):
logger.debug("command succeeded", metadata: [
RedisLogging.MetadataKeys.commandResult: "\(value)"
])
}
}

defer { logger.trace("command sent") }

if self.sendCommandsImmediately {
Expand Down Expand Up @@ -310,10 +310,10 @@ extension RedisConnection {

// we're now in a shutdown state, starting with the command queue.
self.state = .shuttingDown

let notification = self.sendQuitCommand(logger: logger) // send "QUIT" so that all the responses are written out
.flatMap { self.closeChannel() } // close the channel from our end

notification.whenFailure {
logger.error("error while closing connection", metadata: [
RedisLogging.MetadataKeys.error: "\($0)"
Expand All @@ -324,10 +324,10 @@ extension RedisConnection {
logger.trace("connection is now closed")
RedisMetrics.activeConnectionCount.decrement()
}

return notification
}

/// Bypasses everything for a normal command and explicitly just sends a "QUIT" command to Redis.
/// - Note: If the command fails, the `NIO.EventLoopFuture` will still succeed - as it's not critical for the command to succeed.
private func sendQuitCommand(logger: Logger) -> EventLoopFuture<Void> {
Expand All @@ -344,22 +344,22 @@ extension RedisConnection {
.map { _ in logger.trace("sent QUIT command") } // ignore the result's value
.recover { _ in logger.debug("recovered from error sending QUIT") } // if there's an error, just return to void
}

/// Attempts to close the `NIO.Channel`.
/// SwiftNIO throws a `NIO.EventLoopError.shutdown` if the channel is already closed,
/// so that case is captured to let this method's `NIO.EventLoopFuture` still succeed.
private func closeChannel() -> EventLoopFuture<Void> {
let promise = self.channel.eventLoop.makePromise(of: Void.self)

self.channel.close(promise: promise)

// if we succeed, great, if not - check the error that happened
return promise.futureResult
.flatMapError { error in
guard let e = error as? EventLoopError else {
return self.eventLoop.makeFailedFuture(error)
}

// if the error is that the channel is already closed, great - just succeed.
// otherwise, fail the chain
switch e {
Expand Down Expand Up @@ -395,7 +395,7 @@ extension RedisConnection {
) -> EventLoopFuture<Void> {
return self._subscribe(.channels(channels), receiver, subscribeHandler, unsubscribeHandler, nil)
}

public func psubscribe(
to patterns: [String],
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
Expand All @@ -404,7 +404,7 @@ extension RedisConnection {
) -> EventLoopFuture<Void> {
return self._subscribe(.patterns(patterns), receiver, subscribeHandler, unsubscribeHandler, nil)
}

internal func subscribe(
to channels: [RedisChannelName],
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
Expand All @@ -414,7 +414,7 @@ extension RedisConnection {
) -> EventLoopFuture<Void> {
return self._subscribe(.channels(channels), receiver, subscribeHandler, unsubscribeHandler, context)
}

internal func psubscribe(
to patterns: [String],
messageReceiver receiver: @escaping RedisSubscriptionMessageReceiver,
Expand All @@ -424,7 +424,7 @@ extension RedisConnection {
) -> EventLoopFuture<Void> {
return self._subscribe(.patterns(patterns), receiver, subscribeHandler, unsubscribeHandler, context)
}

private func _subscribe(
_ target: RedisSubscriptionTarget,
_ receiver: @escaping RedisSubscriptionMessageReceiver,
Expand All @@ -433,9 +433,9 @@ extension RedisConnection {
_ logger: Logger?
) -> EventLoopFuture<Void> {
let logger = self.prepareLoggerForUse(logger)

logger.trace("received subscribe request")

// if we're closed, just error out
guard self.state.isConnected else { return self.eventLoop.makeFailedFuture(RedisClientError.connectionClosed) }

Expand Down Expand Up @@ -483,7 +483,7 @@ extension RedisConnection {
logger.debug("the connection is now in pubsub mode")
}
}

// add the subscription and just ignore the subscription count
return handler
.addSubscription(for: target, messageReceiver: receiver, onSubscribe: onSubscribe, onUnsubscribe: onUnsubscribe)
Expand All @@ -497,27 +497,27 @@ extension RedisConnection {
public func unsubscribe(from channels: [RedisChannelName]) -> EventLoopFuture<Void> {
return self._unsubscribe(.channels(channels), nil)
}

public func punsubscribe(from patterns: [String]) -> EventLoopFuture<Void> {
return self._unsubscribe(.patterns(patterns), nil)
}

internal func unsubscribe(from channels: [RedisChannelName], context: Context?) -> EventLoopFuture<Void> {
return self._unsubscribe(.channels(channels), context)
}

internal func punsubscribe(from patterns: [String], context: Context?) -> EventLoopFuture<Void> {
return self._unsubscribe(.patterns(patterns), context)
}

private func _unsubscribe(_ target: RedisSubscriptionTarget, _ logger: Logger?) -> EventLoopFuture<Void> {
let logger = self.prepareLoggerForUse(logger)

logger.trace("received unsubscribe request")

// if we're closed, just error out
guard self.state.isConnected else { return self.eventLoop.makeFailedFuture(RedisClientError.connectionClosed) }

// if we're not in pubsub mode, then we just succeed as a no-op
guard case let .pubsub(handler) = self.state else {
// but we still assert just to give some notification to devs at debug
Expand All @@ -526,11 +526,11 @@ extension RedisConnection {
])
return self.eventLoop.makeSucceededFuture(())
}

logger.trace("removing subscription", metadata: [
RedisLogging.MetadataKeys.pubsubTarget: "\(target.debugDescription)"
])

// remove the subscription
return handler.removeSubscription(for: target)
.flatMap {
Expand Down
Loading

0 comments on commit 5875335

Please sign in to comment.