Skip to content

Commit

Permalink
Remove a race condition from the WebSocket upgrade
Browse files Browse the repository at this point in the history
An upgrade to WebSocket invokes three handlers in an order. First, the
`shouldUpgrade` handler supplied by the user is run to decide if an upgrade
must go through. This handler also supplies additional headers to be sent in
the response. Next, after the WebSocket upgrader is done upgrading the
pipeline, the `completionHandler` is called. We remove Kitura-NIO's
`HTTPRequestHandler` here. Next, the `upgradePipelineHandler` is invoked. This
handler allows us to add all the WebSocket related `ChannelHandler`s.

For an undocumented reason, we saved the `ChannelHandlerContext` received by
the `completionHandler` in the `HTTPServer` and later used it upgrade the
pipeline in `upgradePipelineHandler`. This can easily lead to a race condition
where we saved the `ChannelHandlerContext` for a connection, into the
`HTTPServer` but before it could be used in `upgradePipelineHandler`, it was
overwritten by the upgrade happening on another connection. Consequently, we
never upgraded the pipeline of the former connection. This could lead to
different kinds of failures.

The `upgradePipelineHandler` has `Channel` as one of its parameters. Hence
there is no need to store the `ChannelHandlerContext` for use in this closure.
Consequently, we have to use a `Channel` to initialize an `HTTPServerRequest`,
which, in turn, is modified to accept a `Channel` instead of a
`ChannelHandlerContext`.
  • Loading branch information
Pushkar Kulkarni committed Jul 22, 2019
1 parent dfdbff2 commit 09b88f8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Sources/KituraNet/HTTP/HTTPRequestHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ internal class HTTPRequestHandler: ChannelInboundHandler, RemovableChannelHandle

switch request {
case .head(let header):
serverRequest = HTTPServerRequest(ctx: context, requestHead: header, enableSSL: enableSSLVerification)
serverRequest = HTTPServerRequest(channel: context.channel, requestHead: header, enableSSL: enableSSLVerification)
self.clientRequestedKeepAlive = header.isKeepAlive
case .body(var buffer):
guard let serverRequest = serverRequest else {
Expand Down
20 changes: 8 additions & 12 deletions Sources/KituraNet/HTTP/HTTPServer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -124,8 +124,6 @@ public class HTTPServer: Server {
/// The event loop group on which the HTTP handler runs
private let eventLoopGroup: MultiThreadedEventLoopGroup

private var ctx: ChannelHandlerContext?

/**
Creates an HTTP server object.
Expand Down Expand Up @@ -193,19 +191,18 @@ public class HTTPServer: Server {
}

/// Creates upgrade request and adds WebSocket handler to pipeline
private func upgradeHandler(webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture<Void> {
guard let ctx = self.ctx else { fatalError("The channel was probably closed during a protocol upgrade.") }
return ctx.eventLoop.submit {
let request = HTTPServerRequest(ctx: ctx, requestHead: request, enableSSL: false)
private func upgradeHandler(channel: Channel, webSocketHandlerFactory: ProtocolHandlerFactory, request: HTTPRequestHead) -> EventLoopFuture<Void> {
return channel.eventLoop.submit {
let request = HTTPServerRequest(channel: channel, requestHead: request, enableSSL: false)
return webSocketHandlerFactory.handler(for: request)
}.flatMap { (handler: ChannelHandler) -> EventLoopFuture<Void> in
return ctx.channel.pipeline.addHandler(handler).flatMap {
return channel.pipeline.addHandler(handler).flatMap {
if let _extensions = request.headers["Sec-WebSocket-Extensions"].first {
let handlers = webSocketHandlerFactory.extensionHandlers(header: _extensions)
return ctx.channel.pipeline.addHandlers(handlers, position: .before(handler))
return channel.pipeline.addHandlers(handlers, position: .before(handler))
} else {
// No extensions. We must return success.
return ctx.channel.eventLoop.makeSucceededFuture(())
return channel.eventLoop.makeSucceededFuture(())
}
}
}
Expand All @@ -222,7 +219,7 @@ public class HTTPServer: Server {

private func generateUpgradePipelineHandler(_ webSocketHandlerFactory: ProtocolHandlerFactory) -> UpgradePipelineHandlerFunction {
return { (channel: Channel, request: HTTPRequestHead) in
return self.upgradeHandler(webSocketHandlerFactory: webSocketHandlerFactory, request: request)
return self.upgradeHandler(channel: channel, webSocketHandlerFactory: webSocketHandlerFactory, request: request)
}
}

Expand Down Expand Up @@ -304,8 +301,7 @@ public class HTTPServer: Server {
.serverChannelOption(ChannelOptions.socket(SocketOptionLevel(SOL_SOCKET), SO_REUSEPORT), value: allowPortReuse ? 1 : 0)
.childChannelInitializer { channel in
let httpHandler = HTTPRequestHandler(for: self)
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { ctx in
self.ctx = ctx
let config: NIOHTTPServerUpgradeConfiguration = (upgraders: upgraders, completionHandler: { _ in
_ = channel.pipeline.removeHandler(httpHandler)
})
return channel.pipeline.configureHTTPServerPipeline(withServerUpgrade: config, withErrorHandling: true).flatMap {
Expand Down
14 changes: 7 additions & 7 deletions Sources/KituraNet/HTTP/HTTPServerRequest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ public class HTTPServerRequest: ServerRequest {
*/
public var method: String

private let ctx: ChannelHandlerContext
private let channel: Channel

private var enableSSL: Bool = false

Expand Down Expand Up @@ -208,20 +208,20 @@ public class HTTPServerRequest: ServerRequest {
}
}

init(ctx: ChannelHandlerContext, requestHead: HTTPRequestHead, enableSSL: Bool) {
init(channel: Channel, requestHead: HTTPRequestHead, enableSSL: Bool) {
// An HTTPServerRequest may be created only on the EventLoop assigned to handle
// the connection on which the HTTP request arrived.
assert(ctx.eventLoop.inEventLoop)
self.ctx = ctx
assert(channel.eventLoop.inEventLoop)
self.channel = channel
self.headers = HeadersContainer(with: requestHead.headers)
self.method = requestHead.method.rawValue
self.httpVersionMajor = UInt16(requestHead.version.major)
self.httpVersionMinor = UInt16(requestHead.version.minor)
self.rawURLString = requestHead.uri
self.enableSSL = enableSSL
self.localAddressHost = HTTPServerRequest.host(socketAddress: ctx.localAddress)
self.localAddressPort = ctx.localAddress?.port ?? 0
self.remoteAddress = HTTPServerRequest.host(socketAddress: ctx.remoteAddress)
self.localAddressHost = HTTPServerRequest.host(socketAddress: channel.localAddress)
self.localAddressPort = channel.localAddress?.port ?? 0
self.remoteAddress = HTTPServerRequest.host(socketAddress: channel.remoteAddress)
}

var buffer: BufferList?
Expand Down

0 comments on commit 09b88f8

Please sign in to comment.