Skip to content

Commit

Permalink
Added support for NIOSSLCustomVerificationCallback for client connect…
Browse files Browse the repository at this point in the history
…ion (#1107)

This allows client apps to perform SSL Public Key Pinning, or override the certificate verification logic
  • Loading branch information
franck-clement-ug authored Jan 22, 2021
1 parent 6f92056 commit e70c2cf
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 8 deletions.
22 changes: 16 additions & 6 deletions Sources/GRPC/ClientConnection.swift
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,8 @@ extension Channel {
connectionIdleTimeout: TimeAmount,
errorDelegate: ClientErrorDelegate?,
requiresZeroLengthWriteWorkaround: Bool,
logger: Logger
logger: Logger,
customVerificationCallback: NIOSSLCustomVerificationCallback?
) -> EventLoopFuture<Void> {
// We add at most 8 handlers to the pipeline.
var handlers: [ChannelHandler] = []
Expand All @@ -427,11 +428,20 @@ extension Channel {

if let tlsConfiguration = tlsConfiguration {
do {
let sslClientHandler = try NIOSSLClientHandler(
context: try NIOSSLContext(configuration: tlsConfiguration),
serverHostname: tlsServerHostname
)
handlers.append(sslClientHandler)
if let customVerificationCallback = customVerificationCallback {
let sslClientHandler = try NIOSSLClientHandler(
context: try NIOSSLContext(configuration: tlsConfiguration),
serverHostname: tlsServerHostname,
customVerificationCallback: customVerificationCallback
)
handlers.append(sslClientHandler)
} else {
let sslClientHandler = try NIOSSLClientHandler(
context: try NIOSSLContext(configuration: tlsConfiguration),
serverHostname: tlsServerHostname
)
handlers.append(sslClientHandler)
}
handlers.append(TLSVerificationHandler(logger: logger))
} catch {
return self.eventLoop.makeFailedFuture(error)
Expand Down
3 changes: 2 additions & 1 deletion Sources/GRPC/ConnectionManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,8 @@ extension ConnectionManager {
group: self.eventLoop,
hasTLS: self.configuration.tls != nil
),
logger: self.logger
logger: self.logger,
customVerificationCallback: self.configuration.tls?.customVerificationCallback
)

// Run the debug initializer, if there is one.
Expand Down
9 changes: 9 additions & 0 deletions Sources/GRPC/GRPCChannel/GRPCChannelBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,15 @@ extension ClientConnection.Builder.Secure {
self.tls.certificateVerification = certificateVerification
return self
}

/// A custom verification callback that allows completely overriding the certificate verification logic.
@discardableResult
public func withTLSCustomVerificationCallback(
_ callback: @escaping NIOSSLCustomVerificationCallback
) -> Self {
self.tls.customVerificationCallback = callback
return self
}
}

extension ClientConnection.Builder {
Expand Down
9 changes: 8 additions & 1 deletion Sources/GRPC/TLSConfiguration.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ extension ClientConnection.Configuration {
}
}

/// A custom verification callback that allows completely overriding the certificate verification logic for this connection.
public var customVerificationCallback: NIOSSLCustomVerificationCallback?

/// TLS Configuration with suitable defaults for clients.
///
/// This is a wrapper around `NIOSSL.TLSConfiguration` to restrict input to values which comply
Expand All @@ -83,12 +86,15 @@ extension ClientConnection.Configuration {
/// `.fullVerification`.
/// - Parameter hostnameOverride: Value to use for TLS SNI extension; this must not be an IP
/// address, defaults to `nil`.
/// - Parameter customVerificationCallback: A callback to provide to override the certificate verification logic,
/// defaults to `nil`.
public init(
certificateChain: [NIOSSLCertificateSource] = [],
privateKey: NIOSSLPrivateKeySource? = nil,
trustRoots: NIOSSLTrustRoots = .default,
certificateVerification: CertificateVerification = .fullVerification,
hostnameOverride: String? = nil
hostnameOverride: String? = nil,
customVerificationCallback: NIOSSLCustomVerificationCallback? = nil
) {
self.configuration = .forClient(
minimumTLSVersion: .tlsv12,
Expand All @@ -99,6 +105,7 @@ extension ClientConnection.Configuration {
applicationProtocols: GRPCApplicationProtocolIdentifier.client
)
self.hostnameOverride = hostnameOverride
self.customVerificationCallback = customVerificationCallback
}

/// Creates a TLS Configuration using the given `NIOSSL.TLSConfiguration`.
Expand Down
45 changes: 45 additions & 0 deletions Tests/GRPCTests/ClientTLSFailureTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,49 @@ class ClientTLSFailureTests: GRPCTestCase {
XCTFail("Expected NIOSSLExtraError.failedToValidateHostname")
}
}

func testClientConnectionFailsWhenCertificateValidationDenied() throws {
let errorExpectation = self.expectation(description: "error")
// 2 errors: one for the failed handshake, and another for failing the ready-channel promise
// (because the handshake failed).
errorExpectation.expectedFulfillmentCount = 2

let tlsConfiguration = ClientConnection.Configuration.TLS(
certificateChain: [.certificate(SampleCertificate.client.certificate)],
privateKey: .privateKey(SamplePrivateKey.client),
trustRoots: .certificates([SampleCertificate.ca.certificate]),
hostnameOverride: SampleCertificate.server.commonName,
customVerificationCallback: { _, promise in
// The certificate validation is forced to fail
promise.fail(NIOSSLError.unableToValidateCertificate)
}
)

var configuration = self.makeClientConfiguration(tls: tlsConfiguration)
let errorRecorder = ErrorRecordingDelegate(expectation: errorExpectation)
configuration.errorDelegate = errorRecorder

let stateChangeDelegate = RecordingConnectivityDelegate()
stateChangeDelegate.expectChanges(2) { changes in
XCTAssertEqual(changes, [
Change(from: .idle, to: .connecting),
Change(from: .connecting, to: .shutdown),
])
}
configuration.connectivityStateDelegate = stateChangeDelegate

// Start an RPC to trigger creating a channel.
let echo = Echo_EchoClient(channel: ClientConnection(configuration: configuration))
_ = echo.get(.with { $0.text = "foo" })

self.wait(for: [errorExpectation], timeout: self.defaultTestTimeout)
stateChangeDelegate.waitForExpectedChanges(timeout: .seconds(5))

if let nioSSLError = errorRecorder.errors.first as? NIOSSLError,
case .handshakeFailed(.sslError) = nioSSLError {
// Expected case.
} else {
XCTFail("Expected NIOSSLError.handshakeFailed(BoringSSL.sslError)")
}
}
}

0 comments on commit e70c2cf

Please sign in to comment.