Skip to content

Commit

Permalink
collect_block in QuickwitCollector (#4753)
Browse files Browse the repository at this point in the history
* collect_block in QuickwitCollector

collect_block + using `first_vals` to batch fetch sort values

* Isolating the 3 components of the collector.

* Apply suggestions from code review

Co-authored-by: Paul Masurel <[email protected]>

* fix var name

---------

Co-authored-by: Paul Masurel <[email protected]>
  • Loading branch information
PSeitz and fulmicoton authored Mar 18, 2024
1 parent 33920d6 commit 3b43f71
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 38 deletions.
285 changes: 248 additions & 37 deletions quickwit/quickwit-search/src/collector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,20 @@ enum SortingFieldExtractorComponent {
}

impl SortingFieldExtractorComponent {
fn is_fast_field(&self) -> bool {
matches!(self, SortingFieldExtractorComponent::FastField { .. })
}
/// Loads the fast field values for the given doc_ids in its u64 representation. The returned
/// u64 representation maintains the ordering of the original value.
#[inline]
fn extract_typed_sort_values_block(&self, doc_ids: &[DocId], values: &mut [Option<u64>]) {
// In the collect block case we don't have scores to extract
if let SortingFieldExtractorComponent::FastField { sort_column, .. } = self {
let values = &mut values[..doc_ids.len()];
sort_column.first_vals(doc_ids, values);
}
}

/// Returns the sort value for the given element in its u64 representation. The returned u64
/// representation maintains the ordering of the original value.
///
Expand Down Expand Up @@ -369,6 +383,23 @@ pub(crate) struct SortingFieldExtractorPair {
}

impl SortingFieldExtractorPair {
/// Returns the list of sort values for the given element
///
/// See also [`SortingFieldExtractorComponent::extract_typed_sort_values_block`] for more
/// information.
#[inline]
fn extract_typed_sort_values(
&self,
doc_ids: &[DocId],
values1: &mut [Option<u64>],
values2: &mut [Option<u64>],
) {
self.first
.extract_typed_sort_values_block(doc_ids, &mut values1[..doc_ids.len()]);
if let Some(second) = self.second.as_ref() {
second.extract_typed_sort_values_block(doc_ids, &mut values2[..doc_ids.len()]);
}
}
/// Returns the list of sort values for the given element
///
/// See also [`SortingFieldExtractorComponent::extract_typed_sort_value_opt`] for more
Expand Down Expand Up @@ -469,17 +500,36 @@ enum AggregationSegmentCollectors {

/// Quickwit collector working at the scale of the segment.
pub struct QuickwitSegmentCollector {
timestamp_filter_opt: Option<TimestampFilter>,
segment_top_k_collector: Option<QuickwitSegmentTopKCollector>,
// Caches for block fetching
filtered_docs: Box<[DocId; 64]>,
aggregation: Option<AggregationSegmentCollectors>,
num_hits: u64,
}

impl QuickwitSegmentCollector {
#[inline]
fn accept_document(&self, doc_id: DocId) -> bool {
if let Some(ref timestamp_filter) = self.timestamp_filter_opt {
return timestamp_filter.is_within_range(doc_id);
}
true
}
}

/// Quickwit collector working at the scale of the segment.
struct QuickwitSegmentTopKCollector {
split_id: String,
score_extractor: SortingFieldExtractorPair,
// PartialHits in this heap don't contain a split_id yet.
top_k_hits: TopK<SegmentPartialHit, SegmentPartialHitSortingKey, HitSortingMapper>,
segment_ord: u32,
timestamp_filter_opt: Option<TimestampFilter>,
aggregation: Option<AggregationSegmentCollectors>,
search_after: Option<SearchAfterSegment>,
// Precomputed order for search_after for split_id and segment_ord
precomp_search_after_order: Ordering,
sort_values1: Box<[Option<u64>; 64]>,
sort_values2: Box<[Option<u64>; 64]>,
}

/// Search After, but the sort values are converted to the u64 fast field representation.
Expand Down Expand Up @@ -542,16 +592,100 @@ impl SearchAfterSegment {
}
}

impl QuickwitSegmentCollector {
#[inline]
fn collect_top_k(&mut self, doc_id: DocId, score: Score) {
let (sort_value, sort_value2): (Option<u64>, Option<u64>) =
self.score_extractor.extract_typed_sort_value(doc_id, score);
impl QuickwitSegmentTopKCollector {
fn collect_top_k_block(&mut self, docs: &[DocId]) {
self.score_extractor.extract_typed_sort_values(
docs,
&mut self.sort_values1[..],
&mut self.sort_values2[..],
);
if self.search_after.is_some() {
// Search after not optimized for block collection yet
for ((doc_id, sort_value), sort_value2) in docs
.iter()
.cloned()
.zip(self.sort_values1.iter().cloned())
.zip(self.sort_values2.iter().cloned())
{
Self::collect_top_k_vals(
doc_id,
sort_value,
sort_value2,
&self.search_after,
self.precomp_search_after_order,
&mut self.top_k_hits,
);
}
} else {
// Probaly would make sense to check the fence against e.g. sort_values1 earlier,
// before creating the SegmentPartialHit.
//
// Below are different versions to avoid iterating the caches if they are unused.
//
// No sort values loaded. Sort only by doc_id.
if !self.score_extractor.first.is_fast_field() {
for doc_id in docs.iter().cloned() {
let hit = SegmentPartialHit {
sort_value: None,
sort_value2: None,
doc_id,
};
self.top_k_hits.add_entry(hit);
}
return;
}
let has_no_second_sort = !self
.score_extractor
.second
.as_ref()
.map(|extr| extr.is_fast_field())
.unwrap_or(false);
// No second sort values => We can skip iterating the second sort values cache.
if has_no_second_sort {
for (doc_id, sort_value) in
docs.iter().cloned().zip(self.sort_values1.iter().cloned())
{
let hit = SegmentPartialHit {
sort_value,
sort_value2: None,
doc_id,
};
self.top_k_hits.add_entry(hit);
}
return;
}

if let Some(search_after) = &self.search_after {
for ((doc_id, sort_value), sort_value2) in docs
.iter()
.cloned()
.zip(self.sort_values1.iter().cloned())
.zip(self.sort_values2.iter().cloned())
{
let hit = SegmentPartialHit {
sort_value,
sort_value2,
doc_id,
};
self.top_k_hits.add_entry(hit);
}
}
}
#[inline]
/// Generic top k collection, that includes search_after handling
///
/// Outside of the collector to circumvent lifetime issues.
fn collect_top_k_vals(
doc_id: DocId,
sort_value: Option<u64>,
sort_value2: Option<u64>,
search_after: &Option<SearchAfterSegment>,
precomp_search_after_order: Ordering,
top_k_hits: &mut TopK<SegmentPartialHit, SegmentPartialHitSortingKey, HitSortingMapper>,
) {
if let Some(search_after) = &search_after {
let search_after_value1 = search_after.sort_value;
let search_after_value2 = search_after.sort_value2;
let orders = &self.top_k_hits.sort_key_mapper;
let orders = &top_k_hits.sort_key_mapper;
let mut cmp_result = orders
.order1
.compare_opt(&sort_value, &search_after_value1)
Expand All @@ -565,7 +699,7 @@ impl QuickwitSegmentCollector {
// default
let order = orders.order1;
cmp_result = cmp_result
.then(self.precomp_search_after_order)
.then(precomp_search_after_order)
// We compare doc_id only if sort_value1, sort_value2, split_id and segment_ord
// are equal.
.then_with(|| order.compare(&doc_id, &search_after.doc_id))
Expand All @@ -581,15 +715,21 @@ impl QuickwitSegmentCollector {
sort_value2,
doc_id,
};
self.top_k_hits.add_entry(hit);
top_k_hits.add_entry(hit);
}

#[inline]
fn accept_document(&self, doc_id: DocId) -> bool {
if let Some(ref timestamp_filter) = self.timestamp_filter_opt {
return timestamp_filter.is_within_range(doc_id);
}
true
fn collect_top_k(&mut self, doc_id: DocId, score: Score) {
let (sort_value, sort_value2): (Option<u64>, Option<u64>) =
self.score_extractor.extract_typed_sort_value(doc_id, score);
Self::collect_top_k_vals(
doc_id,
sort_value,
sort_value2,
&self.search_after,
self.precomp_search_after_order,
&mut self.top_k_hits,
);
}
}

Expand Down Expand Up @@ -635,17 +775,72 @@ impl SegmentPartialHit {
}
}

pub use tantivy::COLLECT_BLOCK_BUFFER_LEN;
/// Store the filtered docs in `filtered_docs_buffer` if `timestamp_filter_opt` is present.
///
/// Returns the number of docs.
///
/// Ideally we would return just final docs slice, but we can't do that because of the borrow
/// checker.
fn compute_filtered_block<'a>(
timestamp_filter_opt: &Option<TimestampFilter>,
docs: &'a [DocId],
filtered_docs_buffer: &'a mut [DocId; COLLECT_BLOCK_BUFFER_LEN],
) -> &'a [DocId] {
let Some(timestamp_filter) = &timestamp_filter_opt else {
return docs;
};
let mut len = 0;
for &doc in docs {
filtered_docs_buffer[len] = doc;
len += if timestamp_filter.is_within_range(doc) {
1
} else {
0
};
}
&filtered_docs_buffer[..len]
}

impl SegmentCollector for QuickwitSegmentCollector {
type Fruit = tantivy::Result<LeafSearchResponse>;

#[inline]
fn collect_block(&mut self, unfiltered_docs: &[DocId]) {
let filtered_docs: &[DocId] = compute_filtered_block(
&self.timestamp_filter_opt,
unfiltered_docs,
&mut self.filtered_docs,
);

// Update results
self.num_hits += filtered_docs.len() as u64;

if let Some(segment_top_k_collector) = self.segment_top_k_collector.as_mut() {
segment_top_k_collector.collect_top_k_block(filtered_docs);
}

match self.aggregation.as_mut() {
Some(AggregationSegmentCollectors::FindTraceIdsSegmentCollector(collector)) => {
collector.collect_block(filtered_docs)
}
Some(AggregationSegmentCollectors::TantivyAggregationSegmentCollector(collector)) => {
collector.collect_block(filtered_docs)
}
None => (),
}
}

#[inline]
fn collect(&mut self, doc_id: DocId, score: Score) {
if !self.accept_document(doc_id) {
return;
}

self.num_hits += 1;
self.collect_top_k(doc_id, score);
if let Some(segment_top_k_collector) = self.segment_top_k_collector.as_mut() {
segment_top_k_collector.collect_top_k(doc_id, score);
}

match self.aggregation.as_mut() {
Some(AggregationSegmentCollectors::FindTraceIdsSegmentCollector(collector)) => {
Expand All @@ -659,19 +854,23 @@ impl SegmentCollector for QuickwitSegmentCollector {
}

fn harvest(self) -> Self::Fruit {
let partial_hits: Vec<PartialHit> = self
.top_k_hits
.finalize()
.into_iter()
.map(|segment_partial_hit: SegmentPartialHit| {
segment_partial_hit.into_partial_hit(
self.split_id.clone(),
self.segment_ord,
&self.score_extractor.first,
&self.score_extractor.second,
)
})
.collect();
let mut partial_hits: Vec<PartialHit> = Vec::new();
if let Some(segment_top_k_collector) = self.segment_top_k_collector {
// TODO put that in a method of segment_top_k_collector
partial_hits = segment_top_k_collector
.top_k_hits
.finalize()
.into_iter()
.map(|segment_partial_hit: SegmentPartialHit| {
segment_partial_hit.into_partial_hit(
segment_top_k_collector.split_id.clone(),
segment_top_k_collector.segment_ord,
&segment_top_k_collector.score_extractor.first,
&segment_top_k_collector.score_extractor.second,
)
})
.collect();
}

let intermediate_aggregation_result = match self.aggregation {
Some(AggregationSegmentCollectors::FindTraceIdsSegmentCollector(collector)) => {
Expand Down Expand Up @@ -897,16 +1096,28 @@ impl Collector for QuickwitCollector {
// Convert search_after into fast field u64
let search_after =
SearchAfterSegment::new(self.search_after.clone(), order1, order2, &score_extractor);

let segment_top_k_collector = if leaf_max_hits == 0 {
None
} else {
Some(QuickwitSegmentTopKCollector {
split_id: self.split_id.clone(),
score_extractor,
top_k_hits: TopK::new(leaf_max_hits, sort_key_mapper),
segment_ord,
search_after,
precomp_search_after_order,
sort_values1: Box::new([None; 64]),
sort_values2: Box::new([None; 64]),
})
};

Ok(QuickwitSegmentCollector {
num_hits: 0u64,
split_id: self.split_id.clone(),
score_extractor,
top_k_hits: TopK::new(leaf_max_hits, sort_key_mapper),
segment_ord,
num_hits: 0,
timestamp_filter_opt,
segment_top_k_collector,
aggregation,
search_after,
precomp_search_after_order,
filtered_docs: Box::new([0; 64]),
})
}

Expand Down
2 changes: 1 addition & 1 deletion quickwit/rest-api-tests/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def check_result_list(result, expected, context_path=""):
display_filtered_result = filtered_result[:5] + ['...'] if len(filtered_result) > 5 else filtered_result
else:
display_filtered_result = filtered_result
raise Exception("Wrong length at context %s. Expected: %s Received: %s,\n Expected \n%s \n Received \n%s" % (context_path, len(expected), len(result), display_filtered_result, expected))
raise Exception("Wrong length at context %s. Expected: %s Received: %s,\n Expected \n%s \n Received \n%s" % (context_path, len(expected), len(result), expected, display_filtered_result))
raise Exception("Wrong length at context %s. Expected: %s Received: %s" % (context_path, len(expected), len(result)))
for (i, (left, right)) in enumerate(zip(result, expected)):
check_result(left, right, context_path + "[%s]" % i)
Expand Down

0 comments on commit 3b43f71

Please sign in to comment.