diff --git a/source/extensions/filters/http/cache/cache_filter.cc b/source/extensions/filters/http/cache/cache_filter.cc index 2e9c073a6680..88aed69f9f08 100644 --- a/source/extensions/filters/http/cache/cache_filter.cc +++ b/source/extensions/filters/http/cache/cache_filter.cc @@ -303,38 +303,19 @@ CacheFilter::resolveLookupStatus(absl::optional cache_entry_st void CacheFilter::getHeaders(Http::RequestHeaderMap& request_headers) { ASSERT(lookup_, "CacheFilter is trying to call getHeaders with no LookupContext"); - - // If the cache posts a callback to the dispatcher then the CacheFilter is destroyed for any - // reason (e.g client disconnected and HTTP stream terminated), then there is no guarantee that - // the posted callback will run before the filter is deleted. Hence, a weak_ptr to the CacheFilter - // is captured and used to make sure the CacheFilter is still alive before accessing it in the - // posted callback. - // TODO(yosrym93): Look into other options for handling this (also in getBody and getTrailers) as - // they arise, e.g. cancellable posts, guaranteed ordering of posted callbacks and deletions, etc. - CacheFilterWeakPtr self = weak_from_this(); - - // The dispatcher needs to be captured because there's no guarantee that - // decoder_callbacks_->dispatcher() is thread-safe. - lookup_->getHeaders([self, &request_headers, &dispatcher = decoder_callbacks_->dispatcher()]( + callback_called_directly_ = true; + lookup_->getHeaders([this, &request_headers, &dispatcher = decoder_callbacks_->dispatcher()]( LookupResult&& result, bool end_stream) { - // The callback is posted to the dispatcher to make sure it is called on the worker thread. - dispatcher.post([self, &request_headers, result = std::move(result), end_stream]() mutable { - if (CacheFilterSharedPtr cache_filter = self.lock()) { - cache_filter->onHeaders(std::move(result), request_headers, end_stream); - } - }); + ASSERT(!callback_called_directly_ && dispatcher.isThreadSafe(), + "caches must post the callback to the filter's dispatcher"); + onHeaders(std::move(result), request_headers, end_stream); }); + callback_called_directly_ = false; } void CacheFilter::getBody() { ASSERT(lookup_, "CacheFilter is trying to call getBody with no LookupContext"); ASSERT(!remaining_ranges_.empty(), "No reason to call getBody when there's no body to get."); - // If the cache posts a callback to the dispatcher then the CacheFilter is destroyed for any - // reason (e.g client disconnected and HTTP stream terminated), then there is no guarantee that - // the posted callback will run before the filter is deleted. Hence, a weak_ptr to the CacheFilter - // is captured and used to make sure the CacheFilter is still alive before accessing it in the - // posted callback. - CacheFilterWeakPtr self = weak_from_this(); // We don't want to request more than a buffer-size at a time from the cache. uint64_t fetch_size_limit = encoder_callbacks_->encoderBufferLimit(); @@ -347,41 +328,27 @@ void CacheFilter::getBody() { ? (remaining_ranges_[0].begin() + fetch_size_limit) : remaining_ranges_[0].end()}; - // The dispatcher needs to be captured because there's no guarantee that - // decoder_callbacks_->dispatcher() is thread-safe. - lookup_->getBody(fetch_range, [self, &dispatcher = decoder_callbacks_->dispatcher()]( + callback_called_directly_ = true; + lookup_->getBody(fetch_range, [this, &dispatcher = decoder_callbacks_->dispatcher()]( Buffer::InstancePtr&& body, bool end_stream) { - // The callback is posted to the dispatcher to make sure it is called on the worker thread. - dispatcher.post([self, body = std::move(body), end_stream]() mutable { - if (CacheFilterSharedPtr cache_filter = self.lock()) { - cache_filter->onBody(std::move(body), end_stream); - } - }); + ASSERT(!callback_called_directly_ && dispatcher.isThreadSafe(), + "caches must post the callback to the filter's dispatcher"); + onBody(std::move(body), end_stream); }); + callback_called_directly_ = false; } void CacheFilter::getTrailers() { ASSERT(lookup_, "CacheFilter is trying to call getTrailers with no LookupContext"); - // If the cache posts a callback to the dispatcher then the CacheFilter is destroyed for any - // reason (e.g client disconnected and HTTP stream terminated), then there is no guarantee that - // the posted callback will run before the filter is deleted. Hence, a weak_ptr to the CacheFilter - // is captured and used to make sure the CacheFilter is still alive before accessing it in the - // posted callback. - CacheFilterWeakPtr self = weak_from_this(); - - // The dispatcher needs to be captured because there's no guarantee that - // decoder_callbacks_->dispatcher() is thread-safe. - lookup_->getTrailers([self, &dispatcher = decoder_callbacks_->dispatcher()]( + callback_called_directly_ = true; + lookup_->getTrailers([this, &dispatcher = decoder_callbacks_->dispatcher()]( Http::ResponseTrailerMapPtr&& trailers) { - // The callback is posted to the dispatcher to make sure it is called on the worker thread. - // The lambda must be mutable as it captures trailers as a unique_ptr. - dispatcher.post([self, trailers = std::move(trailers)]() mutable { - if (CacheFilterSharedPtr cache_filter = self.lock()) { - cache_filter->onTrailers(std::move(trailers)); - } - }); + ASSERT(!callback_called_directly_ && dispatcher.isThreadSafe(), + "caches must post the callback to the filter's dispatcher"); + onTrailers(std::move(trailers)); }); + callback_called_directly_ = false; } void CacheFilter::onHeaders(LookupResult&& result, Http::RequestHeaderMap& request_headers, diff --git a/source/extensions/filters/http/cache/cache_filter.h b/source/extensions/filters/http/cache/cache_filter.h index 3641797b5c7a..923a1cf1975e 100644 --- a/source/extensions/filters/http/cache/cache_filter.h +++ b/source/extensions/filters/http/cache/cache_filter.h @@ -164,6 +164,8 @@ class CacheFilter : public Http::PassThroughFilter, FilterState filter_state_ = FilterState::Initial; bool is_head_request_ = false; + // This toggle is used to detect callbacks being called directly and not posted. + bool callback_called_directly_ = false; // The status of the insert operation or header update, or decision not to insert or update. // If it's too early to determine the final status, this is empty. absl::optional insert_status_; diff --git a/source/extensions/filters/http/cache/cache_insert_queue.cc b/source/extensions/filters/http/cache/cache_insert_queue.cc index 66cb9d41ea12..80de4b61280c 100644 --- a/source/extensions/filters/http/cache/cache_insert_queue.cc +++ b/source/extensions/filters/http/cache/cache_insert_queue.cc @@ -14,7 +14,7 @@ class CacheInsertFragment { // on_complete is called when the cache completes the operation. virtual void send(InsertContext& context, - std::function on_complete) PURE; + absl::AnyInvocable on_complete) PURE; virtual ~CacheInsertFragment() = default; }; @@ -27,14 +27,14 @@ class CacheInsertFragmentBody : public CacheInsertFragment { CacheInsertFragmentBody(const Buffer::Instance& buffer, bool end_stream) : buffer_(buffer), end_stream_(end_stream) {} - void - send(InsertContext& context, - std::function on_complete) override { + void send(InsertContext& context, + absl::AnyInvocable on_complete) + override { size_t sz = buffer_.length(); context.insertBody( std::move(buffer_), - [on_complete, end_stream = end_stream_, sz](bool cache_success) { - on_complete(cache_success, end_stream, sz); + [cb = std::move(on_complete), end_stream = end_stream_, sz](bool cache_success) mutable { + std::move(cb)(cache_success, end_stream, sz); }, end_stream_); } @@ -52,14 +52,15 @@ class CacheInsertFragmentTrailers : public CacheInsertFragment { Http::ResponseTrailerMapImpl::copyFrom(*trailers_, trailers); } - void - send(InsertContext& context, - std::function on_complete) override { + void send(InsertContext& context, + absl::AnyInvocable on_complete) + override { // While zero isn't technically true for the size of trailers, it doesn't // matter at this point because watermarks after the stream is complete // aren't useful. - context.insertTrailers( - *trailers_, [on_complete](bool cache_success) { on_complete(cache_success, true, 0); }); + context.insertTrailers(*trailers_, [cb = std::move(on_complete)](bool cache_success) mutable { + std::move(cb)(cache_success, true, 0); + }); } private: @@ -72,7 +73,7 @@ CacheInsertQueue::CacheInsertQueue(std::shared_ptr cache, : dispatcher_(encoder_callbacks.dispatcher()), insert_context_(std::move(insert_context)), low_watermark_bytes_(encoder_callbacks.encoderBufferLimit() / 2), high_watermark_bytes_(encoder_callbacks.encoderBufferLimit()), - encoder_callbacks_(encoder_callbacks), abort_callback_(abort), cache_(cache) {} + encoder_callbacks_(encoder_callbacks), abort_callback_(std::move(abort)), cache_(cache) {} void CacheInsertQueue::insertHeaders(const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, bool end_stream) { @@ -123,59 +124,54 @@ void CacheInsertQueue::insertTrailers(const Http::ResponseTrailerMap& trailers) } void CacheInsertQueue::onFragmentComplete(bool cache_success, bool end_stream, size_t sz) { - // If the cache implementation is asynchronous, this may be called from whatever - // thread that cache implementation runs on. Therefore, we post it to the - // dispatcher to be certain any callbacks and updates are called on the filter's - // thread (and therefore we don't have to mutex-guard anything). - dispatcher_.post([this, cache_success, end_stream, sz]() { - fragment_in_flight_ = false; - if (aborting_) { - // Parent filter was destroyed, so we can quit this operation. - fragments_.clear(); - self_ownership_.reset(); - return; + ASSERT(dispatcher_.isThreadSafe()); + fragment_in_flight_ = false; + if (aborting_) { + // Parent filter was destroyed, so we can quit this operation. + fragments_.clear(); + self_ownership_.reset(); + return; + } + ASSERT(queue_size_bytes_ >= sz, "queue can't be emptied by more than its size"); + queue_size_bytes_ -= sz; + if (watermarked_ && queue_size_bytes_ <= low_watermark_bytes_) { + if (encoder_callbacks_.has_value()) { + encoder_callbacks_.value().get().onEncoderFilterBelowWriteBufferLowWatermark(); } - ASSERT(queue_size_bytes_ >= sz, "queue can't be emptied by more than its size"); - queue_size_bytes_ -= sz; - if (watermarked_ && queue_size_bytes_ <= low_watermark_bytes_) { + watermarked_ = false; + } + if (!cache_success) { + // canceled by cache; unwatermark if necessary, inform the filter if + // it's still around, and delete the queue. + if (watermarked_) { if (encoder_callbacks_.has_value()) { encoder_callbacks_.value().get().onEncoderFilterBelowWriteBufferLowWatermark(); } watermarked_ = false; } - if (!cache_success) { - // canceled by cache; unwatermark if necessary, inform the filter if - // it's still around, and delete the queue. - if (watermarked_) { - if (encoder_callbacks_.has_value()) { - encoder_callbacks_.value().get().onEncoderFilterBelowWriteBufferLowWatermark(); - } - watermarked_ = false; - } - fragments_.clear(); - // Clearing self-ownership might provoke the destructor, so take a copy of the - // abort callback to avoid reading from 'this' after it may be deleted. - auto abort_callback = abort_callback_; - self_ownership_.reset(); - abort_callback(); - return; - } - if (end_stream) { - ASSERT(fragments_.empty(), "ending a stream with the queue not empty is a bug"); - ASSERT(!watermarked_, "being over the high watermark when the queue is empty makes no sense"); - self_ownership_.reset(); - return; - } - if (!fragments_.empty()) { - // If there's more in the queue, push the next fragment to the cache. - auto fragment = std::move(fragments_.front()); - fragments_.pop_front(); - fragment_in_flight_ = true; - fragment->send(*insert_context_, [this](bool cache_success, bool end_stream, size_t sz) { - onFragmentComplete(cache_success, end_stream, sz); - }); - } - }); + fragments_.clear(); + // Clearing self-ownership might provoke the destructor, so take a copy of the + // abort callback to avoid reading from 'this' after it may be deleted. + auto abort_callback = std::move(abort_callback_); + self_ownership_.reset(); + std::move(abort_callback)(); + return; + } + if (end_stream) { + ASSERT(fragments_.empty(), "ending a stream with the queue not empty is a bug"); + ASSERT(!watermarked_, "being over the high watermark when the queue is empty makes no sense"); + self_ownership_.reset(); + return; + } + if (!fragments_.empty()) { + // If there's more in the queue, push the next fragment to the cache. + auto fragment = std::move(fragments_.front()); + fragments_.pop_front(); + fragment_in_flight_ = true; + fragment->send(*insert_context_, [this](bool cache_success, bool end_stream, size_t sz) { + onFragmentComplete(cache_success, end_stream, sz); + }); + } } void CacheInsertQueue::setSelfOwned(std::unique_ptr self) { diff --git a/source/extensions/filters/http/cache/cache_insert_queue.h b/source/extensions/filters/http/cache/cache_insert_queue.h index feae50414a63..22297a70a528 100644 --- a/source/extensions/filters/http/cache/cache_insert_queue.h +++ b/source/extensions/filters/http/cache/cache_insert_queue.h @@ -12,7 +12,7 @@ namespace Cache { using OverHighWatermarkCallback = std::function; using UnderLowWatermarkCallback = std::function; -using AbortInsertCallback = std::function; +using AbortInsertCallback = absl::AnyInvocable; class CacheInsertFragment; // This queue acts as an intermediary between CacheFilter and the cache diff --git a/source/extensions/filters/http/cache/http_cache.h b/source/extensions/filters/http/cache/http_cache.h index 59face7598bb..e01fbf7d162a 100644 --- a/source/extensions/filters/http/cache/http_cache.h +++ b/source/extensions/filters/http/cache/http_cache.h @@ -122,19 +122,20 @@ struct CacheInfo { bool supports_range_requests_ = false; }; -using LookupBodyCallback = std::function; -using LookupHeadersCallback = std::function; -using LookupTrailersCallback = std::function; -using InsertCallback = std::function; +using LookupBodyCallback = absl::AnyInvocable; +using LookupHeadersCallback = absl::AnyInvocable; +using LookupTrailersCallback = absl::AnyInvocable; +using InsertCallback = absl::AnyInvocable; +using UpdateHeadersCallback = absl::AnyInvocable; // Manages the lifetime of an insertion. class InsertContext { public: // Accepts response_headers for caching. Only called once. // - // Implementations MUST call insert_complete(true) on success, or - // insert_complete(false) to attempt to abort the insertion. This - // call may be made asynchronously, but any async operation that can + // Implementations MUST post to the filter's dispatcher insert_complete(true) + // on success, or insert_complete(false) to attempt to abort the insertion. + // This call may be made asynchronously, but any async operation that can // potentially silently fail must include a timeout, to avoid memory leaks. virtual void insertHeaders(const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, InsertCallback insert_complete, @@ -149,17 +150,17 @@ class InsertContext { // InsertContextPtr. A cache can abort the insertion by passing 'false' into // ready_for_next_fragment. // - // The cache implementation MUST call ready_for_next_fragment. This call may be - // made asynchronously, but any async operation that can potentially silently - // fail must include a timeout, to avoid memory leaks. + // The cache implementation MUST post ready_for_next_fragment to the filter's + // dispatcher. This post may be made asynchronously, but any async operation + // that can potentially silently fail must include a timeout, to avoid memory leaks. virtual void insertBody(const Buffer::Instance& fragment, InsertCallback ready_for_next_fragment, bool end_stream) PURE; // Inserts trailers into the cache. // - // The cache implementation MUST call insert_complete. This call may be - // made asynchronously, but any async operation that can potentially silently - // fail must include a timeout, to avoid memory leaks. + // The cache implementation MUST post insert_complete to the filter's dispatcher. + // This call may be made asynchronously, but any async operation that can + // potentially silently fail must include a timeout, to avoid memory leaks. virtual void insertTrailers(const Http::ResponseTrailerMap& trailers, InsertCallback insert_complete) PURE; @@ -199,6 +200,9 @@ class LookupContext { // implementation should wait until that is known before calling the callback, // and must pass a LookupResult with range_details_->satisfiable_ = false // if the request is invalid. + // + // A cache that posts the callback must wrap it such that if the LookupContext is + // destroyed before the callback is executed, the callback is not executed. virtual void getHeaders(LookupHeadersCallback&& cb) PURE; // Reads the next fragment from the cache, calling cb when the fragment is ready. @@ -228,11 +232,17 @@ class LookupContext { // getBody requests bytes 0-23 .......... callback with bytes 0-9 // getBody requests bytes 10-23 .......... callback with bytes 10-19 // getBody requests bytes 20-23 .......... callback with bytes 20-23 + // + // A cache that posts the callback must wrap it such that if the LookupContext is + // destroyed before the callback is executed, the callback is not executed. virtual void getBody(const AdjustedByteRange& range, LookupBodyCallback&& cb) PURE; // Get the trailers from the cache. Only called if the request reached the end of // the body and LookupBodyCallback did not pass true for end_stream. The // Http::ResponseTrailerMapPtr passed to cb must not be null. + // + // A cache that posts the callback must wrap it such that if the LookupContext is + // destroyed before the callback is executed, the callback is not executed. virtual void getTrailers(LookupTrailersCallback&& cb) PURE; // This routine is called prior to a LookupContext being destroyed. LookupContext is responsible @@ -248,7 +258,7 @@ class LookupContext { // 5. [Other thread] RPC completes and calls RPCLookupContext::onRPCDone. // --> RPCLookupContext's destructor and onRpcDone cause a data race in RPCLookupContext. // onDestroy() should cancel any outstanding async operations and, if necessary, - // it should block on that cancellation to avoid data races. InsertContext must not invoke any + // it should block on that cancellation to avoid data races. LookupContext must not invoke any // callbacks to the CacheFilter after having onDestroy() invoked. virtual void onDestroy() PURE; @@ -289,7 +299,7 @@ class HttpCache { virtual void updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, - std::function on_complete) PURE; + UpdateHeadersCallback on_complete) PURE; // Returns statically known information about a cache. virtual CacheInfo cacheInfo() const PURE; diff --git a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc index 1e5a49c6c6f6..5e220e2b4598 100644 --- a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc +++ b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.cc @@ -127,13 +127,13 @@ class HeaderUpdateContext : public Logger::Loggable { HeaderUpdateContext(Event::Dispatcher& dispatcher, const FileSystemHttpCache& cache, const Key& key, std::shared_ptr cleanup, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, std::function on_complete) + const ResponseMetadata& metadata, UpdateHeadersCallback on_complete) : dispatcher_(dispatcher), filepath_(absl::StrCat(cache.cachePath(), cache.generateFilename(key))), cache_path_(cache.cachePath()), cleanup_(cleanup), async_file_manager_(cache.asyncFileManager()), response_headers_(Http::createHeaderMap(response_headers)), - response_metadata_(metadata), on_complete_(on_complete) {} + response_metadata_(metadata), on_complete_(std::move(on_complete)) {} void begin(std::shared_ptr ctx) { async_file_manager_->openExistingFile( @@ -278,14 +278,14 @@ class HeaderUpdateContext : public Logger::Loggable { fail("failed to link new cache file", link_result); return; } - on_complete_(true); + std::move(on_complete_)(true); }); ASSERT(queued.ok()); } void fail(absl::string_view msg, absl::Status status) { ENVOY_LOG(warn, "file_system_http_cache: {} for update cache file {}: {}", msg, filepath_, status); - on_complete_(false); + std::move(on_complete_)(false); } Event::Dispatcher* dispatcher() { return &dispatcher_; } Event::Dispatcher& dispatcher_; @@ -300,13 +300,13 @@ class HeaderUpdateContext : public Logger::Loggable { CacheFileHeader header_proto_; AsyncFileHandle read_handle_; AsyncFileHandle write_handle_; - std::function on_complete_; + UpdateHeadersCallback on_complete_; }; void FileSystemHttpCache::updateHeaders(const LookupContext& base_lookup_context, const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, - std::function on_complete) { + UpdateHeadersCallback on_complete) { const FileLookupContext& lookup_context = dynamic_cast(base_lookup_context); const Key& key = lookup_context.key(); @@ -314,8 +314,9 @@ void FileSystemHttpCache::updateHeaders(const LookupContext& base_lookup_context if (!cleanup) { return; } - auto ctx = std::make_shared( - *lookup_context.dispatcher(), *this, key, cleanup, response_headers, metadata, on_complete); + auto ctx = + std::make_shared(*lookup_context.dispatcher(), *this, key, cleanup, + response_headers, metadata, std::move(on_complete)); ctx->begin(ctx); } diff --git a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h index 9e81bf0410b0..be4c59402444 100644 --- a/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h +++ b/source/extensions/http/cache/file_system_http_cache/file_system_http_cache.h @@ -79,8 +79,7 @@ class FileSystemHttpCache : public HttpCache, */ void updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, - std::function on_complete) override; + const ResponseMetadata& metadata, UpdateHeadersCallback on_complete) override; /** * The config of this cache. Used by the factory to ensure there aren't incompatible diff --git a/source/extensions/http/cache/file_system_http_cache/insert_context.cc b/source/extensions/http/cache/file_system_http_cache/insert_context.cc index 98e05657734e..55503ef6dd25 100644 --- a/source/extensions/http/cache/file_system_http_cache/insert_context.cc +++ b/source/extensions/http/cache/file_system_http_cache/insert_context.cc @@ -32,7 +32,7 @@ void FileInsertContext::insertHeaders(const Http::ResponseHeaderMap& response_he const ResponseMetadata& metadata, InsertCallback insert_complete, bool end_stream) { ASSERT(dispatcher()->isThreadSafe()); - callback_in_flight_ = insert_complete; + callback_in_flight_ = std::move(insert_complete); const VaryAllowList& vary_allow_list = lookup_context_->lookup().varyAllowList(); const Http::RequestHeaderMap& request_headers = lookup_context_->lookup().requestHeaders(); if (VaryHeaderUtils::hasVary(response_headers)) { @@ -59,7 +59,6 @@ void FileInsertContext::insertHeaders(const Http::ResponseHeaderMap& response_he } cache_file_header_proto_ = makeCacheFileHeaderProto(key_, response_headers, metadata); end_stream_after_headers_ = end_stream; - on_insert_complete_ = std::move(insert_complete); createFile(); } @@ -140,10 +139,10 @@ void FileInsertContext::insertBody(const Buffer::Instance& fragment, ASSERT(!callback_in_flight_); if (!cleanup_) { // Already cancelled, do nothing, return failure. - ready_for_next_fragment(false); + std::move(ready_for_next_fragment)(false); return; } - callback_in_flight_ = ready_for_next_fragment; + callback_in_flight_ = std::move(ready_for_next_fragment); size_t sz = fragment.length(); Buffer::OwnedImpl consumable_fragment(fragment); auto queued = file_handle_->write( @@ -172,10 +171,10 @@ void FileInsertContext::insertTrailers(const Http::ResponseTrailerMap& trailers, ASSERT(!callback_in_flight_); if (!cleanup_) { // Already cancelled, do nothing, return failure. - insert_complete(false); + std::move(insert_complete)(false); return; } - callback_in_flight_ = insert_complete; + callback_in_flight_ = std::move(insert_complete); CacheFileTrailer file_trailer = makeCacheFileTrailerProto(trailers); Buffer::OwnedImpl consumable_buffer = bufferFromProto(file_trailer); size_t sz = consumable_buffer.length(); diff --git a/source/extensions/http/cache/file_system_http_cache/insert_context.h b/source/extensions/http/cache/file_system_http_cache/insert_context.h index 1f8e6665338f..a70fbbae1d06 100644 --- a/source/extensions/http/cache/file_system_http_cache/insert_context.h +++ b/source/extensions/http/cache/file_system_http_cache/insert_context.h @@ -88,7 +88,6 @@ class FileInsertContext : public InsertContext, public Logger::Loggable lookup_context_; Key key_; std::shared_ptr cache_; diff --git a/source/extensions/http/cache/file_system_http_cache/lookup_context.cc b/source/extensions/http/cache/file_system_http_cache/lookup_context.cc index a58b13a44119..9f3d7e028f5f 100644 --- a/source/extensions/http/cache/file_system_http_cache/lookup_context.cc +++ b/source/extensions/http/cache/file_system_http_cache/lookup_context.cc @@ -140,7 +140,7 @@ void FileLookupContext::getBody(const AdjustedByteRange& range, LookupBodyCallba ASSERT(file_handle_); auto queued = file_handle_->read( dispatcher(), header_block_.offsetToBody() + range.begin(), range.length(), - [this, cb = std::move(cb), range](absl::StatusOr read_result) { + [this, cb = std::move(cb), range](absl::StatusOr read_result) mutable { ASSERT(dispatcher()->isThreadSafe()); cancel_action_in_flight_ = nullptr; if (!read_result.ok() || read_result.value()->length() != range.length()) { @@ -164,7 +164,7 @@ void FileLookupContext::getTrailers(LookupTrailersCallback&& cb) { ASSERT(file_handle_); auto queued = file_handle_->read( dispatcher(), header_block_.offsetToTrailers(), header_block_.trailerSize(), - [this, cb = std::move(cb)](absl::StatusOr read_result) { + [this, cb = std::move(cb)](absl::StatusOr read_result) mutable { ASSERT(dispatcher()->isThreadSafe()); cancel_action_in_flight_ = nullptr; if (!read_result.ok() || read_result.value()->length() != header_block_.trailerSize()) { diff --git a/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc b/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc index a2be23281c86..04c70bba9d39 100644 --- a/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc +++ b/source/extensions/http/cache/simple_http_cache/simple_http_cache.cc @@ -33,35 +33,57 @@ absl::optional variedRequestKey(const LookupRequest& request, class SimpleLookupContext : public LookupContext { public: - SimpleLookupContext(SimpleHttpCache& cache, LookupRequest&& request) - : cache_(cache), request_(std::move(request)) {} + SimpleLookupContext(Event::Dispatcher& dispatcher, SimpleHttpCache& cache, + LookupRequest&& request) + : dispatcher_(dispatcher), cache_(cache), request_(std::move(request)) {} void getHeaders(LookupHeadersCallback&& cb) override { auto entry = cache_.lookup(request_); body_ = std::move(entry.body_); trailers_ = std::move(entry.trailers_); - cb(entry.response_headers_ ? request_.makeLookupResult(std::move(entry.response_headers_), - std::move(entry.metadata_), body_.size()) - : LookupResult{}, - body_.empty() && trailers_ == nullptr); + LookupResult result = entry.response_headers_ + ? request_.makeLookupResult(std::move(entry.response_headers_), + std::move(entry.metadata_), body_.size()) + : LookupResult{}; + bool end_stream = body_.empty() && trailers_ == nullptr; + dispatcher_.post([result = std::move(result), cb = std::move(cb), end_stream, + cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(std::move(result), end_stream); + } + }); } void getBody(const AdjustedByteRange& range, LookupBodyCallback&& cb) override { ASSERT(range.end() <= body_.length(), "Attempt to read past end of body."); - cb(std::make_unique(&body_[range.begin()], range.length()), - trailers_ == nullptr && range.end() == body_.length()); + auto result = std::make_unique(&body_[range.begin()], range.length()); + bool end_stream = trailers_ == nullptr && range.end() == body_.length(); + dispatcher_.post([result = std::move(result), cb = std::move(cb), end_stream, + cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(std::move(result), end_stream); + } + }); } // The cache must call cb with the cached trailers. void getTrailers(LookupTrailersCallback&& cb) override { ASSERT(trailers_); - cb(std::move(trailers_)); + dispatcher_.post( + [cb = std::move(cb), trailers = std::move(trailers_), cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(std::move(trailers)); + } + }); } const LookupRequest& request() const { return request_; } - void onDestroy() override {} + void onDestroy() override { *cancelled_ = true; } + Event::Dispatcher& dispatcher() const { return dispatcher_; } private: + Event::Dispatcher& dispatcher_; + std::shared_ptr cancelled_ = std::make_shared(false); SimpleHttpCache& cache_; const LookupRequest request_; std::string body_; @@ -70,13 +92,18 @@ class SimpleLookupContext : public LookupContext { class SimpleInsertContext : public InsertContext { public: - SimpleInsertContext(LookupContext& lookup_context, SimpleHttpCache& cache) - : key_(dynamic_cast(lookup_context).request().key()), - request_headers_( - dynamic_cast(lookup_context).request().requestHeaders()), - vary_allow_list_( - dynamic_cast(lookup_context).request().varyAllowList()), - cache_(cache) {} + SimpleInsertContext(SimpleLookupContext& lookup_context, SimpleHttpCache& cache) + : dispatcher_(lookup_context.dispatcher()), key_(lookup_context.request().key()), + request_headers_(lookup_context.request().requestHeaders()), + vary_allow_list_(lookup_context.request().varyAllowList()), cache_(cache) {} + + void post(InsertCallback cb, bool result) { + dispatcher_.post([cb = std::move(cb), result = result, cancelled = cancelled_]() mutable { + if (!*cancelled) { + std::move(cb)(result); + } + }); + } void insertHeaders(const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, InsertCallback insert_success, @@ -85,9 +112,9 @@ class SimpleInsertContext : public InsertContext { response_headers_ = Http::createHeaderMap(response_headers); metadata_ = metadata; if (end_stream) { - insert_success(commit()); + post(std::move(insert_success), commit()); } else { - insert_success(true); + post(std::move(insert_success), true); } } @@ -98,9 +125,9 @@ class SimpleInsertContext : public InsertContext { body_.add(chunk); if (end_stream) { - ready_for_next_chunk(commit()); + post(std::move(ready_for_next_chunk), commit()); } else { - ready_for_next_chunk(true); + post(std::move(ready_for_next_chunk), true); } } @@ -108,10 +135,10 @@ class SimpleInsertContext : public InsertContext { InsertCallback insert_complete) override { ASSERT(!committed_); trailers_ = Http::createHeaderMap(trailers); - insert_complete(commit()); + post(std::move(insert_complete), commit()); } - void onDestroy() override {} + void onDestroy() override { *cancelled_ = true; } private: bool commit() { @@ -126,6 +153,8 @@ class SimpleInsertContext : public InsertContext { } } + Event::Dispatcher& dispatcher_; + std::shared_ptr cancelled_ = std::make_shared(false); Key key_; const Http::RequestHeaderMap& request_headers_; const VaryAllowList& vary_allow_list_; @@ -139,32 +168,38 @@ class SimpleInsertContext : public InsertContext { } // namespace LookupContextPtr SimpleHttpCache::makeLookupContext(LookupRequest&& request, - Http::StreamDecoderFilterCallbacks&) { - return std::make_unique(*this, std::move(request)); + Http::StreamDecoderFilterCallbacks& callbacks) { + return std::make_unique(callbacks.dispatcher(), *this, std::move(request)); } void SimpleHttpCache::updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, const ResponseMetadata& metadata, - std::function on_complete) { + UpdateHeadersCallback on_complete) { const auto& simple_lookup_context = static_cast(lookup_context); const Key& key = simple_lookup_context.request().key(); absl::WriterMutexLock lock(&mutex_); auto iter = map_.find(key); + auto post_complete = [on_complete = std::move(on_complete), + &dispatcher = simple_lookup_context.dispatcher()](bool result) mutable { + dispatcher.post([on_complete = std::move(on_complete), result]() mutable { + std::move(on_complete)(result); + }); + }; if (iter == map_.end() || !iter->second.response_headers_) { - on_complete(false); + std::move(post_complete)(false); return; } if (VaryHeaderUtils::hasVary(*iter->second.response_headers_)) { absl::optional varied_key = variedRequestKey(simple_lookup_context.request(), *iter->second.response_headers_); if (!varied_key.has_value()) { - on_complete(false); + std::move(post_complete)(false); return; } iter = map_.find(varied_key.value()); if (iter == map_.end() || !iter->second.response_headers_) { - on_complete(false); + std::move(post_complete)(false); return; } } @@ -172,7 +207,7 @@ void SimpleHttpCache::updateHeaders(const LookupContext& lookup_context, applyHeaderUpdate(response_headers, *entry.response_headers_); entry.metadata_ = metadata; - on_complete(true); + std::move(post_complete)(true); } SimpleHttpCache::Entry SimpleHttpCache::lookup(const LookupRequest& request) { @@ -278,7 +313,8 @@ bool SimpleHttpCache::varyInsert(const Key& request_key, InsertContextPtr SimpleHttpCache::makeInsertContext(LookupContextPtr&& lookup_context, Http::StreamEncoderFilterCallbacks&) { ASSERT(lookup_context != nullptr); - return std::make_unique(*lookup_context, *this); + return std::make_unique(dynamic_cast(*lookup_context), + *this); } constexpr absl::string_view Name = "envoy.extensions.http.cache.simple"; diff --git a/source/extensions/http/cache/simple_http_cache/simple_http_cache.h b/source/extensions/http/cache/simple_http_cache/simple_http_cache.h index bfad38d9fdaf..515ccc5f0183 100644 --- a/source/extensions/http/cache/simple_http_cache/simple_http_cache.h +++ b/source/extensions/http/cache/simple_http_cache/simple_http_cache.h @@ -40,8 +40,7 @@ class SimpleHttpCache : public HttpCache, public Singleton::Instance { Http::StreamEncoderFilterCallbacks& callbacks) override; void updateHeaders(const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, - std::function on_complete) override; + const ResponseMetadata& metadata, UpdateHeadersCallback on_complete) override; CacheInfo cacheInfo() const override; Entry lookup(const LookupRequest& request); diff --git a/test/extensions/filters/http/cache/cache_filter_test.cc b/test/extensions/filters/http/cache/cache_filter_test.cc index 27bd88716b80..14edc705c36c 100644 --- a/test/extensions/filters/http/cache/cache_filter_test.cc +++ b/test/extensions/filters/http/cache/cache_filter_test.cc @@ -291,6 +291,40 @@ TEST_F(CacheFilterTest, CacheMissWithTrailers) { dispatcher_->run(Event::Dispatcher::RunType::Block); } +TEST_F(CacheFilterTest, CacheMissWithTrailersWhenCacheRespondsQuickerThanUpstream) { + request_headers_.setHost("CacheMissWithTrailers"); + const std::string body = "abc"; + Buffer::OwnedImpl body_buffer(body); + Http::TestResponseTrailerMapImpl trailers; + + for (int request = 1; request <= 2; request++) { + // Each iteration a request is sent to a different host, therefore the second one is a miss + request_headers_.setHost("CacheMissWithTrailers" + std::to_string(request)); + + // Create filter for request 1 + CacheFilterSharedPtr filter = makeFilter(simple_cache_); + + testDecodeRequestMiss(filter); + + // Encode response header + EXPECT_EQ(filter->encodeHeaders(response_headers_, false), Http::FilterHeadersStatus::Continue); + // Resolve cache response + dispatcher_->run(Event::Dispatcher::RunType::Block); + EXPECT_EQ(filter->encodeData(body_buffer, false), Http::FilterDataStatus::Continue); + // Resolve cache response + dispatcher_->run(Event::Dispatcher::RunType::Block); + EXPECT_EQ(filter->encodeTrailers(trailers), Http::FilterTrailersStatus::Continue); + // Resolve cache response + dispatcher_->run(Event::Dispatcher::RunType::Block); + + filter->onStreamComplete(); + EXPECT_THAT(lookupStatus(), IsOkAndHolds(LookupStatus::CacheMiss)); + EXPECT_THAT(insertStatus(), IsOkAndHolds(InsertStatus::InsertSucceeded)); + } + // Clear events off the dispatcher. + dispatcher_->run(Event::Dispatcher::RunType::Block); +} + TEST_F(CacheFilterTest, CacheHitNoBody) { request_headers_.setHost("CacheHitNoBody"); @@ -372,22 +406,25 @@ TEST_F(CacheFilterTest, WatermarkEventsAreSentIfCacheBlocksStreamAndLimitExceede return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, - InsertCallback insert_complete, bool) { insert_complete(true); }); + InsertCallback insert_complete, bool) { + dispatcher_->post([cb = std::move(insert_complete)]() mutable { std::move(cb)(true); }); + }); InsertCallback captured_insert_body_callback; // The first time insertBody is called, block until the test is ready to call it. // For completion chunk, complete immediately. EXPECT_CALL(*mock_insert_context, insertBody(_, _, false)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { EXPECT_THAT(captured_insert_body_callback, IsNull()); - captured_insert_body_callback = ready_for_next_chunk; + captured_insert_body_callback = std::move(ready_for_next_chunk); }); EXPECT_CALL(*mock_insert_context, insertBody(_, _, true)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { - ready_for_next_chunk(true); + dispatcher_->post( + [cb = std::move(ready_for_next_chunk)]() mutable { std::move(cb)(true); }); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -444,18 +481,20 @@ TEST_F(CacheFilterTest, FilterDestroyedWhileWatermarkedSendsLowWatermarkEvent) { return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, - InsertCallback insert_complete, bool) { insert_complete(true); }); + InsertCallback insert_complete, bool) { + dispatcher_->post([cb = std::move(insert_complete)]() mutable { std::move(cb)(true); }); + }); InsertCallback captured_insert_body_callback; // The first time insertBody is called, block until the test is ready to call it. // Cache aborts, so there is no second call. EXPECT_CALL(*mock_insert_context, insertBody(_, _, false)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { EXPECT_THAT(captured_insert_body_callback, IsNull()); - captured_insert_body_callback = ready_for_next_chunk; + captured_insert_body_callback = std::move(ready_for_next_chunk); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -477,15 +516,15 @@ TEST_F(CacheFilterTest, FilterDestroyedWhileWatermarkedSendsLowWatermarkEvent) { Buffer::OwnedImpl body1buf(body1); Buffer::OwnedImpl body2buf(body2); EXPECT_EQ(filter->encodeData(body1buf, false), Http::FilterDataStatus::Continue); + dispatcher_->run(Event::Dispatcher::RunType::Block); EXPECT_EQ(filter->encodeData(body2buf, true), Http::FilterDataStatus::Continue); + dispatcher_->run(Event::Dispatcher::RunType::Block); ASSERT_THAT(captured_insert_body_callback, NotNull()); // When the filter is destroyed, a low watermark event should be sent. EXPECT_CALL(encoder_callbacks_, onEncoderFilterBelowWriteBufferLowWatermark()); filter->onDestroy(); filter.reset(); captured_insert_body_callback(false); - // The cache insertBody callback should be posted to the dispatcher. - // Run events on the dispatcher so that the callback is invoked. dispatcher_->run(Event::Dispatcher::RunType::Block); } } @@ -507,19 +546,28 @@ TEST_F(CacheFilterTest, CacheEntryStreamedWithTrailersAndNoContentLengthCanDeliv }); // response_headers_ intentionally has no content length, LookupResult also has no content length. EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{CacheEntryStatus::Ok, - std::make_unique(response_headers_), - absl::nullopt, absl::nullopt}, - /* end_stream = */ false); + dispatcher_->post([cb = std::move(cb), this]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, + std::make_unique(response_headers_), + absl::nullopt, absl::nullopt}, + /* end_stream = */ false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, Gt(5)), _)) .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { - cb(std::make_unique(body), false); + dispatcher_->post([cb = std::move(cb), &body]() mutable { + std::move(cb)(std::make_unique(body), false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(5, Gt(5)), _)) - .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { cb(nullptr, false); }); + .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(nullptr, false); }); + }); EXPECT_CALL(*mock_lookup_context, getTrailers(_)).WillOnce([&](LookupTrailersCallback&& cb) { - cb(std::make_unique()); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique()); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); { @@ -558,7 +606,11 @@ TEST_F(CacheFilterTest, OnDestroyBeforeOnHeadersAbortsAction) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + }); }); auto filter = makeFilter(mock_http_cache, false); EXPECT_EQ(filter->decodeHeaders(request_headers_, true), @@ -581,19 +633,27 @@ TEST_F(CacheFilterTest, OnDestroyBeforeOnBodyAbortsAction) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + }); }); LookupBodyCallback body_callback; EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) - .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { body_callback = cb; }); + .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { + body_callback = std::move(cb); + }); + EXPECT_CALL(*mock_lookup_context, onDestroy()); auto filter = makeFilter(mock_http_cache, false); EXPECT_EQ(filter->decodeHeaders(request_headers_, true), Http::FilterHeadersStatus::StopAllIterationAndWatermark); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); filter->onDestroy(); - // onBody should do nothing because the filter was destroyed. - body_callback(std::make_unique("abcde"), true); - dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + ::testing::Mock::VerifyAndClearExpectations(mock_lookup_context.get()); + EXPECT_THAT(body_callback, NotNull()); + // body_callback should not be called because LookupContext::onDestroy, + // correctly implemented, should have aborted it. } TEST_F(CacheFilterTest, OnDestroyBeforeOnTrailersAbortsAction) { @@ -608,15 +668,21 @@ TEST_F(CacheFilterTest, OnDestroyBeforeOnTrailersAbortsAction) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 5, absl::nullopt}, false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { - cb(std::make_unique("abcde"), false); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique("abcde"), false); + }); }); LookupTrailersCallback trailers_callback; EXPECT_CALL(*mock_lookup_context, getTrailers(_)).WillOnce([&](LookupTrailersCallback&& cb) { - trailers_callback = cb; + trailers_callback = std::move(cb); }); auto filter = makeFilter(mock_http_cache, false); EXPECT_EQ(filter->decodeHeaders(request_headers_, true), @@ -643,15 +709,23 @@ TEST_F(CacheFilterTest, BodyReadFromCacheLimitedToBufferSizeChunks) { EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { std::unique_ptr response_headers = std::make_unique(response_headers_); - cb(LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + dispatcher_->post([cb = std::move(cb), + response_headers = std::move(response_headers)]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, std::move(response_headers), 8, absl::nullopt}, false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { - cb(std::make_unique("abcde"), false); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique("abcde"), false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(5, 8), _)) .WillOnce([&](const AdjustedByteRange&, LookupBodyCallback&& cb) { - cb(std::make_unique("fgh"), true); + dispatcher_->post([cb = std::move(cb)]() mutable { + std::move(cb)(std::make_unique("fgh"), true); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -703,14 +777,17 @@ TEST_F(CacheFilterTest, CacheInsertAbortedByCache) { return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, - InsertCallback insert_complete, bool) { insert_complete(true); }); + InsertCallback insert_complete, bool) { + dispatcher_->post([cb = std::move(insert_complete)]() mutable { std::move(cb)(true); }); + }); EXPECT_CALL(*mock_insert_context, insertBody(_, _, true)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { - ready_for_next_chunk(false); + dispatcher_->post( + [cb = std::move(ready_for_next_chunk)]() mutable { std::move(cb)(false); }); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -753,13 +830,13 @@ TEST_F(CacheFilterTest, FilterDeletedWhileIncompleteCacheWriteInQueueShouldAband return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); InsertCallback captured_insert_header_callback; EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, InsertCallback insert_complete, - bool) { captured_insert_header_callback = insert_complete; }); + bool) { captured_insert_header_callback = std::move(insert_complete); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); { @@ -772,16 +849,14 @@ TEST_F(CacheFilterTest, FilterDeletedWhileIncompleteCacheWriteInQueueShouldAband // Encode header of response. response_headers_.setContentLength(body.size()); EXPECT_EQ(filter->encodeHeaders(response_headers_, false), Http::FilterHeadersStatus::Continue); - // Destroy the filter prematurely. + // Destroy the filter prematurely (it goes out of scope). } ASSERT_THAT(captured_insert_header_callback, NotNull()); EXPECT_THAT(weak_cache_pointer.lock(), NotNull()) << "cache instance was unexpectedly destroyed when filter was destroyed"; + // The callback should now do nothing visible, because the filter has been destroyed. + // Calling it allows the CacheInsertQueue to discard its self-ownership. captured_insert_header_callback(true); - // The callback should be posted to the dispatcher. - // Run events on the dispatcher so that the callback is invoked, - // where it should now do nothing due to the filter being destroyed. - dispatcher_->run(Event::Dispatcher::RunType::Block); } TEST_F(CacheFilterTest, FilterDeletedWhileCompleteCacheWriteInQueueShouldContinueWrite) { @@ -801,17 +876,17 @@ TEST_F(CacheFilterTest, FilterDeletedWhileCompleteCacheWriteInQueueShouldContinu return std::move(mock_insert_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - cb(LookupResult{}, false); + dispatcher_->post([cb = std::move(cb)]() mutable { std::move(cb)(LookupResult{}, false); }); }); InsertCallback captured_insert_header_callback; InsertCallback captured_insert_body_callback; EXPECT_CALL(*mock_insert_context, insertHeaders(_, _, _, false)) .WillOnce([&](const Http::ResponseHeaderMap&, const ResponseMetadata&, InsertCallback insert_complete, - bool) { captured_insert_header_callback = insert_complete; }); + bool) { captured_insert_header_callback = std::move(insert_complete); }); EXPECT_CALL(*mock_insert_context, insertBody(_, _, true)) .WillOnce([&](const Buffer::Instance&, InsertCallback ready_for_next_chunk, bool) { - captured_insert_body_callback = ready_for_next_chunk; + captured_insert_body_callback = std::move(ready_for_next_chunk); }); EXPECT_CALL(*mock_insert_context, onDestroy()); EXPECT_CALL(*mock_lookup_context, onDestroy()); @@ -1391,12 +1466,15 @@ TEST_F(CacheFilterDeathTest, BadRangeRequestLookup) { return std::move(mock_lookup_context); }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { - // LookupResult with unknown length and an unsatisfiable RangeDetails is invalid. - cb(LookupResult{CacheEntryStatus::Ok, - std::make_unique(response_headers_), - absl::nullopt, - RangeDetails{/*satisfiable_ = */ false, {AdjustedByteRange{0, 5}}}}, - false); + dispatcher_->post([cb = std::move(cb), this]() mutable { + // LookupResult with unknown length and an unsatisfiable RangeDetails is invalid. + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, + std::make_unique(response_headers_), + absl::nullopt, + RangeDetails{/*satisfiable_ = */ false, {AdjustedByteRange{0, 5}}}}, + false); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); { @@ -1429,15 +1507,20 @@ TEST_F(CacheFilterTest, RangeRequestSatisfiedBeforeLengthKnown) { }); EXPECT_CALL(*mock_lookup_context, getHeaders(_)).WillOnce([&](LookupHeadersCallback&& cb) { // LookupResult with unknown length and an unsatisfiable RangeDetails is invalid. - cb(LookupResult{CacheEntryStatus::Ok, - std::make_unique(response_headers_), - absl::nullopt, - RangeDetails{/*satisfiable_ = */ true, {AdjustedByteRange{0, 5}}}}, - false); + dispatcher_->post([cb = std::move(cb), this]() mutable { + std::move(cb)( + LookupResult{CacheEntryStatus::Ok, + std::make_unique(response_headers_), + absl::nullopt, + RangeDetails{/*satisfiable_ = */ true, {AdjustedByteRange{0, 5}}}}, + false); + }); }); EXPECT_CALL(*mock_lookup_context, getBody(RangeMatcher(0, 5), _)) .WillOnce([&](AdjustedByteRange, LookupBodyCallback&& cb) { - cb(std::make_unique(body), false); + dispatcher_->post([cb = std::move(cb), &body]() mutable { + cb(std::make_unique(body), false); + }); }); EXPECT_CALL(*mock_lookup_context, onDestroy()); { diff --git a/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc b/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc index c5df0fa728f6..d4c6bb578b01 100644 --- a/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc +++ b/test/extensions/filters/http/cache/http_cache_implementation_test_common.cc @@ -153,6 +153,20 @@ absl::Status HttpCacheImplementationTest::insert( return absl::OkStatus(); } +LookupContextPtr HttpCacheImplementationTest::lookupContextWithAllParts() { + absl::string_view path = "/common"; + Http::TestResponseHeaderMapImpl response_headers{ + {":status", "200"}, + {"date", formatter_.fromTime(time_system_.systemTime())}, + {"cache-control", "public,max-age=3600"}}; + Http::TestResponseTrailerMapImpl response_trailers{ + {"common-trailer", "irrelevant value"}, + }; + EXPECT_THAT(insert(lookup(path), response_headers, "commonbody", response_trailers), IsOk()); + LookupRequest request = makeLookupRequest(path); + return cache()->makeLookupContext(std::move(request), decoder_callbacks_); +} + absl::Status HttpCacheImplementationTest::insert(absl::string_view request_path, const Http::TestResponseHeaderMapImpl& headers, const absl::string_view body) { @@ -777,6 +791,48 @@ TEST_P(HttpCacheImplementationTest, EmptyTrailers) { EXPECT_TRUE(expectLookupSuccessWithBodyAndTrailers(name_lookup_context.get(), body1)); } +TEST_P(HttpCacheImplementationTest, DoesNotRunHeadersCallbackWhenCancelledAfterPosted) { + bool was_called = false; + { + LookupContextPtr context = lookupContextWithAllParts(); + context->getHeaders([&was_called](LookupResult&&, bool) { was_called = true; }); + pumpIntoDispatcher(); + context->onDestroy(); + } + pumpDispatcher(); + EXPECT_FALSE(was_called); +} + +TEST_P(HttpCacheImplementationTest, DoesNotRunBodyCallbackWhenCancelledAfterPosted) { + bool was_called = false; + { + LookupContextPtr context = lookupContextWithAllParts(); + context->getHeaders([](LookupResult&&, bool) {}); + pumpDispatcher(); + context->getBody({0, 10}, [&was_called](Buffer::InstancePtr&&, bool) { was_called = true; }); + pumpIntoDispatcher(); + context->onDestroy(); + } + pumpDispatcher(); + EXPECT_FALSE(was_called); +} + +TEST_P(HttpCacheImplementationTest, DoesNotRunTrailersCallbackWhenCancelledAfterPosted) { + bool was_called = false; + { + LookupContextPtr context = lookupContextWithAllParts(); + context->getHeaders([](LookupResult&&, bool) {}); + pumpDispatcher(); + context->getBody({0, 10}, [](Buffer::InstancePtr&&, bool) {}); + pumpDispatcher(); + context->getTrailers([&was_called](Http::ResponseTrailerMapPtr&&) { was_called = true; }); + pumpIntoDispatcher(); + context->onDestroy(); + } + pumpDispatcher(); + EXPECT_FALSE(was_called); +} + } // namespace Cache } // namespace HttpFilters } // namespace Extensions diff --git a/test/extensions/filters/http/cache/http_cache_implementation_test_common.h b/test/extensions/filters/http/cache/http_cache_implementation_test_common.h index 497bfcad81eb..26f20312fd70 100644 --- a/test/extensions/filters/http/cache/http_cache_implementation_test_common.h +++ b/test/extensions/filters/http/cache/http_cache_implementation_test_common.h @@ -63,6 +63,7 @@ class HttpCacheImplementationTest std::shared_ptr cache() const { return delegate_->cache(); } bool validationEnabled() const { return delegate_->validationEnabled(); } + void pumpIntoDispatcher() { delegate_->beforePumpingDispatcher(); } void pumpDispatcher() { delegate_->pumpDispatcher(); } LookupContextPtr lookup(absl::string_view request_path); @@ -91,6 +92,8 @@ class HttpCacheImplementationTest LookupRequest makeLookupRequest(absl::string_view request_path); + LookupContextPtr lookupContextWithAllParts(); + testing::AssertionResult expectLookupSuccessWithHeaders(LookupContext* lookup_context, const Http::TestResponseHeaderMapImpl& headers); diff --git a/test/extensions/filters/http/cache/mocks.h b/test/extensions/filters/http/cache/mocks.h index 0a243fab1910..d1739f958448 100644 --- a/test/extensions/filters/http/cache/mocks.h +++ b/test/extensions/filters/http/cache/mocks.h @@ -17,7 +17,7 @@ class MockHttpCache : public HttpCache { (LookupContextPtr && lookup_context, Http::StreamEncoderFilterCallbacks& callbacks)); MOCK_METHOD(void, updateHeaders, (const LookupContext& lookup_context, const Http::ResponseHeaderMap& response_headers, - const ResponseMetadata& metadata, std::function on_complete)); + const ResponseMetadata& metadata, absl::AnyInvocable on_complete)); MOCK_METHOD(CacheInfo, cacheInfo, (), (const)); };