Skip to content

Commit

Permalink
feat(query): push rank limit into aggregate partial node (#16466)
Browse files Browse the repository at this point in the history
* feat(query): push rank limit into aggregate partial node

* feat(query): push rank limit into aggregate partial node

* feat(query): push rank limit into aggregate partial node

* feat(query): push rank limit into aggregate partial node

* new-name
  • Loading branch information
sundy-li committed Sep 19, 2024
1 parent 2b6431c commit 1fac99d
Show file tree
Hide file tree
Showing 39 changed files with 395 additions and 785 deletions.
6 changes: 6 additions & 0 deletions src/query/expression/src/kernels/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ pub enum LimitType {
}

impl LimitType {
pub fn from_limit_rows(limit: Option<usize>) -> Self {
match limit {
Some(limit) => LimitType::LimitRows(limit),
None => LimitType::None,
}
}
pub fn limit_rows(&self, rows: usize) -> usize {
match self {
LimitType::LimitRows(limit) => *limit,
Expand Down
10 changes: 7 additions & 3 deletions src/query/expression/src/kernels/sort_compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ macro_rules! do_sorter {
impl SortCompare {
pub fn new(ordering_descs: Vec<SortColumnDescription>, rows: usize, limit: LimitType) -> Self {
let equality_index =
if ordering_descs.len() == 1 && matches!(limit, LimitType::LimitRank(_)) {
if ordering_descs.len() == 1 && !matches!(limit, LimitType::LimitRank(_)) {
vec![]
} else {
vec![1; rows as _]
Expand All @@ -114,6 +114,11 @@ impl SortCompare {
}
}

fn need_update_equality_index(&self) -> bool {
self.current_column_index != self.ordering_descs.len() - 1
|| matches!(self.limit, LimitType::LimitRank(_))
}

pub fn increment_column_index(&mut self) {
self.current_column_index += 1;
}
Expand Down Expand Up @@ -196,8 +201,7 @@ impl SortCompare {
} else {
let mut current = 1;
let len = self.rows;
let need_update_equality_index =
self.current_column_index != self.ordering_descs.len() - 1;
let need_update_equality_index = self.need_update_equality_index();

while current < len {
// Find the start of the next range of equal elements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;

use databend_common_exception::Result;
use databend_common_expression::DataBlock;
use databend_common_expression::LimitType;
use databend_common_expression::SortColumnDescription;
use databend_common_pipeline_core::processors::InputPort;
use databend_common_pipeline_core::processors::OutputPort;
Expand All @@ -25,13 +26,13 @@ use crate::processors::transforms::Transform;
use crate::processors::transforms::Transformer;

pub struct TransformSortPartial {
limit: Option<usize>,
limit: LimitType,
sort_columns_descriptions: Arc<Vec<SortColumnDescription>>,
}

impl TransformSortPartial {
pub fn new(
limit: Option<usize>,
limit: LimitType,
sort_columns_descriptions: Arc<Vec<SortColumnDescription>>,
) -> Self {
Self {
Expand All @@ -43,7 +44,7 @@ impl TransformSortPartial {
pub fn try_create(
input: Arc<InputPort>,
output: Arc<OutputPort>,
limit: Option<usize>,
limit: LimitType,
sort_columns_descriptions: Arc<Vec<SortColumnDescription>>,
) -> Result<Box<dyn Processor>> {
Ok(Transformer::create(input, output, TransformSortPartial {
Expand All @@ -58,6 +59,6 @@ impl Transform for TransformSortPartial {
const NAME: &'static str = "SortPartialTransform";

fn transform(&mut self, block: DataBlock) -> Result<DataBlock> {
DataBlock::sort(&block, &self.sort_columns_descriptions, self.limit)
DataBlock::sort_with_type(&block, &self.sort_columns_descriptions, self.limit)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ fn remove_exchange(plan: PhysicalPlan) -> PhysicalPlan {
input: Box::new(traverse(*plan.input)),
group_by: plan.group_by,
agg_funcs: plan.agg_funcs,
rank_limit: plan.rank_limit,
enable_experimental_aggregate_hashtable: plan
.enable_experimental_aggregate_hashtable,
group_by_display: plan.group_by_display,
Expand All @@ -310,7 +311,6 @@ fn remove_exchange(plan: PhysicalPlan) -> PhysicalPlan {
group_by: plan.group_by,
agg_funcs: plan.agg_funcs,
before_group_by_schema: plan.before_group_by_schema,
limit: plan.limit,
group_by_display: plan.group_by_display,
stat_info: plan.stat_info,
}),
Expand Down
31 changes: 26 additions & 5 deletions src/query/service/src/pipelines/builders/builder_aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ use databend_common_expression::DataBlock;
use databend_common_expression::DataSchemaRef;
use databend_common_expression::HashMethodKind;
use databend_common_expression::HashTableConfig;
use databend_common_expression::LimitType;
use databend_common_expression::SortColumnDescription;
use databend_common_functions::aggregates::AggregateFunctionFactory;
use databend_common_pipeline_core::processors::ProcessorPtr;
use databend_common_pipeline_core::query_spill_prefix;
use databend_common_pipeline_transforms::processors::TransformPipelineHelper;
use databend_common_pipeline_transforms::processors::TransformSortPartial;
use databend_common_sql::executor::physical_plans::AggregateExpand;
use databend_common_sql::executor::physical_plans::AggregateFinal;
use databend_common_sql::executor::physical_plans::AggregateFunctionDesc;
Expand Down Expand Up @@ -111,7 +114,6 @@ impl PipelineBuilder {
enable_experimental_aggregate_hashtable,
self.is_exchange_neighbor,
max_block_size as usize,
None,
max_spill_io_requests as usize,
)?;

Expand All @@ -125,7 +127,7 @@ impl PipelineBuilder {

let group_cols = &params.group_columns;
let schema_before_group_by = params.input_schema.clone();
let sample_block = DataBlock::empty_with_schema(schema_before_group_by);
let sample_block = DataBlock::empty_with_schema(schema_before_group_by.clone());
let method = DataBlock::choose_hash_method(&sample_block, group_cols, efficiently_memory)?;

// Need a global atomic to read the max current radix bits hint
Expand All @@ -136,6 +138,28 @@ impl PipelineBuilder {
.cluster_with_partial(true, self.ctx.get_cluster().nodes.len())
};

// For rank limit, we can filter data using sort with rank before partial
if let Some(rank_limit) = &aggregate.rank_limit {
let sort_desc = rank_limit
.0
.iter()
.map(|desc| {
let offset = schema_before_group_by.index_of(&desc.order_by.to_string())?;
Ok(SortColumnDescription {
offset,
asc: desc.asc,
nulls_first: desc.nulls_first,
is_nullable: schema_before_group_by.field(offset).is_nullable(), // This information is not needed here.
})
})
.collect::<Result<Vec<_>>>()?;
let sort_desc = Arc::new(sort_desc);

self.main_pipeline.add_transformer(|| {
TransformSortPartial::new(LimitType::LimitRank(rank_limit.1), sort_desc.clone())
});
}

self.main_pipeline.add_transform(|input, output| {
Ok(ProcessorPtr::create(
match params.aggregate_functions.is_empty() {
Expand Down Expand Up @@ -225,7 +249,6 @@ impl PipelineBuilder {
enable_experimental_aggregate_hashtable,
self.is_exchange_neighbor,
max_block_size as usize,
aggregate.limit,
max_spill_io_requests as usize,
)?;

Expand Down Expand Up @@ -292,7 +315,6 @@ impl PipelineBuilder {
enable_experimental_aggregate_hashtable: bool,
cluster_aggregator: bool,
max_block_size: usize,
limit: Option<usize>,
max_spill_io_requests: usize,
) -> Result<Arc<AggregatorParams>> {
let mut agg_args = Vec::with_capacity(agg_funcs.len());
Expand Down Expand Up @@ -335,7 +357,6 @@ impl PipelineBuilder {
enable_experimental_aggregate_hashtable,
cluster_aggregator,
max_block_size,
limit,
max_spill_io_requests,
)?;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ use std::sync::Arc;
use databend_common_catalog::catalog::CatalogManager;
use databend_common_exception::Result;
use databend_common_expression::DataSchema;
use databend_common_expression::LimitType;
use databend_common_expression::SortColumnDescription;
use databend_common_pipeline_core::processors::ProcessorPtr;
use databend_common_pipeline_core::DynTransformBuilder;
Expand Down Expand Up @@ -230,7 +231,7 @@ impl PipelineBuilder {
Ok(ProcessorPtr::create(TransformSortPartial::try_create(
transform_input_port,
transform_output_port,
None,
LimitType::None,
sort_desc.clone(),
)?))
},
Expand Down
8 changes: 7 additions & 1 deletion src/query/service/src/pipelines/builders/builder_sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ use std::sync::Arc;

use databend_common_exception::Result;
use databend_common_expression::DataSchemaRef;
use databend_common_expression::LimitType;
use databend_common_expression::SortColumnDescription;
use databend_common_pipeline_core::processors::ProcessorPtr;
use databend_common_pipeline_core::query_spill_prefix;
Expand Down Expand Up @@ -197,7 +198,12 @@ impl SortPipelineBuilder {

pub fn build_full_sort_pipeline(self, pipeline: &mut Pipeline) -> Result<()> {
// Partial sort
pipeline.add_transformer(|| TransformSortPartial::new(self.limit, self.sort_desc.clone()));
pipeline.add_transformer(|| {
TransformSortPartial::new(
LimitType::from_limit_rows(self.limit),
self.sort_desc.clone(),
)
});

self.build_merge_sort_pipeline(pipeline, false)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,6 @@ pub struct AggregatorParams {
pub enable_experimental_aggregate_hashtable: bool,
pub cluster_aggregator: bool,
pub max_block_size: usize,
// Limit is push down to AggregatorTransform
pub limit: Option<usize>,
pub max_spill_io_requests: usize,
}

Expand All @@ -59,7 +57,6 @@ impl AggregatorParams {
enable_experimental_aggregate_hashtable: bool,
cluster_aggregator: bool,
max_block_size: usize,
limit: Option<usize>,
max_spill_io_requests: usize,
) -> Result<Arc<AggregatorParams>> {
let mut states_offsets: Vec<usize> = Vec::with_capacity(agg_funcs.len());
Expand All @@ -80,7 +77,6 @@ impl AggregatorParams {
enable_experimental_aggregate_hashtable,
cluster_aggregator,
max_block_size,
limit,
max_spill_io_requests,
}))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ pub struct TransformFinalAggregate<Method: HashMethodBounds> {
method: Method,
params: Arc<AggregatorParams>,
flush_state: PayloadFlushState,
reach_limit: bool,
}

impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {
Expand All @@ -63,7 +62,6 @@ impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {
method,
params,
flush_state: PayloadFlushState::default(),
reach_limit: false,
},
))
}
Expand Down Expand Up @@ -124,23 +122,11 @@ impl<Method: HashMethodBounds> TransformFinalAggregate<Method> {
let mut blocks = vec![];
self.flush_state.clear();

let mut rows = 0;
loop {
if ht.merge_result(&mut self.flush_state)? {
let mut cols = self.flush_state.take_aggregate_results();
cols.extend_from_slice(&self.flush_state.take_group_columns());
rows += cols[0].len();
blocks.push(DataBlock::new_from_columns(cols));

if rows >= self.params.limit.unwrap_or(usize::MAX) {
log::info!(
"reach limit optimization in flush agg hashtable, current {}, total {}",
rows,
ht.len(),
);
self.reach_limit = true;
break;
}
} else {
break;
}
Expand All @@ -162,10 +148,6 @@ where Method: HashMethodBounds
const NAME: &'static str = "TransformFinalAggregate";

fn transform(&mut self, meta: AggregateMeta<Method, usize>) -> Result<Vec<DataBlock>> {
if self.reach_limit {
return Ok(vec![self.params.empty_result_block()]);
}

if self.params.enable_experimental_aggregate_hashtable {
return Ok(vec![self.transform_agg_hashtable(meta)?]);
}
Expand Down Expand Up @@ -196,32 +178,15 @@ where Method: HashMethodBounds
let (len, _) = keys_iter.size_hint();
let mut places = Vec::with_capacity(len);

let mut current_len = hash_cell.hashtable.len();
unsafe {
for key in keys_iter {
if self.reach_limit {
let entry = hash_cell.hashtable.entry(key);
if let Some(entry) = entry {
let place = Into::<StateAddr>::into(*entry.get());
places.push(place);
}
continue;
}

match hash_cell.hashtable.insert_and_entry(key) {
Ok(mut entry) => {
let place =
self.params.alloc_layout(&mut hash_cell.arena);
places.push(place);

*entry.get_mut() = place.addr();

if let Some(limit) = self.params.limit {
current_len += 1;
if current_len >= limit {
self.reach_limit = true;
}
}
}
Err(entry) => {
let place = Into::<StateAddr>::into(*entry.get());
Expand Down
Loading

0 comments on commit 1fac99d

Please sign in to comment.