diff --git a/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift b/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift index 3ceee92..e38deca 100644 --- a/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift +++ b/Sources/GRPCNIOTransportCore/Server/Connection/ServerConnectionManagementHandler.swift @@ -14,10 +14,10 @@ * limitations under the License. */ -internal import GRPCCore -internal import NIOCore -internal import NIOHTTP2 -internal import NIOTLS +private import GRPCCore +package import NIOCore +package import NIOHTTP2 +private import NIOTLS /// A `ChannelHandler` which manages the lifecycle of a gRPC connection over HTTP/2. /// @@ -39,11 +39,11 @@ internal import NIOTLS /// Some of the behaviours are described in: /// - [gRFC A8](https://github.com/grpc/proposal/blob/master/A8-client-side-keepalive.md), and /// - [gRFC A9](https://github.com/grpc/proposal/blob/master/A9-server-side-conn-mgt.md). -final class ServerConnectionManagementHandler: ChannelDuplexHandler { - typealias InboundIn = HTTP2Frame - typealias InboundOut = HTTP2Frame - typealias OutboundIn = HTTP2Frame - typealias OutboundOut = HTTP2Frame +package final class ServerConnectionManagementHandler: ChannelDuplexHandler { + package typealias InboundIn = HTTP2Frame + package typealias InboundOut = HTTP2Frame + package typealias OutboundIn = HTTP2Frame + package typealias OutboundOut = HTTP2Frame /// The `EventLoop` of the `Channel` this handler exists in. private let eventLoop: any EventLoop @@ -98,7 +98,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { /// While NIO's `EmbeddedEventLoop` provides control over its view of time (and therefore any /// events scheduled on it) it doesn't offer a way to get the current time. This is usually done /// via `NIODeadline`. - enum Clock { + package enum Clock { case nio case manual(Manual) @@ -111,14 +111,14 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { } } - final class Manual { + package final class Manual { private(set) var time: NIODeadline - init() { + package init() { self.time = .uptimeNanoseconds(0) } - func advance(by amount: TimeAmount) { + package func advance(by amount: TimeAmount) { self.time = self.time + amount } } @@ -147,7 +147,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { } /// A synchronous view over this handler. - var syncView: SyncView { + package var syncView: SyncView { return SyncView(self) } @@ -155,7 +155,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { /// /// Methods on this view *must* be called from the same `EventLoop` as the `Channel` in which /// this handler exists. - struct SyncView { + package struct SyncView { private let handler: ServerConnectionManagementHandler fileprivate init(_ handler: ServerConnectionManagementHandler) { @@ -163,7 +163,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { } /// Notify the handler that the connection has received a flush event. - func connectionWillFlush() { + package func connectionWillFlush() { // The handler can't rely on `flush(context:)` due to its expected position in the pipeline. // It's expected to be placed after the HTTP/2 handler (i.e. closer to the application) as // it needs to receive HTTP/2 frames. However, flushes from stream channels aren't sent down @@ -178,13 +178,13 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { } /// Notify the handler that a HEADERS frame was written in the last write loop. - func wroteHeadersFrame() { + package func wroteHeadersFrame() { self.handler.eventLoop.assertInEventLoop() self.handler.frameStats.wroteHeaders() } /// Notify the handler that a DATA frame was written in the last write loop. - func wroteDataFrame() { + package func wroteDataFrame() { self.handler.eventLoop.assertInEventLoop() self.handler.frameStats.wroteData() } @@ -208,7 +208,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { /// keep-alive pings. Pings more frequent than this interval count as 'strikes' and the /// connection is closed if there are too many strikes. /// - clock: A clock providing the current time. - init( + package init( eventLoop: any EventLoop, maxIdleTime: TimeAmount?, maxAge: TimeAmount?, @@ -248,16 +248,16 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { self.requireALPN = requireALPN } - func handlerAdded(context: ChannelHandlerContext) { + package func handlerAdded(context: ChannelHandlerContext) { assert(context.eventLoop === self.eventLoop) self.context = context } - func handlerRemoved(context: ChannelHandlerContext) { + package func handlerRemoved(context: ChannelHandlerContext) { self.context = nil } - func channelActive(context: ChannelHandlerContext) { + package func channelActive(context: ChannelHandlerContext) { let view = LoopBoundView(handler: self, context: context) self.maxAgeTimer?.schedule(on: context.eventLoop) { @@ -275,7 +275,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { context.fireChannelActive() } - func channelInactive(context: ChannelHandlerContext) { + package func channelInactive(context: ChannelHandlerContext) { self.maxIdleTimer?.cancel() self.maxAgeTimer?.cancel() self.maxGraceTimer?.cancel() @@ -284,7 +284,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { context.fireChannelInactive() } - func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { + package func userInboundEventTriggered(context: ChannelHandlerContext, event: Any) { switch event { case let event as NIOHTTP2StreamCreatedEvent: self._streamCreated(event.streamID, channel: context.channel) @@ -314,7 +314,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { context.fireUserInboundEventTriggered(event) } - func channelRead(context: ChannelHandlerContext, data: NIOAny) { + package func channelRead(context: ChannelHandlerContext, data: NIOAny) { self.inReadLoop = true // Any read data indicates that the connection is alive so cancel the keep-alive timers. @@ -337,7 +337,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { context.fireChannelRead(data) } - func channelReadComplete(context: ChannelHandlerContext) { + package func channelReadComplete(context: ChannelHandlerContext) { while self.flushPending { self.flushPending = false context.flush() @@ -354,7 +354,7 @@ final class ServerConnectionManagementHandler: ChannelDuplexHandler { context.fireChannelReadComplete() } - func flush(context: ChannelHandlerContext) { + package func flush(context: ChannelHandlerContext) { self.maybeFlush(context: context) } } @@ -383,7 +383,7 @@ extension ServerConnectionManagementHandler { } extension ServerConnectionManagementHandler { - struct HTTP2StreamDelegate: @unchecked Sendable, NIOHTTP2StreamDelegate { + package struct HTTP2StreamDelegate: @unchecked Sendable, NIOHTTP2StreamDelegate { // @unchecked is okay: the only methods do the appropriate event-loop dance. private let handler: ServerConnectionManagementHandler @@ -392,7 +392,7 @@ extension ServerConnectionManagementHandler { self.handler = handler } - func streamCreated(_ id: HTTP2StreamID, channel: any Channel) { + package func streamCreated(_ id: HTTP2StreamID, channel: any Channel) { if self.handler.eventLoop.inEventLoop { self.handler._streamCreated(id, channel: channel) } else { @@ -402,7 +402,7 @@ extension ServerConnectionManagementHandler { } } - func streamClosed(_ id: HTTP2StreamID, channel: any Channel) { + package func streamClosed(_ id: HTTP2StreamID, channel: any Channel) { if self.handler.eventLoop.inEventLoop { self.handler._streamClosed(id, channel: channel) } else { @@ -413,7 +413,7 @@ extension ServerConnectionManagementHandler { } } - var http2StreamDelegate: HTTP2StreamDelegate { + package var http2StreamDelegate: HTTP2StreamDelegate { return HTTP2StreamDelegate(self) } diff --git a/Tests/GRPCNIOTransportCoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift b/Tests/GRPCNIOTransportCoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift index 91c0d38..34c0e2d 100644 --- a/Tests/GRPCNIOTransportCoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift +++ b/Tests/GRPCNIOTransportCoreTests/Server/Connection/ServerConnectionManagementHandlerTests.swift @@ -14,15 +14,15 @@ * limitations under the License. */ +import GRPCNIOTransportCore import NIOCore import NIOEmbedded import NIOHTTP2 -import XCTest +import Testing -@testable import GRPCNIOTransportCore - -final class ServerConnectionManagementHandlerTests: XCTestCase { - func testIdleTimeoutOnNewConnection() throws { +struct ServerConnectionManagementHandlerTests { + @Test("Idle timeout on new connection") + func idleTimeoutOnNewConnection() throws { let connection = try Connection(maxIdleTime: .minutes(1)) try connection.activate() // Hit the max idle time. @@ -35,7 +35,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { try connection.waitUntilClosed() } - func testIdleTimerIsCancelledWhenStreamIsOpened() throws { + @Test("Idle timeout is cancelled when stream is opened") + func idleTimerIsCancelledWhenStreamIsOpened() throws { let connection = try Connection(maxIdleTime: .minutes(1)) try connection.activate() @@ -44,17 +45,18 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { connection.advanceTime(by: .minutes(1)) // No GOAWAY frame means the timer was cancelled. - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) } - func testIdleTimerStartsWhenAllStreamsAreClosed() throws { + @Test("Idle timer starts when all streams are closed") + func idleTimerStartsWhenAllStreamsAreClosed() throws { let connection = try Connection(maxIdleTime: .minutes(1)) try connection.activate() // Open a stream to cancel the idle timer and run through the max idle time. connection.streamOpened(1) connection.advanceTime(by: .minutes(1)) - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) // Close the stream to start the timer again. connection.streamClosed(1) @@ -67,7 +69,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { try connection.waitUntilClosed() } - func testMaxAge() throws { + @Test("Connection shutdown after max age is reached") + func maxAge() throws { let connection = try Connection(maxAge: .minutes(1)) try connection.activate() @@ -87,7 +90,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { try connection.waitUntilClosed() } - func testGracefulShutdownRatchetsDownStreamID() throws { + @Test("Graceful shutdown ratchets down last stream ID") + func gracefulShutdownRatchetsDownStreamID() throws { // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same // regardless of how it's triggered. let connection = try Connection(maxIdleTime: .minutes(1)) @@ -106,7 +110,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { try connection.waitUntilClosed() } - func testGracefulShutdownGracePeriod() throws { + @Test("Graceful shutdown promoted to close after grace period") + func gracefulShutdownGracePeriod() throws { // This test uses the idle timeout to trigger graceful shutdown. The mechanism is the same // regardless of how it's triggered. let connection = try Connection( @@ -128,7 +133,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { try connection.waitUntilClosed() } - func testKeepaliveOnNewConnection() throws { + @Test("Keepalive works on new connection") + func keepaliveOnNewConnection() throws { let connection = try Connection( keepaliveTime: .minutes(5), keepaliveTimeout: .seconds(5) @@ -138,20 +144,20 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { // Wait for the keep alive timer to fire which should cause the server to send a keep // alive PING. connection.advanceTime(by: .minutes(5)) - let frame1 = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame1.streamID, .rootStream) - try XCTAssertPing(frame1.payload) { data, ack in - XCTAssertFalse(ack) - // Data is opaque, send it back. - try connection.ping(data: data, ack: true) - } + let frame1 = try #require(try connection.readFrame()) + #expect(frame1.streamID == .rootStream) + let (data, ack) = try #require(frame1.payload.ping) + #expect(!ack) + // Data is opaque, send it back. + try connection.ping(data: data, ack: true) // Run past the timeout, nothing should happen. connection.advanceTime(by: .seconds(5)) - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) } - func testKeepaliveStartsAfterReadLoop() throws { + @Test("Keepalive starts after read loop") + func keepaliveStartsAfterReadLoop() throws { let connection = try Connection( keepaliveTime: .minutes(5), keepaliveTimeout: .seconds(5) @@ -165,21 +171,21 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { // Run out the keep alive timer, it shouldn't fire. connection.advanceTime(by: .minutes(5)) - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) // Fire channel read complete to start the keep alive timer again. connection.channel.pipeline.fireChannelReadComplete() // Now expire the keep alive timer again, we should read out a PING frame. connection.advanceTime(by: .minutes(5)) - let frame1 = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame1.streamID, .rootStream) - XCTAssertPing(frame1.payload) { data, ack in - XCTAssertFalse(ack) - } + let frame1 = try #require(try connection.readFrame()) + #expect(frame1.streamID == .rootStream) + let (_, ack) = try #require(frame1.payload.ping) + #expect(!ack) } - func testKeepaliveOnNewConnectionWithoutResponse() throws { + @Test("Keepalive works on new connection without response") + func keepaliveOnNewConnectionWithoutResponse() throws { let connection = try Connection( keepaliveTime: .minutes(5), keepaliveTimeout: .seconds(5) @@ -189,11 +195,10 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { // Wait for the keep alive timer to fire which should cause the server to send a keep // alive PING. connection.advanceTime(by: .minutes(5)) - let frame1 = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame1.streamID, .rootStream) - XCTAssertPing(frame1.payload) { data, ack in - XCTAssertFalse(ack) - } + let frame1 = try #require(try connection.readFrame()) + #expect(frame1.streamID == .rootStream) + let (_, ack) = try #require(frame1.payload.ping) + #expect(!ack) // We didn't ack the PING, the connection should shutdown after the timeout. connection.advanceTime(by: .seconds(5)) @@ -203,7 +208,8 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { try connection.waitUntilClosed() } - func testClientKeepalivePolicing() throws { + @Test("Keepalive sent by client is policed") + func clientKeepalivePolicing() throws { let connection = try Connection( allowKeepaliveWithoutCalls: true, minPingIntervalWithoutCalls: .minutes(1) @@ -213,24 +219,25 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { // The first ping is valid, the second and third are strikes. for _ in 1 ... 3 { try connection.ping(data: HTTP2PingData(), ack: false) - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) } // The fourth ping is the third strike and triggers a GOAWAY. try connection.ping(data: HTTP2PingData(), ack: false) - let frame = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame.streamID, .rootStream) - XCTAssertGoAway(frame.payload) { streamID, error, data in - XCTAssertEqual(streamID, .rootStream) - XCTAssertEqual(error, .enhanceYourCalm) - XCTAssertEqual(data, ByteBuffer(string: "too_many_pings")) - } + let frame = try #require(try connection.readFrame()) + #expect(frame.streamID == .rootStream) + let (streamID, error, data) = try #require(frame.payload.goAway) + + #expect(streamID == .rootStream) + #expect(error == .enhanceYourCalm) + #expect(data == ByteBuffer(string: "too_many_pings")) // The server should close the connection. try connection.waitUntilClosed() } - func testClientKeepaliveWithPermissibleIntervals() throws { + @Test("Client keepalive works with permissible intervals") + func clientKeepaliveWithPermissibleIntervals() throws { let connection = try Connection( allowKeepaliveWithoutCalls: true, minPingIntervalWithoutCalls: .minutes(1), @@ -240,14 +247,15 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { for _ in 1 ... 100 { try connection.ping(data: HTTP2PingData(), ack: false) - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) // Advance by the ping interval. connection.advanceTime(by: .minutes(1)) } } - func testClientKeepaliveResetState() throws { + @Test("Client keepalive works after reset state") + func clientKeepaliveResetState() throws { let connection = try Connection( allowKeepaliveWithoutCalls: true, minPingIntervalWithoutCalls: .minutes(1) @@ -258,7 +266,7 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { // The first ping is valid, the second and third are strikes. for _ in 1 ... 3 { try connection.ping(data: HTTP2PingData(), ack: false) - XCTAssertNil(try connection.readFrame()) + #expect(try connection.readFrame() == nil) } } @@ -273,13 +281,13 @@ final class ServerConnectionManagementHandlerTests: XCTestCase { // The next ping is the third strike and triggers a GOAWAY. try connection.ping(data: HTTP2PingData(), ack: false) - let frame = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame.streamID, .rootStream) - XCTAssertGoAway(frame.payload) { streamID, error, data in - XCTAssertEqual(streamID, .rootStream) - XCTAssertEqual(error, .enhanceYourCalm) - XCTAssertEqual(data, ByteBuffer(string: "too_many_pings")) - } + let frame = try #require(try connection.readFrame()) + #expect(frame.streamID == .rootStream) + let (streamID, error, data) = try #require(frame.payload.goAway) + + #expect(streamID == .rootStream) + #expect(error == .enhanceYourCalm) + #expect(data == ByteBuffer(string: "too_many_pings")) // The server should close the connection. try connection.waitUntilClosed() @@ -292,18 +300,22 @@ extension ServerConnectionManagementHandlerTests { lastStreamID: HTTP2StreamID, streamToOpenBeforePingAck: HTTP2StreamID? = nil ) throws { - let frame1 = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame1.streamID, .rootStream) - XCTAssertGoAway(frame1.payload) { streamID, errorCode, _ in - XCTAssertEqual(streamID, .maxID) - XCTAssertEqual(errorCode, .noError) + do { + let frame = try #require(try connection.readFrame()) + #expect(frame.streamID == .rootStream) + + let (streamID, errorCode, _) = try #require(frame.payload.goAway) + #expect(streamID == .maxID) + #expect(errorCode == .noError) } // Followed by a PING - let frame2 = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame2.streamID, .rootStream) - try XCTAssertPing(frame2.payload) { data, ack in - XCTAssertFalse(ack) + do { + let frame = try #require(try connection.readFrame()) + #expect(frame.streamID == .rootStream) + + let (data, ack) = try #require(frame.payload.ping) + #expect(!ack) if let id = streamToOpenBeforePingAck { connection.streamOpened(id) @@ -314,11 +326,13 @@ extension ServerConnectionManagementHandlerTests { } // PING ACK triggers another GOAWAY. - let frame3 = try XCTUnwrap(connection.readFrame()) - XCTAssertEqual(frame3.streamID, .rootStream) - XCTAssertGoAway(frame3.payload) { streamID, errorCode, _ in - XCTAssertEqual(streamID, lastStreamID) - XCTAssertEqual(errorCode, .noError) + do { + let frame = try #require(try connection.readFrame()) + #expect(frame.streamID == .rootStream) + + let (streamID, errorCode, _) = try #require(frame.payload.goAway) + #expect(streamID == lastStreamID) + #expect(errorCode == .noError) } } } diff --git a/Tests/GRPCNIOTransportCoreTests/XCTest+FramePayload.swift b/Tests/GRPCNIOTransportCoreTests/XCTest+FramePayload.swift index b6892d0..12b3907 100644 --- a/Tests/GRPCNIOTransportCoreTests/XCTest+FramePayload.swift +++ b/Tests/GRPCNIOTransportCoreTests/XCTest+FramePayload.swift @@ -41,3 +41,23 @@ func XCTAssertPing( XCTFail("Expected '.ping' got '\(payload)'") } } + +extension HTTP2Frame.FramePayload { + var goAway: (lastStreamID: HTTP2StreamID, errorCode: HTTP2ErrorCode, opaqueData: ByteBuffer?)? { + switch self { + case .goAway(let lastStreamID, let errorCode, let opaqueData): + return (lastStreamID, errorCode, opaqueData) + default: + return nil + } + } + + var ping: (data: HTTP2PingData, ack: Bool)? { + switch self { + case .ping(let data, ack: let ack): + return (data, ack) + default: + return nil + } + } +}