From dd2c13a9d71754b279e3c3e8d610b382cc370332 Mon Sep 17 00:00:00 2001 From: 0xG0nz0 <8682922+0xg0nz0@users.noreply.github.com> Date: Sun, 31 Mar 2024 22:04:42 +0000 Subject: [PATCH] Improve test coverage for SSLOptions --- sdk/net/crypto/ssl.cc | 157 ++++++++++++++++++++++-------------------- sdk/net/crypto/ssl.h | 45 +++++++++++- tests/ssl_test.cc | 60 +++++++++++++--- 3 files changed, 179 insertions(+), 83 deletions(-) diff --git a/sdk/net/crypto/ssl.cc b/sdk/net/crypto/ssl.cc index d57c80f..1c79a03 100644 --- a/sdk/net/crypto/ssl.cc +++ b/sdk/net/crypto/ssl.cc @@ -50,8 +50,9 @@ const std::vector iggy::ssl::SSLOptions::getDefaultCipherList(iggy: ciphers.push_back("TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305"); #ifndef NO_RSA ciphers.push_back("TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305"); -#endif -#endif +#endif // NO_RSA +#endif // HAVE_CHACHA && HAVE_POLY1305 +#endif // WOLFSSL_NO_TLS12 } else { auto protocolVersionName = iggy::ssl::getProtocolVersionName(protocolVersion); throw std::runtime_error(fmt::format("Unsupported protocol version: {}", protocolVersionName)); @@ -62,92 +63,102 @@ const std::vector iggy::ssl::SSLOptions::getDefaultCipherList(iggy: throw std::runtime_error(fmt::format("No ciphers available for the specified protocol version: {}", protocolVersionName)); } return ciphers; -#endif +} + +iggy::ssl::SSLContext::SSLContext(const SSLOptions& options, const PKIEnvironment& pkiEnv) + : options(options) + , pkiEnv(pkiEnv) { + // before we make any other wolfSSL calls, make sure library is initialized once and only once + std::call_once(sslInitDone, []() { wolfSSL_Init(); }); + + // for now we only support a TLS 1.3 client context; if we generalize this code we can expand + this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); + if (!this->ctx) { + char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr); + throw std::runtime_error(fmt::format("Failed to allocate WolfSSL TLS context: {}", errMsg)); } + this->cm = wolfSSL_CTX_GetCertManager(ctx); - iggy::ssl::SSLContext::SSLContext(const SSLOptions& options, const PKIEnvironment& pkiEnv) - : options(options) - , pkiEnv(pkiEnv) { - // before we make any other wolfSSL calls, make sure library is initialized once and only once - std::call_once(sslInitDone, []() { wolfSSL_Init(); }); - - // for now we only support a TLS 1.3 client context; if we generalize this code we can expand - this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); - if (!this->ctx) { - char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr); - throw std::runtime_error(fmt::format("Failed to allocate WolfSSL TLS context: {}", errMsg)); - } - this->cm = wolfSSL_CTX_GetCertManager(ctx); + // set up the supported ciphers + std::string delimiter = ":"; + std::string joinedCiphers; - // set up the supported ciphers - std::string delimiter = ":"; - std::string joinedCiphers; + auto supportedCiphers = options.getCiphers(); + if (!supportedCiphers.empty()) { + joinedCiphers = std::accumulate(std::next(supportedCiphers.begin()), supportedCiphers.end(), supportedCiphers[0], + [delimiter](std::string a, std::string b) { return a + delimiter + b; }); + } + int ret = wolfSSL_CTX_set_cipher_list(this->ctx, joinedCiphers.c_str()); + if (ret != SSL_SUCCESS) { + char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr); + throw std::runtime_error(fmt::format("Failed to set cipher list: {}", errMsg)); + } +} - auto supportedCiphers = options.getCiphers(); - if (!supportedCiphers.empty()) { - joinedCiphers = std::accumulate(std::next(supportedCiphers.begin()), supportedCiphers.end(), supportedCiphers[0], - [delimiter](std::string a, std::string b) { return a + delimiter + b; }); - } - int ret = wolfSSL_CTX_set_cipher_list(this->ctx, joinedCiphers.c_str()); - if (ret != SSL_SUCCESS) { - char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr); - throw std::runtime_error(fmt::format("Failed to set cipher list: {}", errMsg)); +void iggy::ssl::SSLOptions::validate(bool strict) const { + if (strict) { + if (this->minimumSupportedProtocolVersion != iggy::ssl::ProtocolVersion::TLSV1_3) { + throw std::runtime_error("Only TLS 1.3 is supported in strict mode"); } } - - iggy::ssl::SSLContext::SSLContext(const SSLContext& other) - : options(other.options) - , pkiEnv(other.pkiEnv) { - this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); - this->cm = wolfSSL_CTX_GetCertManager(ctx); + if (this->peerType == iggy::ssl::PeerType::SERVER && !this->peerCertPath.has_value()) { + throw std::runtime_error("Server mode requires a peer certificate path"); } - - iggy::ssl::SSLContext::SSLContext(SSLContext && other) - : options(other.options) - , pkiEnv(other.pkiEnv) { - this->ctx = other.ctx; - this->cm = other.cm; - other.ctx = nullptr; - other.cm = nullptr; +} + +iggy::ssl::SSLContext::SSLContext(const SSLContext& other) + : options(other.options) + , pkiEnv(other.pkiEnv) { + this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); + this->cm = wolfSSL_CTX_GetCertManager(ctx); +} + +iggy::ssl::SSLContext::SSLContext(SSLContext&& other) + : options(other.options) + , pkiEnv(other.pkiEnv) { + this->ctx = other.ctx; + this->cm = other.cm; + other.ctx = nullptr; + other.cm = nullptr; +} + +iggy::ssl::SSLContext::~SSLContext() { + if (this->ctx) { + wolfSSL_CTX_free(this->ctx); } +} - iggy::ssl::SSLContext::~SSLContext() { +iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(const iggy::ssl::SSLContext& other) { + if (this != &other) { if (this->ctx) { wolfSSL_CTX_free(this->ctx); } + this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); } + return *this; +} - iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(const iggy::ssl::SSLContext& other) { - if (this != &other) { - if (this->ctx) { - wolfSSL_CTX_free(this->ctx); - } - this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method()); - } - return *this; - } - - iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(SSLContext&& other) { - if (this != &other) { - if (this->ctx) { - wolfSSL_CTX_free(this->ctx); - } - this->ctx = other.ctx; - other.ctx = nullptr; +iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(SSLContext&& other) { + if (this != &other) { + if (this->ctx) { + wolfSSL_CTX_free(this->ctx); } - return *this; + this->ctx = other.ctx; + other.ctx = nullptr; } - - std::string iggy::ssl::getProtocolVersionName(iggy::ssl::ProtocolVersion protocolVersion) { - switch (protocolVersion) { - case iggy::ssl::ProtocolVersion::TLSV1_3: - return "TLSV1_3"; - case iggy::ssl::ProtocolVersion::TLSV1_2: - return "TLSV1_2"; - default: - int protocolVersionInt = static_cast(protocolVersion); - throw std::runtime_error(fmt::format("Unsupported protocol version code: {}", protocolVersionInt)); - } + return *this; +} + +std::string iggy::ssl::getProtocolVersionName(iggy::ssl::ProtocolVersion protocolVersion) { + switch (protocolVersion) { + case iggy::ssl::ProtocolVersion::TLSV1_3: + return "TLSV1_3"; + case iggy::ssl::ProtocolVersion::TLSV1_2: + return "TLSV1_2"; + default: + int protocolVersionInt = static_cast(protocolVersion); + throw std::runtime_error(fmt::format("Unsupported protocol version code: {}", protocolVersionInt)); } +} - std::once_flag iggy::ssl::SSLContext::sslInitDone = std::once_flag(); +std::once_flag iggy::ssl::SSLContext::sslInitDone = std::once_flag(); diff --git a/sdk/net/crypto/ssl.h b/sdk/net/crypto/ssl.h index 97085f8..d4a9888 100644 --- a/sdk/net/crypto/ssl.h +++ b/sdk/net/crypto/ssl.h @@ -37,8 +37,8 @@ std::string getProtocolVersionName(iggy::ssl::ProtocolVersion protocolVersion); */ class SSLOptions { private: + PeerType peerType; std::optional peerCertPath = std::nullopt; - PeerType peerType = PeerType::CLIENT; ProtocolVersion minimumSupportedProtocolVersion = ProtocolVersion::TLSV1_3; std::vector ciphers = getDefaultCipherList(ProtocolVersion::TLSV1_3); @@ -46,7 +46,8 @@ class SSLOptions { /** * Creates a default set of options for a TLS 1.3-compatible client. */ - SSLOptions() = default; + explicit SSLOptions(PeerType peerType = PeerType::CLIENT) + : peerType(peerType) {} /** * @brief Gets the default cipher list for use in SSL/TLS contexts. @@ -54,6 +55,38 @@ class SSLOptions { */ static const std::vector getDefaultCipherList(ProtocolVersion protocolVersion); + /** + * @brief Gets the type of peer endpoint represented by this end of the socket. + */ + PeerType getPeerType() const { return this->peerType; } + + /** + * @brief Sets the type of peer endpoint represented by the local end of the socket. + */ + void setPeerType(PeerType peerType) { this->peerType = peerType; } + + /** + * @brief Sets the type of peer endpoint represented by the local end of the socket. + */ + std::optional getPeerCertificatePath() const { return this->peerCertPath; } + + /** + * @brief Sets the path to the peer's certificate, if any, to use for verifying the peer's identity. + */ + void setPeerCertificatePath(const std::string& peerCertPath) { this->peerCertPath = peerCertPath; } + + /** + * @brief Gets the minimum supported protocol version for the SSL/TLS context. + */ + ProtocolVersion getMinimumSupportedProtocolVersion() const { return this->minimumSupportedProtocolVersion; } + + /** + * @brief Sets the minimum supported protocol version for the SSL/TLS context. + */ + void setMinimumSupportedProtocolVersion(ProtocolVersion minimumSupportedProtocolVersion) { + this->minimumSupportedProtocolVersion = minimumSupportedProtocolVersion; + } + /** * @brief Gets the list of requested supported ciphers; will be validated by the context during init. */ @@ -63,6 +96,14 @@ class SSLOptions { * @brief Sets the list of requested supported ciphers; will be validated by the context during init. */ void setCiphers(const std::vector& ciphers) { this->ciphers = ciphers; } + + /** + * @brief Sanity checks the combination of options configured by the user. + * @param strict if true, will apply additional validations that may be more restrictive. + * + * Applies basic validations to the SSL options, e.g. if PeerType::SERVER is set, then a peer certificate path must be provided. + */ + void validate(bool strict = true) const; }; /** diff --git a/tests/ssl_test.cc b/tests/ssl_test.cc index 5c5561f..2ca8c48 100644 --- a/tests/ssl_test.cc +++ b/tests/ssl_test.cc @@ -3,14 +3,58 @@ TEST_CASE("SSL configuration", UT_TAG) { iggy::ssl::SSLOptions options; - auto cipherListTLSV1_2 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_2); - auto cipherListTLSV1_3 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_3); - REQUIRE(cipherListTLSV1_2.size() == 6); - REQUIRE(cipherListTLSV1_3.size() == 3); + SECTION("expected basic default settings") { + // default options should always be strictly valid + REQUIRE_NOTHROW(options.validate(true)); - std::string tls12Cipher = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; - std::string tls13Cipher = "TLS_CHACHA20_POLY1305_SHA256"; - CHECK(std::find(cipherListTLSV1_2.begin(), cipherListTLSV1_2.end(), tls12Cipher) != cipherListTLSV1_2.end()); - CHECK(std::find(cipherListTLSV1_3.begin(), cipherListTLSV1_3.end(), tls13Cipher) != cipherListTLSV1_3.end()); + REQUIRE(options.getPeerType() == iggy::ssl::PeerType::CLIENT); + REQUIRE(options.getPeerCertificatePath().has_value() == false); + REQUIRE(options.getMinimumSupportedProtocolVersion() == iggy::ssl::ProtocolVersion::TLSV1_3); + } + + SECTION("default cipher list configured") { + auto cipherListTLSV1_2 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_2); + auto cipherListTLSV1_3 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_3); + + REQUIRE(cipherListTLSV1_2.size() == 6); + REQUIRE(cipherListTLSV1_3.size() == 3); + + std::string tls12Cipher = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + std::string tls13Cipher = "TLS_CHACHA20_POLY1305_SHA256"; + CHECK(std::find(cipherListTLSV1_2.begin(), cipherListTLSV1_2.end(), tls12Cipher) != cipherListTLSV1_2.end()); + CHECK(std::find(cipherListTLSV1_3.begin(), cipherListTLSV1_3.end(), tls13Cipher) != cipherListTLSV1_3.end()); + } + + SECTION("configure bespoke cipers") { + auto testCipher = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + options.setCiphers({testCipher}); + + auto ciphers = options.getCiphers(); + REQUIRE(ciphers.size() == 1); + CHECK(std::find(ciphers.begin(), ciphers.end(), testCipher) != ciphers.end()); + + REQUIRE_NOTHROW(options.validate(true)); + } + + SECTION("configure server options") { + options.setPeerType(iggy::ssl::PeerType::SERVER); + + // missing certificate path + REQUIRE_THROWS(options.validate()); + + // fix the issue + options.setPeerCertificatePath("test.pem"); + + // first try strict validation, fail + options.setMinimumSupportedProtocolVersion(iggy::ssl::ProtocolVersion::TLSV1_2); + REQUIRE_THROWS(options.validate(true)); + + // loosen the validation + REQUIRE_NOTHROW(options.validate(false)); + + // finally harden the settings and tighten up validation + options.setMinimumSupportedProtocolVersion(iggy::ssl::ProtocolVersion::TLSV1_3); + REQUIRE_NOTHROW(options.validate(true)); + } }