Skip to content

Commit

Permalink
Improve test coverage for SSLOptions
Browse files Browse the repository at this point in the history
  • Loading branch information
0xg0nz0 committed Mar 31, 2024
1 parent 4743845 commit dd2c13a
Show file tree
Hide file tree
Showing 3 changed files with 179 additions and 83 deletions.
157 changes: 84 additions & 73 deletions sdk/net/crypto/ssl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ const std::vector<std::string> iggy::ssl::SSLOptions::getDefaultCipherList(iggy:
ciphers.push_back("TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305");
#ifndef NO_RSA
ciphers.push_back("TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305");
#endif
#endif
#endif // NO_RSA
#endif // HAVE_CHACHA && HAVE_POLY1305
#endif // WOLFSSL_NO_TLS12
} else {
auto protocolVersionName = iggy::ssl::getProtocolVersionName(protocolVersion);
throw std::runtime_error(fmt::format("Unsupported protocol version: {}", protocolVersionName));
Expand All @@ -62,92 +63,102 @@ const std::vector<std::string> iggy::ssl::SSLOptions::getDefaultCipherList(iggy:
throw std::runtime_error(fmt::format("No ciphers available for the specified protocol version: {}", protocolVersionName));
}
return ciphers;
#endif
}

iggy::ssl::SSLContext::SSLContext(const SSLOptions& options, const PKIEnvironment& pkiEnv)
: options(options)
, pkiEnv(pkiEnv) {
// before we make any other wolfSSL calls, make sure library is initialized once and only once
std::call_once(sslInitDone, []() { wolfSSL_Init(); });

// for now we only support a TLS 1.3 client context; if we generalize this code we can expand
this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method());
if (!this->ctx) {
char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr);
throw std::runtime_error(fmt::format("Failed to allocate WolfSSL TLS context: {}", errMsg));
}
this->cm = wolfSSL_CTX_GetCertManager(ctx);

iggy::ssl::SSLContext::SSLContext(const SSLOptions& options, const PKIEnvironment& pkiEnv)
: options(options)
, pkiEnv(pkiEnv) {
// before we make any other wolfSSL calls, make sure library is initialized once and only once
std::call_once(sslInitDone, []() { wolfSSL_Init(); });

// for now we only support a TLS 1.3 client context; if we generalize this code we can expand
this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method());
if (!this->ctx) {
char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr);
throw std::runtime_error(fmt::format("Failed to allocate WolfSSL TLS context: {}", errMsg));
}
this->cm = wolfSSL_CTX_GetCertManager(ctx);
// set up the supported ciphers
std::string delimiter = ":";
std::string joinedCiphers;

// set up the supported ciphers
std::string delimiter = ":";
std::string joinedCiphers;
auto supportedCiphers = options.getCiphers();
if (!supportedCiphers.empty()) {
joinedCiphers = std::accumulate(std::next(supportedCiphers.begin()), supportedCiphers.end(), supportedCiphers[0],
[delimiter](std::string a, std::string b) { return a + delimiter + b; });
}
int ret = wolfSSL_CTX_set_cipher_list(this->ctx, joinedCiphers.c_str());
if (ret != SSL_SUCCESS) {
char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr);
throw std::runtime_error(fmt::format("Failed to set cipher list: {}", errMsg));
}
}

auto supportedCiphers = options.getCiphers();
if (!supportedCiphers.empty()) {
joinedCiphers = std::accumulate(std::next(supportedCiphers.begin()), supportedCiphers.end(), supportedCiphers[0],
[delimiter](std::string a, std::string b) { return a + delimiter + b; });
}
int ret = wolfSSL_CTX_set_cipher_list(this->ctx, joinedCiphers.c_str());
if (ret != SSL_SUCCESS) {
char* errMsg = wolfSSL_ERR_error_string(wolfSSL_ERR_get_error(), nullptr);
throw std::runtime_error(fmt::format("Failed to set cipher list: {}", errMsg));
void iggy::ssl::SSLOptions::validate(bool strict) const {
if (strict) {
if (this->minimumSupportedProtocolVersion != iggy::ssl::ProtocolVersion::TLSV1_3) {
throw std::runtime_error("Only TLS 1.3 is supported in strict mode");
}
}

iggy::ssl::SSLContext::SSLContext(const SSLContext& other)
: options(other.options)
, pkiEnv(other.pkiEnv) {
this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method());
this->cm = wolfSSL_CTX_GetCertManager(ctx);
if (this->peerType == iggy::ssl::PeerType::SERVER && !this->peerCertPath.has_value()) {
throw std::runtime_error("Server mode requires a peer certificate path");
}

iggy::ssl::SSLContext::SSLContext(SSLContext && other)
: options(other.options)
, pkiEnv(other.pkiEnv) {
this->ctx = other.ctx;
this->cm = other.cm;
other.ctx = nullptr;
other.cm = nullptr;
}

iggy::ssl::SSLContext::SSLContext(const SSLContext& other)
: options(other.options)
, pkiEnv(other.pkiEnv) {
this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method());
this->cm = wolfSSL_CTX_GetCertManager(ctx);
}

iggy::ssl::SSLContext::SSLContext(SSLContext&& other)
: options(other.options)
, pkiEnv(other.pkiEnv) {
this->ctx = other.ctx;
this->cm = other.cm;
other.ctx = nullptr;
other.cm = nullptr;
}

iggy::ssl::SSLContext::~SSLContext() {
if (this->ctx) {
wolfSSL_CTX_free(this->ctx);
}
}

iggy::ssl::SSLContext::~SSLContext() {
iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(const iggy::ssl::SSLContext& other) {
if (this != &other) {
if (this->ctx) {
wolfSSL_CTX_free(this->ctx);
}
this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method());
}
return *this;
}

iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(const iggy::ssl::SSLContext& other) {
if (this != &other) {
if (this->ctx) {
wolfSSL_CTX_free(this->ctx);
}
this->ctx = wolfSSL_CTX_new(wolfTLSv1_3_client_method());
}
return *this;
}

iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(SSLContext&& other) {
if (this != &other) {
if (this->ctx) {
wolfSSL_CTX_free(this->ctx);
}
this->ctx = other.ctx;
other.ctx = nullptr;
iggy::ssl::SSLContext& iggy::ssl::SSLContext::operator=(SSLContext&& other) {
if (this != &other) {
if (this->ctx) {
wolfSSL_CTX_free(this->ctx);
}
return *this;
this->ctx = other.ctx;
other.ctx = nullptr;
}

std::string iggy::ssl::getProtocolVersionName(iggy::ssl::ProtocolVersion protocolVersion) {
switch (protocolVersion) {
case iggy::ssl::ProtocolVersion::TLSV1_3:
return "TLSV1_3";
case iggy::ssl::ProtocolVersion::TLSV1_2:
return "TLSV1_2";
default:
int protocolVersionInt = static_cast<int>(protocolVersion);
throw std::runtime_error(fmt::format("Unsupported protocol version code: {}", protocolVersionInt));
}
return *this;
}

std::string iggy::ssl::getProtocolVersionName(iggy::ssl::ProtocolVersion protocolVersion) {
switch (protocolVersion) {
case iggy::ssl::ProtocolVersion::TLSV1_3:
return "TLSV1_3";
case iggy::ssl::ProtocolVersion::TLSV1_2:
return "TLSV1_2";
default:
int protocolVersionInt = static_cast<int>(protocolVersion);
throw std::runtime_error(fmt::format("Unsupported protocol version code: {}", protocolVersionInt));
}
}

std::once_flag iggy::ssl::SSLContext::sslInitDone = std::once_flag();
std::once_flag iggy::ssl::SSLContext::sslInitDone = std::once_flag();
45 changes: 43 additions & 2 deletions sdk/net/crypto/ssl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,56 @@ std::string getProtocolVersionName(iggy::ssl::ProtocolVersion protocolVersion);
*/
class SSLOptions {
private:
PeerType peerType;
std::optional<std::string> peerCertPath = std::nullopt;
PeerType peerType = PeerType::CLIENT;
ProtocolVersion minimumSupportedProtocolVersion = ProtocolVersion::TLSV1_3;
std::vector<std::string> ciphers = getDefaultCipherList(ProtocolVersion::TLSV1_3);

public:
/**
* Creates a default set of options for a TLS 1.3-compatible client.
*/
SSLOptions() = default;
explicit SSLOptions(PeerType peerType = PeerType::CLIENT)
: peerType(peerType) {}

/**
* @brief Gets the default cipher list for use in SSL/TLS contexts.
* @return A vector of cipher strings, all uppercase.
*/
static const std::vector<std::string> getDefaultCipherList(ProtocolVersion protocolVersion);

/**
* @brief Gets the type of peer endpoint represented by this end of the socket.
*/
PeerType getPeerType() const { return this->peerType; }

/**
* @brief Sets the type of peer endpoint represented by the local end of the socket.
*/
void setPeerType(PeerType peerType) { this->peerType = peerType; }

/**
* @brief Sets the type of peer endpoint represented by the local end of the socket.
*/
std::optional<std::string> getPeerCertificatePath() const { return this->peerCertPath; }

/**
* @brief Sets the path to the peer's certificate, if any, to use for verifying the peer's identity.
*/
void setPeerCertificatePath(const std::string& peerCertPath) { this->peerCertPath = peerCertPath; }

/**
* @brief Gets the minimum supported protocol version for the SSL/TLS context.
*/
ProtocolVersion getMinimumSupportedProtocolVersion() const { return this->minimumSupportedProtocolVersion; }

/**
* @brief Sets the minimum supported protocol version for the SSL/TLS context.
*/
void setMinimumSupportedProtocolVersion(ProtocolVersion minimumSupportedProtocolVersion) {
this->minimumSupportedProtocolVersion = minimumSupportedProtocolVersion;
}

/**
* @brief Gets the list of requested supported ciphers; will be validated by the context during init.
*/
Expand All @@ -63,6 +96,14 @@ class SSLOptions {
* @brief Sets the list of requested supported ciphers; will be validated by the context during init.
*/
void setCiphers(const std::vector<std::string>& ciphers) { this->ciphers = ciphers; }

/**
* @brief Sanity checks the combination of options configured by the user.
* @param strict if true, will apply additional validations that may be more restrictive.
*
* Applies basic validations to the SSL options, e.g. if PeerType::SERVER is set, then a peer certificate path must be provided.
*/
void validate(bool strict = true) const;
};

/**
Expand Down
60 changes: 52 additions & 8 deletions tests/ssl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,58 @@

TEST_CASE("SSL configuration", UT_TAG) {
iggy::ssl::SSLOptions options;
auto cipherListTLSV1_2 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_2);
auto cipherListTLSV1_3 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_3);

REQUIRE(cipherListTLSV1_2.size() == 6);
REQUIRE(cipherListTLSV1_3.size() == 3);
SECTION("expected basic default settings") {
// default options should always be strictly valid
REQUIRE_NOTHROW(options.validate(true));

std::string tls12Cipher = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384";
std::string tls13Cipher = "TLS_CHACHA20_POLY1305_SHA256";
CHECK(std::find(cipherListTLSV1_2.begin(), cipherListTLSV1_2.end(), tls12Cipher) != cipherListTLSV1_2.end());
CHECK(std::find(cipherListTLSV1_3.begin(), cipherListTLSV1_3.end(), tls13Cipher) != cipherListTLSV1_3.end());
REQUIRE(options.getPeerType() == iggy::ssl::PeerType::CLIENT);
REQUIRE(options.getPeerCertificatePath().has_value() == false);
REQUIRE(options.getMinimumSupportedProtocolVersion() == iggy::ssl::ProtocolVersion::TLSV1_3);
}

SECTION("default cipher list configured") {
auto cipherListTLSV1_2 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_2);
auto cipherListTLSV1_3 = options.getDefaultCipherList(iggy::ssl::ProtocolVersion::TLSV1_3);

REQUIRE(cipherListTLSV1_2.size() == 6);
REQUIRE(cipherListTLSV1_3.size() == 3);

std::string tls12Cipher = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384";
std::string tls13Cipher = "TLS_CHACHA20_POLY1305_SHA256";
CHECK(std::find(cipherListTLSV1_2.begin(), cipherListTLSV1_2.end(), tls12Cipher) != cipherListTLSV1_2.end());
CHECK(std::find(cipherListTLSV1_3.begin(), cipherListTLSV1_3.end(), tls13Cipher) != cipherListTLSV1_3.end());
}

SECTION("configure bespoke cipers") {
auto testCipher = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384";
options.setCiphers({testCipher});

auto ciphers = options.getCiphers();
REQUIRE(ciphers.size() == 1);
CHECK(std::find(ciphers.begin(), ciphers.end(), testCipher) != ciphers.end());

REQUIRE_NOTHROW(options.validate(true));
}

SECTION("configure server options") {
options.setPeerType(iggy::ssl::PeerType::SERVER);

// missing certificate path
REQUIRE_THROWS(options.validate());

// fix the issue
options.setPeerCertificatePath("test.pem");

// first try strict validation, fail
options.setMinimumSupportedProtocolVersion(iggy::ssl::ProtocolVersion::TLSV1_2);
REQUIRE_THROWS(options.validate(true));

// loosen the validation
REQUIRE_NOTHROW(options.validate(false));

// finally harden the settings and tighten up validation
options.setMinimumSupportedProtocolVersion(iggy::ssl::ProtocolVersion::TLSV1_3);
REQUIRE_NOTHROW(options.validate(true));
}
}

0 comments on commit dd2c13a

Please sign in to comment.