From 09b88f886f250360408f65bad108789bd910802d Mon Sep 17 00:00:00 2001 From: Pushkar Kulkarni Date: Mon, 22 Jul 2019 23:27:09 +0530 Subject: [PATCH] Remove a race condition from the WebSocket upgrade An upgrade to WebSocket invokes three handlers in an order. First, the `shouldUpgrade` handler supplied by the user is run to decide if an upgrade must go through. This handler also supplies additional headers to be sent in the response. Next, after the WebSocket upgrader is done upgrading the pipeline, the `completionHandler` is called. We remove Kitura-NIO's `HTTPRequestHandler` here. Next, the `upgradePipelineHandler` is invoked. This handler allows us to add all the WebSocket related `ChannelHandler`s. For an undocumented reason, we saved the `ChannelHandlerContext` received by the `completionHandler` in the `HTTPServer` and later used it upgrade the pipeline in `upgradePipelineHandler`. This can easily lead to a race condition where we saved the `ChannelHandlerContext` for a connection, into the `HTTPServer` but before it could be used in `upgradePipelineHandler`, it was overwritten by the upgrade happening on another connection. Consequently, we never upgraded the pipeline of the former connection. This could lead to different kinds of failures. The `upgradePipelineHandler` has `Channel` as one of its parameters. Hence there is no need to store the `ChannelHandlerContext` for use in this closure. Consequently, we have to use a `Channel` to initialize an `HTTPServerRequest`, which, in turn, is modified to accept a `Channel` instead of a `ChannelHandlerContext`. --- .../KituraNet/HTTP/HTTPRequestHandler.swift | 2 +- Sources/KituraNet/HTTP/HTTPServer.swift | 20 ++++++++----------- .../KituraNet/HTTP/HTTPServerRequest.swift | 14 ++++++------- 3 files changed, 16 insertions(+), 20 deletions(-) diff --git a/Sources/KituraNet/HTTP/HTTPRequestHandler.swift b/Sources/KituraNet/HTTP/HTTPRequestHandler.swift index 46dd83d7..b1317840 100644 --- a/Sources/KituraNet/HTTP/HTTPRequestHandler.swift +++ b/Sources/KituraNet/HTTP/HTTPRequestHandler.swift @@ -79,7 +79,7 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle switch request { case .head(let header): - serverRequest = HTTPServerRequest(ctx: context, requestHead: header, enableSSL: enableSSLVerification) + serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification) self.clientRequestedKeepAlive = header.isKeepAlive case .body(var buffer): guard let serverRequest = serverRequest else { diff --git a/Sources/KituraNet/HTTP/HTTPServer.swift b/Sources/KituraNet/HTTP/HTTPServer.swift index 58be06de..4340735a 100644 --- a/Sources/KituraNet/HTTP/HTTPServer.swift +++ b/Sources/KituraNet/HTTP/HTTPServer.swift @@ -124,8 +124,6 @@ public class HTTPServer: Server { /// The event loop group on which the HTTP handler runs private let eventLoopGroup: MultiThreadedEventLoopGroup - private var ctx: ChannelHandlerContext? - /** Creates an HTTP server object. @@ -193,19 +191,18 @@ public class HTTPServer: Server { } /// Creates upgrade request and adds WebSocket handler to pipeline - private func upgradeHandler(webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture { - guard let ctx = self.ctx else { fatalError("The channel was probably closed during a protocol upgrade.") } - return ctx.eventLoop.submit { - let request = HTTPServerRequest(ctx: ctx, requestHead: request, enableSSL: false) + private func upgradeHandler(channel: Channel, webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture { + return channel.eventLoop.submit { + let request = HTTPServerRequest(channel: channel, requestHead: request, enableSSL: false) return webSocketHandlerFactory.handler(for: request) }.flatMap { (handler: ChannelHandler) -> EventLoopFuture in - return ctx.channel.pipeline.addHandler(handler).flatMap { + return channel.pipeline.addHandler(handler).flatMap { if let _extensions = request.headers["Sec-WebSocket-Extensions"].first { let handlers = webSocketHandlerFactory.extensionHandlers(header: _extensions) - return ctx.channel.pipeline.addHandlers(handlers, position: .before(handler)) + return channel.pipeline.addHandlers(handlers, position: .before(handler)) } else { // No extensions. We must return success. - return ctx.channel.eventLoop.makeSucceededFuture(()) + return channel.eventLoop.makeSucceededFuture(()) } } } @@ -222,7 +219,7 @@ public class HTTPServer: Server { private func generateUpgradePipelineHandler(_ webSocketHandlerFactory: ProtocolHandlerFactory) -> UpgradePipelineHandlerFunction { return { (channel: Channel, request: HTTPRequestHead) in - return self.upgradeHandler(webSocketHandlerFactory: webSocketHandlerFactory, request: request) + return self.upgradeHandler(channel: channel, webSocketHandlerFactory: webSocketHandlerFactory, request: request) } } @@ -304,8 +301,7 @@ public class HTTPServer: Server { .serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: allowPortReuse ? 1 : 0) .childChannelInitializer { channel in let httpHandler = HTTPRequestHandler(for: self) - let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { ctx in - self.ctx = ctx + let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { _ in _ = channel.pipeline.removeHandler(httpHandler) }) return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config, withErrorHandling: true).flatMap { diff --git a/Sources/KituraNet/HTTP/HTTPServerRequest.swift b/Sources/KituraNet/HTTP/HTTPServerRequest.swift index d071b778..937c3d69 100644 --- a/Sources/KituraNet/HTTP/HTTPServerRequest.swift +++ b/Sources/KituraNet/HTTP/HTTPServerRequest.swift @@ -180,7 +180,7 @@ public class HTTPServerRequest: ServerRequest { */ public var method: String - private let ctx: ChannelHandlerContext + private let channel: Channel private var enableSSL: Bool = false @@ -208,20 +208,20 @@ public class HTTPServerRequest: ServerRequest { } } - init(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead, enableSSL: Bool) { + init(channel: Channel, requestHead: HTTPRequestHead, enableSSL: Bool) { // An HTTPServerRequest may be created only on the EventLoop assigned to handle // the connection on which the HTTP request arrived. - assert(ctx.eventLoop.inEventLoop) - self.ctx = ctx + assert(channel.eventLoop.inEventLoop) + self.channel = channel self.headers = HeadersContainer(with: requestHead.headers) self.method = requestHead.method.rawValue self.httpVersionMajor = UInt16(requestHead.version.major) self.httpVersionMinor = UInt16(requestHead.version.minor) self.rawURLString = requestHead.uri self.enableSSL = enableSSL - self.localAddressHost = HTTPServerRequest.host(socketAddress: ctx.localAddress) - self.localAddressPort = ctx.localAddress?.port ?? 0 - self.remoteAddress = HTTPServerRequest.host(socketAddress: ctx.remoteAddress) + self.localAddressHost = HTTPServerRequest.host(socketAddress: channel.localAddress) + self.localAddressPort = channel.localAddress?.port ?? 0 + self.remoteAddress = HTTPServerRequest.host(socketAddress: channel.remoteAddress) } var buffer: BufferList?