diff --git a/Sources/GRPC/ClientConnection.swift b/Sources/GRPC/ClientConnection.swift index ba8480270..4488c1b52 100644 --- a/Sources/GRPC/ClientConnection.swift +++ b/Sources/GRPC/ClientConnection.swift @@ -411,7 +411,8 @@ extension Channel { connectionIdleTimeout: TimeAmount, errorDelegate: ClientErrorDelegate?, requiresZeroLengthWriteWorkaround: Bool, - logger: Logger + logger: Logger, + customVerificationCallback: NIOSSLCustomVerificationCallback? ) -> EventLoopFuture { // We add at most 8 handlers to the pipeline. var handlers: [ChannelHandler] = [] @@ -427,11 +428,20 @@ extension Channel { if let tlsConfiguration = tlsConfiguration { do { - let sslClientHandler = try NIOSSLClientHandler( - context: try NIOSSLContext(configuration: tlsConfiguration), - serverHostname: tlsServerHostname - ) - handlers.append(sslClientHandler) + if let customVerificationCallback = customVerificationCallback { + let sslClientHandler = try NIOSSLClientHandler( + context: try NIOSSLContext(configuration: tlsConfiguration), + serverHostname: tlsServerHostname, + customVerificationCallback: customVerificationCallback + ) + handlers.append(sslClientHandler) + } else { + let sslClientHandler = try NIOSSLClientHandler( + context: try NIOSSLContext(configuration: tlsConfiguration), + serverHostname: tlsServerHostname + ) + handlers.append(sslClientHandler) + } handlers.append(TLSVerificationHandler(logger: logger)) } catch { return self.eventLoop.makeFailedFuture(error) diff --git a/Sources/GRPC/ConnectionManager.swift b/Sources/GRPC/ConnectionManager.swift index 84b50483c..14404d195 100644 --- a/Sources/GRPC/ConnectionManager.swift +++ b/Sources/GRPC/ConnectionManager.swift @@ -853,7 +853,8 @@ extension ConnectionManager { group: self.eventLoop, hasTLS: self.configuration.tls != nil ), - logger: self.logger + logger: self.logger, + customVerificationCallback: self.configuration.tls?.customVerificationCallback ) // Run the debug initializer, if there is one. diff --git a/Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift b/Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift index aeab44830..80fd65da7 100644 --- a/Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift +++ b/Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift @@ -235,6 +235,15 @@ extension ClientConnection.Builder.Secure { self.tls.certificateVerification = certificateVerification return self } + + /// A custom verification callback that allows completely overriding the certificate verification logic. + @discardableResult + public func withTLSCustomVerificationCallback( + _ callback: @escaping NIOSSLCustomVerificationCallback + ) -> Self { + self.tls.customVerificationCallback = callback + return self + } } extension ClientConnection.Builder { diff --git a/Sources/GRPC/TLSConfiguration.swift b/Sources/GRPC/TLSConfiguration.swift index 746ffede1..46e28f260 100644 --- a/Sources/GRPC/TLSConfiguration.swift +++ b/Sources/GRPC/TLSConfiguration.swift @@ -68,6 +68,9 @@ extension ClientConnection.Configuration { } } + /// A custom verification callback that allows completely overriding the certificate verification logic for this connection. + public var customVerificationCallback: NIOSSLCustomVerificationCallback? + /// TLS Configuration with suitable defaults for clients. /// /// This is a wrapper around `NIOSSL.TLSConfiguration` to restrict input to values which comply @@ -83,12 +86,15 @@ extension ClientConnection.Configuration { /// `.fullVerification`. /// - Parameter hostnameOverride: Value to use for TLS SNI extension; this must not be an IP /// address, defaults to `nil`. + /// - Parameter customVerificationCallback: A callback to provide to override the certificate verification logic, + /// defaults to `nil`. public init( certificateChain: [NIOSSLCertificateSource] = [], privateKey: NIOSSLPrivateKeySource? = nil, trustRoots: NIOSSLTrustRoots = .default, certificateVerification: CertificateVerification = .fullVerification, - hostnameOverride: String? = nil + hostnameOverride: String? = nil, + customVerificationCallback: NIOSSLCustomVerificationCallback? = nil ) { self.configuration = .forClient( minimumTLSVersion: .tlsv12, @@ -99,6 +105,7 @@ extension ClientConnection.Configuration { applicationProtocols: GRPCApplicationProtocolIdentifier.client ) self.hostnameOverride = hostnameOverride + self.customVerificationCallback = customVerificationCallback } /// Creates a TLS Configuration using the given `NIOSSL.TLSConfiguration`. diff --git a/Tests/GRPCTests/ClientTLSFailureTests.swift b/Tests/GRPCTests/ClientTLSFailureTests.swift index 5c6adcb74..22e886cae 100644 --- a/Tests/GRPCTests/ClientTLSFailureTests.swift +++ b/Tests/GRPCTests/ClientTLSFailureTests.swift @@ -180,4 +180,49 @@ class ClientTLSFailureTests: GRPCTestCase { XCTFail("Expected NIOSSLExtraError.failedToValidateHostname") } } + + func testClientConnectionFailsWhenCertificateValidationDenied() throws { + let errorExpectation = self.expectation(description: "error") + // 2 errors: one for the failed handshake, and another for failing the ready-channel promise + // (because the handshake failed). + errorExpectation.expectedFulfillmentCount = 2 + + let tlsConfiguration = ClientConnection.Configuration.TLS( + certificateChain: [.certificate(SampleCertificate.client.certificate)], + privateKey: .privateKey(SamplePrivateKey.client), + trustRoots: .certificates([SampleCertificate.ca.certificate]), + hostnameOverride: SampleCertificate.server.commonName, + customVerificationCallback: { _, promise in + // The certificate validation is forced to fail + promise.fail(NIOSSLError.unableToValidateCertificate) + } + ) + + var configuration = self.makeClientConfiguration(tls: tlsConfiguration) + let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation) + configuration.errorDelegate = errorRecorder + + let stateChangeDelegate = RecordingConnectivityDelegate() + stateChangeDelegate.expectChanges(2) { changes in + XCTAssertEqual(changes, [ + Change(from: .idle, to: .connecting), + Change(from: .connecting, to: .shutdown), + ]) + } + configuration.connectivityStateDelegate = stateChangeDelegate + + // Start an RPC to trigger creating a channel. + let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration)) + _ = echo.get(.with { $0.text = "foo" }) + + self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout) + stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5)) + + if let nioSSLError = errorRecorder.errors.first as? NIOSSLError, + case .handshakeFailed(.sslError) = nioSSLError { + // Expected case. + } else { + XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)") + } + } }