diff --git a/mobile/library/cc/engine_builder.cc b/mobile/library/cc/engine_builder.cc index 74a66e6a4dfb..8152ccef64a4 100644 --- a/mobile/library/cc/engine_builder.cc +++ b/mobile/library/cc/engine_builder.cc @@ -929,14 +929,14 @@ EngineSharedPtr EngineBuilder::build() { Engine* engine = new Engine(envoy_engine); - auto options = std::make_unique(); + auto options = std::make_shared(); std::unique_ptr bootstrap = generateBootstrap(); if (bootstrap) { options->setConfigProto(std::move(bootstrap)); } ENVOY_BUG(options->setLogLevel(logLevelToString(log_level_)).ok(), "invalid log level"); options->setConcurrency(1); - envoy_engine->run(std::move(options)); + envoy_engine->run(options); // we can't construct via std::make_shared // because Engine is only constructible as a friend diff --git a/mobile/library/common/BUILD b/mobile/library/common/BUILD index b955b2db2d36..f2c47473ceaa 100644 --- a/mobile/library/common/BUILD +++ b/mobile/library/common/BUILD @@ -35,6 +35,7 @@ envoy_cc_library( "//library/common/types:c_types_lib", "@envoy//envoy/server:lifecycle_notifier_interface", "@envoy//envoy/stats:stats_interface", + "@envoy//source/common/common:thread_impl_lib_posix", "@envoy//source/common/runtime:runtime_lib", "@envoy_build_config//:extension_registry", ], diff --git a/mobile/library/common/engine_common.cc b/mobile/library/common/engine_common.cc index a7f5fb3bcfec..7cfe535ae3ec 100644 --- a/mobile/library/common/engine_common.cc +++ b/mobile/library/common/engine_common.cc @@ -67,8 +67,7 @@ class ServerLite : public Server::InstanceBase { } }; -EngineCommon::EngineCommon(std::unique_ptr&& options) - : options_(std::move(options)) { +EngineCommon::EngineCommon(std::shared_ptr options) : options_(options) { #if !defined(ENVOY_ENABLE_FULL_PROTOS) registerMobileProtoDescriptors(); diff --git a/mobile/library/common/engine_common.h b/mobile/library/common/engine_common.h index f96722c63c69..a098ac59ce72 100644 --- a/mobile/library/common/engine_common.h +++ b/mobile/library/common/engine_common.h @@ -22,7 +22,7 @@ namespace Envoy { */ class EngineCommon { public: - EngineCommon(std::unique_ptr&& options); + EngineCommon(std::shared_ptr options); bool run() { base_->runServer(); return true; @@ -41,11 +41,11 @@ class EngineCommon { Envoy::SignalAction handle_sigs_; Envoy::TerminateHandler log_on_terminate_; #endif - std::unique_ptr options_; + std::shared_ptr options_; Event::RealTimeSystem real_time_system_; // NO_CHECK_FORMAT(real_time) DefaultListenerHooks default_listener_hooks_; ProdComponentFactory prod_component_factory_; - std::unique_ptr base_; + std::shared_ptr base_; }; } // namespace Envoy diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/BUILD b/mobile/library/common/extensions/cert_validator/platform_bridge/BUILD index bf2371bba072..79553a0b1963 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/BUILD +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/BUILD @@ -35,6 +35,9 @@ envoy_cc_library( ":c_types_lib", ":platform_bridge_cc_proto", "//library/common/system:system_helper_lib", + "@envoy//envoy/thread:thread_interface", + "@envoy//source/common/common:macros", + "@envoy//source/common/common:thread_impl_lib_posix", "@envoy//source/common/tls/cert_validator:cert_validator_lib", ], ) diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc index 602c0b1e4213..9860ee3cec49 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc @@ -2,7 +2,6 @@ #include #include -#include #include "library/common/data/utility.h" #include "library/common/system/system_helper.h" @@ -13,22 +12,27 @@ namespace TransportSockets { namespace Tls { PlatformBridgeCertValidator::PlatformBridgeCertValidator( - const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats) + const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, + Thread::PosixThreadFactoryPtr thread_factory) : allow_untrusted_certificate_(config != nullptr && config->trustChainVerification() == envoy::extensions::transport_sockets::tls::v3:: CertificateValidationContext::ACCEPT_UNTRUSTED), - stats_(stats) { + stats_(stats), thread_factory_(std::move(thread_factory)) { ENVOY_BUG(config != nullptr && config->caCert().empty() && config->certificateRevocationList().empty(), "Invalid certificate validation context config."); } +PlatformBridgeCertValidator::PlatformBridgeCertValidator( + const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats) + : PlatformBridgeCertValidator(config, stats, Thread::PosixThreadFactory::create()) {} + PlatformBridgeCertValidator::~PlatformBridgeCertValidator() { // Wait for validation threads to finish. for (auto& [id, job] : validation_jobs_) { - if (job.validation_thread_.joinable()) { - job.validation_thread_.join(); + if (job.validation_thread_->joinable()) { + job.validation_thread_->join(); } } } @@ -84,10 +88,19 @@ ValidationResults PlatformBridgeCertValidator::doVerifyCertChain( ValidationJob job; job.result_callback_ = std::move(callback); - job.validation_thread_ = - std::thread(&verifyCertChainByPlatform, &(job.result_callback_->dispatcher()), - std::move(certs), std::string(host), std::move(subject_alt_names), this); - std::thread::id thread_id = job.validation_thread_.get_id(); + Event::Dispatcher& dispatcher = job.result_callback_->dispatcher(); + job.validation_thread_ = thread_factory_->createThread( + [this, &dispatcher, certs = std::move(certs), host = std::string(host), + subject_alt_names = std::move(subject_alt_names)]() -> void { + verifyCertChainByPlatform(&dispatcher, certs, host, subject_alt_names, this); + }, + /* options= */ absl::nullopt, /* crash_on_failure=*/false); + if (job.validation_thread_ == nullptr) { + return {ValidationResults::ValidationStatus::Failed, + Envoy::Ssl::ClientValidationStatus::NotValidated, absl::nullopt, + "Failed creating a thread for cert chain validation."}; + } + Thread::ThreadId thread_id = job.validation_thread_->pthreadId(); validation_jobs_[thread_id] = std::move(job); return {ValidationResults::ValidationStatus::Pending, Envoy::Ssl::ClientValidationStatus::NotValidated, absl::nullopt, absl::nullopt}; @@ -146,7 +159,7 @@ void PlatformBridgeCertValidator::postVerifyResultAndCleanUp(bool success, std:: dispatcher->post([weak_alive_indicator, success, hostname = std::move(hostname), error = std::string(error_details), tls_alert, failure_type, - thread_id = std::this_thread::get_id(), parent]() { + thread_id = parent->thread_factory_->currentPthreadId(), parent]() { if (weak_alive_indicator.expired()) { return; } @@ -154,9 +167,10 @@ void PlatformBridgeCertValidator::postVerifyResultAndCleanUp(bool success, std:: }); } -void PlatformBridgeCertValidator::onVerificationComplete(std::thread::id thread_id, - std::string hostname, bool success, - std::string error, uint8_t tls_alert, +void PlatformBridgeCertValidator::onVerificationComplete(const Thread::ThreadId& thread_id, + const std::string& hostname, bool success, + const std::string& error, + uint8_t tls_alert, ValidationFailureType failure_type) { ENVOY_LOG(trace, "Got validation result for {} from platform", hostname); @@ -166,7 +180,7 @@ void PlatformBridgeCertValidator::onVerificationComplete(std::thread::id thread_ return; } ValidationJob& job = job_handle.mapped(); - job.validation_thread_.join(); + job.validation_thread_->join(); Ssl::ClientValidationStatus detailed_status = Envoy::Ssl::ClientValidationStatus::NotValidated; switch (failure_type) { diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h index 42587bbb546c..39e114be795e 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h @@ -1,7 +1,8 @@ #pragma once -#include - +#include "source/common/common/macros.h" +#include "source/common/common/posix/thread_impl.h" +#include "source/common/common/thread.h" #include "source/common/tls/cert_validator/default_validator.h" #include "absl/container/flat_hash_map.h" @@ -55,6 +56,11 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable validation_jobs_; + absl::flat_hash_map validation_jobs_; std::shared_ptr alive_indicator_{new size_t(1)}; + Thread::PosixThreadFactoryPtr thread_factory_; }; } // namespace Tls diff --git a/mobile/library/common/internal_engine.cc b/mobile/library/common/internal_engine.cc index cb0caf553c8f..9269ba676678 100644 --- a/mobile/library/common/internal_engine.cc +++ b/mobile/library/common/internal_engine.cc @@ -16,9 +16,10 @@ static std::atomic current_stream_handle_{0}; envoy_stream_t InternalEngine::initStream() { return current_stream_handle_++; } InternalEngine::InternalEngine(envoy_engine_callbacks callbacks, envoy_logger logger, - envoy_event_tracker event_tracker) - : callbacks_(callbacks), logger_(logger), event_tracker_(event_tracker), - dispatcher_(std::make_unique()) { + envoy_event_tracker event_tracker, + Thread::PosixThreadFactoryPtr thread_factory) + : thread_factory_(std::move(thread_factory)), callbacks_(callbacks), logger_(logger), + event_tracker_(event_tracker), dispatcher_(std::make_unique()) { ExtensionRegistry::registerFactories(); // TODO(Augustyniak): Capturing an address of event_tracker_ and registering it in the API @@ -32,12 +33,13 @@ InternalEngine::InternalEngine(envoy_engine_callbacks callbacks, envoy_logger lo Runtime::maybeSetRuntimeGuard("envoy.reloadable_features.dfp_mixed_scheme", true); } +InternalEngine::InternalEngine(envoy_engine_callbacks callbacks, envoy_logger logger, + envoy_event_tracker event_tracker) + : InternalEngine(callbacks, logger, event_tracker, Thread::PosixThreadFactory::create()) {} + envoy_status_t InternalEngine::run(const std::string& config, const std::string& log_level) { - // Start the Envoy on the dedicated thread. Note: due to how the assignment operator works with - // std::thread, main_thread_ is the same object after this call, but its state is replaced with - // that of the temporary. The temporary object's state becomes the default state, which does - // nothing. - auto options = std::make_unique(); + // Start the Envoy on the dedicated thread. + auto options = std::make_shared(); options->setConfigYaml(config); if (!log_level.empty()) { ENVOY_BUG(options->setLogLevel(log_level).ok(), "invalid log level"); @@ -46,12 +48,17 @@ envoy_status_t InternalEngine::run(const std::string& config, const std::string& return run(std::move(options)); } -envoy_status_t InternalEngine::run(std::unique_ptr&& options) { - main_thread_ = std::thread(&InternalEngine::main, this, std::move(options)); - return ENVOY_SUCCESS; +// This function takes a `std::shared_ptr` instead of `std::unique_ptr` because `std::function` is a +// copy-constructible type, so it's not possible to move capture `std::unique_ptr` with +// `std::function`. +envoy_status_t InternalEngine::run(std::shared_ptr options) { + main_thread_ = + thread_factory_->createThread([this, options]() mutable -> void { main(options); }, + /* options= */ absl::nullopt, /* crash_on_failure= */ false); + return (main_thread_ != nullptr) ? ENVOY_SUCCESS : ENVOY_FAILURE; } -envoy_status_t InternalEngine::main(std::unique_ptr&& options) { +envoy_status_t InternalEngine::main(std::shared_ptr options) { // Using unique_ptr ensures main_common's lifespan is strictly scoped to this function. std::unique_ptr main_common; { @@ -81,7 +88,7 @@ envoy_status_t InternalEngine::main(std::unique_ptr&& op std::make_unique(log_mutex_, Logger::Registry::getSink()); } - main_common = std::make_unique(std::move(options)); + main_common = std::make_unique(options); server_ = main_common->server(); event_dispatcher_ = &server_->dispatcher(); @@ -150,8 +157,12 @@ envoy_status_t InternalEngine::terminate() { IS_ENVOY_BUG("attempted to double terminate engine"); return ENVOY_FAILURE; } + // The Engine could not be created. + if (main_thread_ == nullptr) { + return ENVOY_FAILURE; + } // If main_thread_ has finished (or hasn't started), there's nothing more to do. - if (!main_thread_.joinable()) { + if (!main_thread_->joinable()) { return ENVOY_FAILURE; } @@ -170,7 +181,7 @@ envoy_status_t InternalEngine::terminate() { dispatcher_->post([this]() { http_client_->shutdownApiListener(); }); // Exit the event loop and finish up in Engine::run(...) - if (std::this_thread::get_id() == main_thread_.get_id()) { + if (thread_factory_->currentPthreadId() == main_thread_->pthreadId()) { // TODO(goaway): figure out some way to support this. PANIC("Terminating the engine from its own main thread is currently unsupported."); } else { @@ -178,8 +189,8 @@ envoy_status_t InternalEngine::terminate() { } } // lock(_mutex) - if (std::this_thread::get_id() != main_thread_.get_id()) { - main_thread_.join(); + if (thread_factory_->currentPthreadId() != main_thread_->pthreadId()) { + main_thread_->join(); } terminated_ = true; return ENVOY_SUCCESS; @@ -265,7 +276,7 @@ void handlerStats(Stats::Store& stats, Buffer::Instance& response) { } std::string InternalEngine::dumpStats() { - if (!main_thread_.joinable()) { + if (!main_thread_->joinable()) { return ""; } diff --git a/mobile/library/common/internal_engine.h b/mobile/library/common/internal_engine.h index 0b45ae95bde1..3ac09ced5d07 100644 --- a/mobile/library/common/internal_engine.h +++ b/mobile/library/common/internal_engine.h @@ -4,6 +4,9 @@ #include "envoy/stats/store.h" #include "source/common/common/logger.h" +#include "source/common/common/macros.h" +#include "source/common/common/posix/thread_impl.h" +#include "source/common/common/thread.h" #include "absl/base/call_once.h" #include "extension_registry.h" @@ -37,7 +40,7 @@ class InternalEngine : public Logger::Loggable { * @param log_level, the log level. */ envoy_status_t run(const std::string& config, const std::string& log_level); - envoy_status_t run(std::unique_ptr&& options); + envoy_status_t run(std::shared_ptr options); /** * Immediately terminate the engine, if running. Calling this function when @@ -118,10 +121,16 @@ class InternalEngine : public Logger::Loggable { Stats::Store& getStatsStore(); private: - envoy_status_t main(std::unique_ptr&& options); + GTEST_FRIEND_CLASS(InternalEngineTest, ThreadCreationFailed); + + InternalEngine(envoy_engine_callbacks callbacks, envoy_logger logger, + envoy_event_tracker event_tracker, Thread::PosixThreadFactoryPtr thread_factory); + + envoy_status_t main(std::shared_ptr options); static void logInterfaces(absl::string_view event, std::vector& interfaces); + Thread::PosixThreadFactoryPtr thread_factory_; Event::Dispatcher* event_dispatcher_{}; Stats::ScopeSharedPtr client_scope_; Stats::StatNameSetPtr stat_name_set_; @@ -142,7 +151,7 @@ class InternalEngine : public Logger::Loggable { Server::ServerLifecycleNotifier::HandlePtr postinit_callback_handler_; // main_thread_ should be destroyed first, hence it is the last member variable. Objects with // instructions scheduled on the main_thread_ need to have a longer lifetime. - std::thread main_thread_{}; // Empty placeholder to be populated later. + Thread::PosixThreadPtr main_thread_{nullptr}; // Empty placeholder to be populated later. bool terminated_{false}; }; diff --git a/mobile/test/common/BUILD b/mobile/test/common/BUILD index 91686f69ecef..d44bd610c643 100644 --- a/mobile/test/common/BUILD +++ b/mobile/test/common/BUILD @@ -32,6 +32,7 @@ envoy_cc_test( "//test/common/mocks/common:common_mocks", "//test/common/mocks/event:event_mocks", "@envoy//test/common/http:common_lib", + "@envoy//test/mocks/thread:thread_mocks", "@envoy_build_config//:test_extensions", ], ) diff --git a/mobile/test/common/engine_common_test.cc b/mobile/test/common/engine_common_test.cc index ffb511fa51a6..54a2d741ff53 100644 --- a/mobile/test/common/engine_common_test.cc +++ b/mobile/test/common/engine_common_test.cc @@ -7,7 +7,7 @@ namespace Envoy { TEST(EngineCommonTest, SignalHandlingFalse) { ExtensionRegistry::registerFactories(); - auto options = std::make_unique(); + auto options = std::make_shared(); Platform::EngineBuilder builder; options->setConfigProto(builder.generateBootstrap()); diff --git a/mobile/test/common/extensions/cert_validator/platform_bridge/BUILD b/mobile/test/common/extensions/cert_validator/platform_bridge/BUILD index 2956e5eeed83..1c743e436a73 100644 --- a/mobile/test/common/extensions/cert_validator/platform_bridge/BUILD +++ b/mobile/test/common/extensions/cert_validator/platform_bridge/BUILD @@ -25,6 +25,7 @@ envoy_extension_cc_test( "@envoy//test/common/tls/test_data:cert_infos", "@envoy//test/mocks/event:event_mocks", "@envoy//test/mocks/ssl:ssl_mocks", + "@envoy//test/mocks/thread:thread_mocks", "@envoy//test/test_common:environment_lib", "@envoy//test/test_common:file_system_for_test_lib", "@envoy//test/test_common:test_runtime_lib", diff --git a/mobile/test/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator_test.cc b/mobile/test/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator_test.cc index 2d05403d394a..e5e4e6391641 100644 --- a/mobile/test/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator_test.cc +++ b/mobile/test/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator_test.cc @@ -2,7 +2,6 @@ #include #include "source/common/buffer/buffer_impl.h" -#include "source/common/crypto/crypto_impl.h" #include "source/common/crypto/utility.h" #include "source/common/network/transport_socket_options_impl.h" #include "source/common/tls/cert_validator/default_validator.h" @@ -16,6 +15,7 @@ #include "test/common/tls/test_data/san_dns2_cert_info.h" #include "test/mocks/event/mocks.h" #include "test/mocks/ssl/mocks.h" +#include "test/mocks/thread/mocks.h" #include "test/test_common/environment.h" #include "test/test_common/test_runtime.h" #include "test/test_common/utility.h" @@ -32,6 +32,7 @@ using SSLContextPtr = Envoy::CSmartPtr; using envoy::extensions::transport_sockets::tls::v3::CertificateValidationContext; using testing::_; +using testing::ByMove; using testing::NiceMock; using testing::Return; using testing::ReturnRef; @@ -63,11 +64,12 @@ class PlatformBridgeCertValidatorTest : public testing::TestWithParam { protected: PlatformBridgeCertValidatorTest() - : api_(Api::createApiForTest()), dispatcher_(api_->allocateDispatcher("test_thread")), + : thread_factory_(Thread::PosixThreadFactory::create()), api_(Api::createApiForTest()), + dispatcher_(api_->allocateDispatcher("test_thread")), stats_(generateSslStats(*test_store_.rootScope())), ssl_ctx_(SSL_CTX_new(TLS_method())), callback_(std::make_unique()), is_server_(false), mock_validator_(std::make_unique()), - main_thread_id_(std::this_thread::get_id()), + main_thread_id_(thread_factory_->currentPthreadId()), helper_handle_(test::SystemHelperPeer::replaceSystemHelper()) { ON_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)) .WillByDefault(WithArgs<0, 1>(Invoke(this, &PlatformBridgeCertValidatorTest::validate))); @@ -86,7 +88,7 @@ class PlatformBridgeCertValidatorTest ~PlatformBridgeCertValidatorTest() { mock_validator_.reset(); - main_thread_id_ = std::thread::id(); + main_thread_id_ = thread_factory_->currentPthreadId(); Envoy::Assert::resetEnvoyBugCountersForTest(); } @@ -104,7 +106,7 @@ class PlatformBridgeCertValidatorTest envoy_cert_validation_result validate(const std::vector& certs, absl::string_view hostname) { // Validate must be called on the worker thread, not the main thread. - EXPECT_NE(main_thread_id_, std::this_thread::get_id()); + EXPECT_NE(main_thread_id_, thread_factory_->currentPthreadId()); // Make sure the cert was converted correctly. const Buffer::InstancePtr buffer(new Buffer::OwnedImpl(certs[0])); @@ -115,10 +117,11 @@ class PlatformBridgeCertValidatorTest void cleanup() { // Validate must be called on the worker thread, not the main thread. - EXPECT_NE(main_thread_id_, std::this_thread::get_id()); + EXPECT_NE(main_thread_id_, thread_factory_->currentPthreadId()); mock_validator_->cleanup(); } + Thread::PosixThreadFactoryPtr thread_factory_; Api::ApiPtr api_; Event::DispatcherPtr dispatcher_; Stats::TestUtil::TestStore test_store_; @@ -131,7 +134,7 @@ class PlatformBridgeCertValidatorTest std::unique_ptr callback_; bool is_server_; std::unique_ptr mock_validator_; - std::thread::id main_thread_id_; + Thread::ThreadId main_thread_id_; std::unique_ptr helper_handle_; }; @@ -222,7 +225,7 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificate) { EXPECT_CALL(callback_ref, onCertValidationResult(true, Envoy::Ssl::ClientValidationStatus::Validated, "", 46)) .WillOnce(Invoke([this]() { - EXPECT_EQ(main_thread_id_, std::this_thread::get_id()); + EXPECT_EQ(main_thread_id_, thread_factory_->currentPthreadId()); dispatcher_->exit(); })); EXPECT_FALSE(waitForDispatcherToExit()); @@ -257,7 +260,7 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptySanOverrides) { EXPECT_CALL(callback_ref, onCertValidationResult(true, Envoy::Ssl::ClientValidationStatus::Validated, "", 46)) .WillOnce(Invoke([this]() { - EXPECT_EQ(main_thread_id_, std::this_thread::get_id()); + EXPECT_EQ(main_thread_id_, thread_factory_->currentPthreadId()); dispatcher_->exit(); })); EXPECT_FALSE(waitForDispatcherToExit()); @@ -292,7 +295,7 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptyHostNoOverrides) { EXPECT_CALL(callback_ref, onCertValidationResult(true, Envoy::Ssl::ClientValidationStatus::Validated, "", 46)) .WillOnce(Invoke([this]() { - EXPECT_EQ(main_thread_id_, std::this_thread::get_id()); + EXPECT_EQ(main_thread_id_, thread_factory_->currentPthreadId()); dispatcher_->exit(); })); EXPECT_FALSE(waitForDispatcherToExit()); @@ -389,6 +392,26 @@ TEST_P(PlatformBridgeCertValidatorTest, DeletedWithValidationPending) { EXPECT_TRUE(waitForDispatcherToExit()); } +TEST_P(PlatformBridgeCertValidatorTest, ThreadCreationFailed) { + initializeConfig(); + auto thread_factory = std::make_unique(); + EXPECT_CALL(*thread_factory, createThread(_, _, false)).WillOnce(Return(ByMove(nullptr))); + PlatformBridgeCertValidator validator(&config_, stats_, std::move(thread_factory)); + + std::string hostname = "server1.example.com"; + bssl::UniquePtr cert_chain = readCertChainFromFile( + TestEnvironment::substitute("{{ test_rundir }}/test/common/tls/test_data/san_dns2_cert.pem")); + auto& callback_ref = *callback_; + EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); + + ValidationResults results = + validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); + EXPECT_EQ(ValidationResults::ValidationStatus::Failed, results.status); + EXPECT_EQ(Ssl::ClientValidationStatus::NotValidated, results.detailed_status); + EXPECT_EQ("Failed creating a thread for cert chain validation.", *results.error_details); +} + } // namespace Tls } // namespace TransportSockets } // namespace Extensions diff --git a/mobile/test/common/internal_engine_test.cc b/mobile/test/common/internal_engine_test.cc index 22e91ac3094c..93fc7b440a98 100644 --- a/mobile/test/common/internal_engine_test.cc +++ b/mobile/test/common/internal_engine_test.cc @@ -4,6 +4,7 @@ #include "test/common/http/common.h" #include "test/common/mocks/common/mocks.h" +#include "test/mocks/thread/mocks.h" #include "absl/synchronization/notification.h" #include "gtest/gtest.h" @@ -17,6 +18,7 @@ namespace Envoy { using testing::_; +using testing::ByMove; using testing::HasSubstr; using testing::Return; using testing::ReturnRef; @@ -114,8 +116,6 @@ struct EngineTestContext { // between the main thread and the engine thread both writing to the // Envoy::Logger::current_log_context global. struct TestEngine { - std::unique_ptr engine_; - envoy_engine_t handle() const { return reinterpret_cast(engine_.get()); } TestEngine(envoy_engine_callbacks callbacks, const std::string& level) { engine_.reset(new Envoy::InternalEngine(callbacks, {}, {})); Platform::EngineBuilder builder; @@ -124,14 +124,13 @@ struct TestEngine { engine_->run(yaml, level); } + envoy_engine_t handle() const { return reinterpret_cast(engine_.get()); } + envoy_status_t terminate() const { return engine_->terminate(); } + [[nodiscard]] bool isTerminated() const { return engine_->isTerminated(); } - ~TestEngine() { - if (!engine_->isTerminated()) { - engine_->terminate(); - } - } + std::unique_ptr engine_; }; // Transform C map to C++ map. @@ -664,4 +663,27 @@ TEST_F(InternalEngineTest, SetLogger) { EXPECT_EQ(engine->terminate(), ENVOY_SUCCESS); } +TEST_F(InternalEngineTest, ThreadCreationFailed) { + const std::string level = "debug"; + EngineTestContext engine_cbs_context{}; + envoy_engine_callbacks engine_cbs{[](void* context) -> void { + auto* engine_running = + static_cast(context); + engine_running->on_engine_running.Notify(); + } /*on_engine_running*/, + [](void* context) -> void { + auto* exit = static_cast(context); + exit->on_exit.Notify(); + } /*on_exit*/, + &engine_cbs_context /*context*/}; + auto thread_factory = std::make_unique(); + EXPECT_CALL(*thread_factory, createThread(_, _, false)).WillOnce(Return(ByMove(nullptr))); + std::unique_ptr engine( + new Envoy::InternalEngine(engine_cbs, {}, {}, std::move(thread_factory))); + envoy_status_t status = engine->run(BUFFERED_TEST_CONFIG, level); + EXPECT_EQ(status, ENVOY_FAILURE); + // Calling `terminate()` should not crash. + EXPECT_EQ(engine->terminate(), ENVOY_FAILURE); +} + } // namespace Envoy diff --git a/source/common/common/macros.h b/source/common/common/macros.h index f2b06b84f340..75032a0f1dc0 100644 --- a/source/common/common/macros.h +++ b/source/common/common/macros.h @@ -57,4 +57,8 @@ namespace Envoy { #if (defined(__GNUC__) && !defined(__clang__)) #define GCC_COMPILER #endif + +#define GTEST_FRIEND_CLASS(test_case_name, test_name) \ + friend class test_case_name##_##test_name##_Test + } // namespace Envoy diff --git a/source/common/common/posix/thread_impl.cc b/source/common/common/posix/thread_impl.cc index e935ac53c11b..e89fb16382c3 100644 --- a/source/common/common/posix/thread_impl.cc +++ b/source/common/common/posix/thread_impl.cc @@ -1,5 +1,8 @@ +#include "source/common/common/posix/thread_impl.h" + +#include "envoy/thread/thread.h" + #include "source/common/common/assert.h" -#include "source/common/common/thread_impl.h" #include "absl/strings/str_cat.h" @@ -31,93 +34,142 @@ int64_t getCurrentThreadId() { // so we need to truncate the string_view to 15 bytes. #define PTHREAD_MAX_THREADNAME_LEN_INCLUDING_NULL_BYTE 16 +ThreadHandle::ThreadHandle(std::function thread_routine) + : thread_routine_(thread_routine) {} + +/** Returns the thread routine. */ +std::function& ThreadHandle::routine() { return thread_routine_; }; + +/** Returns the thread handle. */ +pthread_t& ThreadHandle::handle() { return thread_handle_; } + /** * Wrapper for a pthread thread. We don't use std::thread because it eats exceptions and leads to * unusable stack traces. */ -class ThreadImplPosix : public Thread { -public: - ThreadImplPosix(std::function thread_routine, OptionsOptConstRef options) - : thread_routine_(std::move(thread_routine)) { - if (options) { - name_ = options->name_.substr(0, PTHREAD_MAX_THREADNAME_LEN_INCLUDING_NULL_BYTE - 1); - } - RELEASE_ASSERT(Logger::Registry::initialized(), ""); - const int rc = pthread_create( - &thread_handle_, nullptr, - [](void* arg) -> void* { - static_cast(arg)->thread_routine_(); - return nullptr; - }, - this); - RELEASE_ASSERT(rc == 0, ""); +PosixThread::PosixThread(ThreadHandle* thread_handle, OptionsOptConstRef options) + : thread_handle_(thread_handle) { + if (options) { + name_ = options->name_.substr(0, PTHREAD_MAX_THREADNAME_LEN_INCLUDING_NULL_BYTE - 1); + } #if SUPPORTS_PTHREAD_NAMING - // If the name was not specified, get it from the OS. If the name was - // specified, write it into the thread, and assert that the OS sees it the - // same way. - if (name_.empty()) { - getNameFromOS(name_); + // If the name was not specified, get it from the OS. If the name was + // specified, write it into the thread, and assert that the OS sees it the + // same way. + if (name_.empty()) { + getNameFromOS(name_); + } else { + const int set_name_rc = pthread_setname_np(thread_handle_->handle(), name_.c_str()); + if (set_name_rc != 0) { + ENVOY_LOG_MISC(trace, "Error {} setting name `{}'", set_name_rc, name_); } else { - const int set_name_rc = pthread_setname_np(thread_handle_, name_.c_str()); - if (set_name_rc != 0) { - ENVOY_LOG_MISC(trace, "Error {} setting name `{}'", set_name_rc, name_); - } else { - // When compiling in debug mode, read back the thread-name from the OS, - // and verify it's what we asked for. This ensures the truncation is as - // expected, and that the OS will actually retain all the bytes of the - // name we expect. - // - // Note that the system-call to read the thread name may fail in case - // the thread exits after the call to set the name above, and before the - // call to get the name, so we can only do the assert if that call - // succeeded. - std::string check_name; - ASSERT(!getNameFromOS(check_name) || check_name == name_, - absl::StrCat("configured name=", name_, " os name=", check_name)); - } + // When compiling in debug mode, read back the thread-name from the OS, + // and verify it's what we asked for. This ensures the truncation is as + // expected, and that the OS will actually retain all the bytes of the + // name we expect. + // + // Note that the system-call to read the thread name may fail in case + // the thread exits after the call to set the name above, and before the + // call to get the name, so we can only do the assert if that call + // succeeded. + std::string check_name; + ASSERT(!getNameFromOS(check_name) || check_name == name_, + absl::StrCat("configured name=", name_, " os name=", check_name)); } -#endif } +#endif +} + +PosixThread::~PosixThread() { + ASSERT(joined_); + delete thread_handle_; +} - ~ThreadImplPosix() override { ASSERT(joined_); } +std::string PosixThread::name() const { return name_; } - std::string name() const override { return name_; } +// Thread::Thread +void PosixThread::join() { + ASSERT(!joined_); + joined_ = true; + const int rc = pthread_join(thread_handle_->handle(), nullptr); + RELEASE_ASSERT(rc == 0, ""); +} - // Thread::Thread - void join() override { - ASSERT(!joined_); - joined_ = true; - const int rc = pthread_join(thread_handle_, nullptr); - RELEASE_ASSERT(rc == 0, ""); - } +bool PosixThread::joinable() const { return !joined_; } + +ThreadId PosixThread::pthreadId() const { +#if defined(__linux__) + return ThreadId(static_cast(thread_handle_->handle())); +#elif defined(__APPLE__) + uint64_t tid; + pthread_threadid_np(thread_handle_->handle(), &tid); + return ThreadId(tid); +#else +#error "Enable and test pthread id retrieval code for you arch in pthread/thread_impl.cc" +#endif +} -private: #if SUPPORTS_PTHREAD_NAMING - // Attempts to get the name from the operating system, returning true and - // updating 'name' if successful. Note that during normal operation this - // may fail, if the thread exits prior to the system call. - bool getNameFromOS(std::string& name) { - // Verify that the name got written into the thread as expected. - char buf[PTHREAD_MAX_THREADNAME_LEN_INCLUDING_NULL_BYTE] = {0}; - const int get_name_rc = pthread_getname_np(thread_handle_, buf, sizeof(buf)); - name = buf; - return get_name_rc == 0; - } +// Attempts to get the name from the operating system, returning true and +// updating 'name' if successful. Note that during normal operation this +// may fail, if the thread exits prior to the system call. +bool PosixThread::getNameFromOS(std::string& name) { + // Verify that the name got written into the thread as expected. + char buf[PTHREAD_MAX_THREADNAME_LEN_INCLUDING_NULL_BYTE] = {0}; + const int get_name_rc = pthread_getname_np(thread_handle_->handle(), buf, sizeof(buf)); + name = buf; + return get_name_rc == 0; +} #endif - std::function thread_routine_; - pthread_t thread_handle_; - std::string name_; - bool joined_{false}; +class PosixThreadFactoryImpl : public PosixThreadFactory { +public: + ThreadPtr createThread(std::function thread_routine, + OptionsOptConstRef options) override { + return createThread(thread_routine, options, /* crash_on_failure= */ true); + }; + + PosixThreadPtr createThread(std::function thread_routine, OptionsOptConstRef options, + bool crash_on_failure) override { + auto thread_handle = new ThreadHandle(thread_routine); + const int rc = pthread_create( + &thread_handle->handle(), nullptr, + [](void* arg) -> void* { + static_cast(arg)->routine()(); + return nullptr; + }, + reinterpret_cast(thread_handle)); + if (rc != 0) { + delete thread_handle; + if (crash_on_failure) { + RELEASE_ASSERT(false, fmt::format("Unable to create a thread with return code: {}", rc)); + } else { + IS_ENVOY_BUG(fmt::format("Unable to create a thread with return code: {}", rc)); + } + return nullptr; + } + return std::make_unique(thread_handle, options); + }; + + ThreadId currentThreadId() override { return ThreadId(getCurrentThreadId()); }; + + ThreadId currentPthreadId() override { +#if defined(__linux__) + return static_cast(static_cast(pthread_self())); +#elif defined(__APPLE__) + uint64_t tid; + pthread_threadid_np(pthread_self(), &tid); + return ThreadId(tid); +#else +#error "Enable and test pthread id retrieval code for you arch in pthread/thread_impl.cc" +#endif + } }; -ThreadPtr ThreadFactoryImplPosix::createThread(std::function thread_routine, - OptionsOptConstRef options) { - return std::make_unique(thread_routine, options); +PosixThreadFactoryPtr PosixThreadFactory::create() { + return std::make_unique(); } -ThreadId ThreadFactoryImplPosix::currentThreadId() { return ThreadId(getCurrentThreadId()); } - } // namespace Thread } // namespace Envoy diff --git a/source/common/common/posix/thread_impl.h b/source/common/common/posix/thread_impl.h index 9b373ecaceb6..2201d3e61886 100644 --- a/source/common/common/posix/thread_impl.h +++ b/source/common/common/posix/thread_impl.h @@ -4,19 +4,99 @@ #include +#include "envoy/common/platform.h" #include "envoy/thread/thread.h" namespace Envoy { namespace Thread { -/** - * Implementation of ThreadFactory - */ -class ThreadFactoryImplPosix : public ThreadFactory { +class ThreadHandle { public: - // Thread::ThreadFactory - ThreadPtr createThread(std::function thread_routine, OptionsOptConstRef options) override; - ThreadId currentThreadId() override; + explicit ThreadHandle(std::function thread_routine); + + /** Returns the thread routine. */ + std::function& routine(); + + /** Returns the thread handle. */ + pthread_t& handle(); + +private: + std::function thread_routine_; + pthread_t thread_handle_; +}; + +class PosixThread : public Thread { +public: + PosixThread(ThreadHandle* thread_handle, OptionsOptConstRef options); + ~PosixThread() override; + + // Envoy::Thread + std::string name() const override; + void join() override; + + /** + * Returns true if the thread object identifies an active thread of execution, + * false otherwise. + * A thread that has finished executing code, but has not yet been joined is + * still considered an active thread of execution and is therefore joinable. + */ + bool joinable() const; + + /** + * Returns the pthread ID. The thread ID returned from this call is the same + * thread ID returned from `pthread_self()`: + * https://man7.org/linux/man-pages/man3/pthread_self.3.html + */ + ThreadId pthreadId() const; + +private: +#if SUPPORTS_PTHREAD_NAMING + // Attempts to get the name from the operating system, returning true and + // updating 'name' if successful. Note that during normal operation this + // may fail, if the thread exits prior to the system call. + bool getNameFromOS(std::string& name); +#endif + + std::function thread_routine_; + ThreadHandle* thread_handle_; + std::string name_; + bool joined_{false}; +}; + +using PosixThreadPtr = std::unique_ptr; + +class PosixThreadFactory; +using PosixThreadFactoryPtr = std::unique_ptr; + +/** An interface for POSIX `ThreadFactory` */ +class PosixThreadFactory : public ThreadFactory { +public: + // /** Creates a new instance of `PosixThreadPtr`. */ + static PosixThreadFactoryPtr create(); + + /** + * Creates a new generic thread from the specified `thread_routine`. When the + * thread cannot be created, this function will crash. + */ + ThreadPtr createThread(std::function thread_routine, OptionsOptConstRef options) PURE; + + /** + * Creates a new POSIX thread from the specified `thread_routine`. When + * `crash_on_failure` is set to true, this function will crash when the thread + * cannot be created; otherwise a `nullptr` will be returned. + */ + virtual PosixThreadPtr createThread(std::function thread_routine, + OptionsOptConstRef options, bool crash_on_failure) PURE; + + /** + * On Linux, `currentThreadId()` uses `gettid()` and it returns the kernel + * thread ID. The thread ID returned from this call is not the same as the + * thread ID returned from `currentPThreadId()`. + */ + ThreadId currentThreadId() PURE; + + /** Returns the current pthread ID. It uses `pthread_self()`. */ + virtual ThreadId currentPthreadId() PURE; }; } // namespace Thread diff --git a/source/exe/linux/platform_impl.cc b/source/exe/linux/platform_impl.cc index a1c04243eb77..63de62b11936 100644 --- a/source/exe/linux/platform_impl.cc +++ b/source/exe/linux/platform_impl.cc @@ -11,7 +11,7 @@ namespace Envoy { PlatformImpl::PlatformImpl() - : thread_factory_(std::make_unique()), + : thread_factory_(Thread::PosixThreadFactory::create()), file_system_(std::make_unique()) {} PlatformImpl::~PlatformImpl() = default; diff --git a/source/exe/posix/platform_impl.cc b/source/exe/posix/platform_impl.cc index 391d3be940ae..4eae6bee3b10 100644 --- a/source/exe/posix/platform_impl.cc +++ b/source/exe/posix/platform_impl.cc @@ -5,7 +5,7 @@ namespace Envoy { PlatformImpl::PlatformImpl() - : thread_factory_(std::make_unique()), + : thread_factory_(Thread::PosixThreadFactory::create()), file_system_(std::make_unique()) {} PlatformImpl::~PlatformImpl() = default; diff --git a/test/common/common/thread_test.cc b/test/common/common/thread_test.cc index 8cfc1c6ac771..8c8a28f7d3d9 100644 --- a/test/common/common/thread_test.cc +++ b/test/common/common/thread_test.cc @@ -1,5 +1,9 @@ #include +#if defined(__linux__) || defined(__APPLE__) +#include "source/common/common/posix/thread_impl.h" +#endif + #include "source/common/common/thread.h" #include "source/common/common/thread_synchronizer.h" @@ -246,6 +250,30 @@ TEST_F(ThreadAsyncPtrTest, NameNotSpecifiedWait) { thread->join(); } +#if defined(__linux__) || defined(__APPLE__) +TEST(PosixThreadTest, PThreadId) { + auto thread_factory = PosixThreadFactory::create(); + ThreadId thread_id; + auto thread = + thread_factory->createThread([&]() { thread_id = thread_factory->currentPthreadId(); }, + /* options= */ absl::nullopt, /* crash_on_failure= */ false); + thread->join(); + + EXPECT_EQ(thread->pthreadId(), thread_id); + EXPECT_NE(thread->pthreadId(), thread_factory->currentThreadId()); +} + +TEST(PosixThreadTest, Joinable) { + auto thread_factory = PosixThreadFactory::create(); + auto thread = thread_factory->createThread([&]() {}, /* options= */ absl::nullopt, + /* crash_on_failure= */ true); + + EXPECT_TRUE(thread->joinable()); + thread->join(); + EXPECT_FALSE(thread->joinable()); +} +#endif + } // namespace } // namespace Thread } // namespace Envoy diff --git a/test/mocks/thread/BUILD b/test/mocks/thread/BUILD new file mode 100644 index 000000000000..62f26f23abd7 --- /dev/null +++ b/test/mocks/thread/BUILD @@ -0,0 +1,18 @@ +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_mock", + "envoy_package", +) + +licenses(["notice"]) # Apache 2 + +envoy_package() + +envoy_cc_mock( + name = "thread_mocks", + hdrs = ["mocks.h"], + deps = [ + "//envoy/thread:thread_interface", + "//source/common/common:thread_lib", + ], +) diff --git a/test/mocks/thread/mocks.h b/test/mocks/thread/mocks.h new file mode 100644 index 000000000000..9593a01b478b --- /dev/null +++ b/test/mocks/thread/mocks.h @@ -0,0 +1,30 @@ +#pragma once + +#include "envoy/thread/thread.h" + +#if defined(__linux__) || defined(__APPLE__) +#include "source/common/common/posix/thread_impl.h" +#endif + +namespace Envoy { +namespace Thread { + +class MockThreadFactory : public ThreadFactory { +public: + MOCK_METHOD(ThreadPtr, createThread, (std::function, OptionsOptConstRef)); + MOCK_METHOD(ThreadId, currentThreadId, ()); +}; + +#if defined(__linux__) || defined(__APPLE__) +class MockPosixThreadFactory : public PosixThreadFactory { +public: + MOCK_METHOD(ThreadPtr, createThread, (std::function, OptionsOptConstRef)); + MOCK_METHOD(PosixThreadPtr, createThread, + (std::function, OptionsOptConstRef, bool crash_on_failure)); + MOCK_METHOD(ThreadId, currentThreadId, ()); + MOCK_METHOD(ThreadId, currentPthreadId, ()); +}; +#endif + +} // namespace Thread +} // namespace Envoy diff --git a/test/per_file_coverage.sh b/test/per_file_coverage.sh index 4a0e5fc32502..7f02290a378e 100755 --- a/test/per_file_coverage.sh +++ b/test/per_file_coverage.sh @@ -6,6 +6,7 @@ declare -a KNOWN_LOW_COVERAGE=( "source/common:95.9" # TODO(#32149): increase this once io_uring is tested. "source/common/api:84.5" # flaky due to posix: be careful adjusting "source/common/api/posix:83.8" # flaky (accept failover non-deterministic): be careful adjusting +"source/common/common/posix:88.8" # No easy way to test pthread_create failure. "source/common/config:95.4" "source/common/crypto:95.5" "source/common/event:95.0" # Emulated edge events guards don't report LCOV diff --git a/test/test_common/thread_factory_for_test.cc b/test/test_common/thread_factory_for_test.cc index 55c4d84ed12b..0d1c2db9cae6 100644 --- a/test/test_common/thread_factory_for_test.cc +++ b/test/test_common/thread_factory_for_test.cc @@ -9,7 +9,7 @@ ThreadFactory& threadFactoryForTest() { #ifdef WIN32 static auto* thread_factory = new ThreadFactoryImplWin32(); #else - static auto* thread_factory = new ThreadFactoryImplPosix(); + static auto* thread_factory = PosixThreadFactory::create().release(); #endif return *thread_factory; } diff --git a/tools/spelling/spelling_dictionary.txt b/tools/spelling/spelling_dictionary.txt index 7f66087cef02..15e00c4ff5cf 100644 --- a/tools/spelling/spelling_dictionary.txt +++ b/tools/spelling/spelling_dictionary.txt @@ -881,6 +881,7 @@ iteratively javascript jitter jittered +joinable js kafka keepalive