Skip to content

Commit

Permalink
[fix](window_func) fix window_funnel bug used in window function
Browse files Browse the repository at this point in the history
Add test cases for agg functions used in window function
  • Loading branch information
jacktengg committed Sep 17, 2024
1 parent 0cff225 commit 511e455
Show file tree
Hide file tree
Showing 14 changed files with 1,339 additions and 52 deletions.
3 changes: 0 additions & 3 deletions be/src/vec/aggregate_functions/aggregate_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ class IAggregateFunction {
virtual void deserialize_and_merge_from_column(AggregateDataPtr __restrict place,
const IColumn& column, Arena* arena) const = 0;

/// Returns true if a function requires Arena to handle own states (see add(), merge(), deserialize()).
virtual bool allocates_memory_in_arena() const { return false; }

/// Inserts results into a column.
virtual void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const = 0;

Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ struct AggregateFunctionBinary

String get_name() const override { return StatFunc::Data::name(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

DataTypePtr get_return_type() const override {
return std::make_shared<DataTypeNumber<ResultType>>();
}

bool allocates_memory_in_arena() const override { return false; }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
this->data(place).add(
Expand Down
2 changes: 0 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_collect.h
Original file line number Diff line number Diff line change
Expand Up @@ -626,8 +626,6 @@ class AggregateFunctionCollect
return std::make_shared<DataTypeArray>(make_nullable(return_type));
}

bool allocates_memory_in_arena() const override { return ENABLE_ARENA; }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
auto& data = this->data(place);
Expand Down
9 changes: 9 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_corr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,15 @@ struct CorrMoment {
}

static String name() { return "corr"; }

void reset() {
m0 = {};
x1 = {};
y1 = {};
xy = {};
x2 = {};
y2 = {};
}
};

AggregateFunctionPtr create_aggregate_corr_function(const std::string& name,
Expand Down
2 changes: 0 additions & 2 deletions be/src/vec/aggregate_functions/aggregate_function_distinct.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,8 +338,6 @@ class AggregateFunctionDistinct

DataTypePtr get_return_type() const override { return nested_func->get_return_type(); }

bool allocates_memory_in_arena() const override { return true; }

AggregateFunctionPtr transmit_to_stable() override {
return AggregateFunctionPtr(new AggregateFunctionDistinct<Data, true>(
nested_func, IAggregateFunction::argument_types));
Expand Down
4 changes: 0 additions & 4 deletions be/src/vec/aggregate_functions/aggregate_function_foreach.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,6 @@ class AggregateFunctionForEach : public IAggregateFunctionDataHelper<AggregateFu
offsets_to.push_back(offsets_to.back() + state.dynamic_array_size);
}

bool allocates_memory_in_arena() const override {
return nested_function->allocates_memory_in_arena();
}

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
std::vector<const IColumn*> nested(num_arguments);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,11 @@ struct AggregateFunctionGroupArrayIntersectData {
Set value;
bool init = false;

void reset() {
init = false;
value = std::make_unique<NullableNumericOrDateSetType>();
}

void process_col_data(auto& column_data, size_t offset, size_t arr_size, bool& init, Set& set) {
const bool is_column_data_nullable = column_data.is_nullable();

Expand Down Expand Up @@ -163,7 +168,7 @@ class AggregateFunctionGroupArrayIntersect

DataTypePtr get_return_type() const override { return argument_type; }

bool allocates_memory_in_arena() const override { return false; }
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
Expand Down Expand Up @@ -331,6 +336,11 @@ struct AggregateFunctionGroupArrayIntersectGenericData {
: value(std::make_unique<NullableStringSet>()) {}
Set value;
bool init = false;

void reset() {
init = false;
value = std::make_unique<NullableStringSet>();
}
};

/** Template parameter with true value should be used for columns that store their elements in memory continuously.
Expand All @@ -357,7 +367,7 @@ class AggregateFunctionGroupArrayIntersectGeneric

DataTypePtr get_return_type() const override { return input_data_type; }

bool allocates_memory_in_arena() const override { return true; }
void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena* arena) const override {
Expand Down
8 changes: 0 additions & 8 deletions be/src/vec/aggregate_functions/aggregate_function_null.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,6 @@ class AggregateFunctionNullBaseInline : public IAggregateFunctionHelper<Derived>
nested_function->insert_result_into(nested_place(place), to);
}
}

bool allocates_memory_in_arena() const override {
return nested_function->allocates_memory_in_arena();
}
};

/** There are two cases: for single argument and variadic.
Expand Down Expand Up @@ -329,10 +325,6 @@ class AggregateFunctionNullVariadicInline final
arena);
}

bool allocates_memory_in_arena() const override {
return this->nested_function->allocates_memory_in_arena();
}

private:
// The array length is fixed in the implementation of some aggregate functions.
// Therefore we choose 256 as the appropriate maximum length limit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,11 @@ struct AggOrthBitmapBaseData {
public:
using ColVecData = std::conditional_t<IsNumber<T>, ColumnVector<T>, ColumnString>;

void reset() {
bitmap = {};
first_init = true;
}

void add(const IColumn** columns, size_t row_num) {
const auto& bitmap_col =
assert_cast<const ColumnBitmap&, TypeCheckOnRelease::DISABLE>(*columns[0]);
Expand Down Expand Up @@ -99,6 +104,11 @@ struct AggOrthBitMapIntersect : public AggOrthBitmapBaseData<T> {

static DataTypePtr get_return_type() { return std::make_shared<DataTypeBitMap>(); }

void reset() {
AggOrthBitmapBaseData<T>::reset();
result.reset();
}

void merge(const AggOrthBitMapIntersect& rhs) {
if (rhs.first_init) {
return;
Expand Down Expand Up @@ -170,6 +180,11 @@ struct AggOrthBitMapIntersectCount : public AggOrthBitmapBaseData<T> {

static DataTypePtr get_return_type() { return std::make_shared<DataTypeInt64>(); }

void reset() {
AggOrthBitmapBaseData<T>::reset();
result = 0;
}

void merge(const AggOrthBitMapIntersectCount& rhs) {
if (rhs.first_init) {
return;
Expand Down Expand Up @@ -225,6 +240,11 @@ struct AggOrthBitmapExprCalBaseData {
}
}

void reset() {
bitmap_expr_cal = {};
first_init = true;
}

protected:
doris::BitmapExprCalculation bitmap_expr_cal;
bool first_init = true;
Expand Down Expand Up @@ -263,6 +283,11 @@ struct AggOrthBitMapExprCal : public AggOrthBitmapExprCalBaseData<T> {
->bitmap_expr_cal.bitmap_calculate());
}

void reset() {
AggOrthBitmapExprCalBaseData<T>::reset();
result.reset();
}

private:
BitmapValue result;
};
Expand Down Expand Up @@ -299,6 +324,11 @@ struct AggOrthBitMapExprCalCount : public AggOrthBitmapExprCalBaseData<T> {
->bitmap_expr_cal.bitmap_calculate_count());
}

void reset() {
AggOrthBitmapExprCalBaseData<T>::reset();
result = 0;
}

private:
int64_t result = 0;
};
Expand Down Expand Up @@ -330,6 +360,11 @@ struct OrthBitmapUnionCountData {
column.get_data().emplace_back(result ? result : value.cardinality());
}

void reset() {
value.reset();
result = 0;
}

private:
BitmapValue value;
int64_t result = 0;
Expand All @@ -347,6 +382,8 @@ class AggFunctionOrthBitmapFunc final

DataTypePtr get_return_type() const override { return Impl::get_return_type(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
this->data(place).init_add_key(columns, row_num, _argument_size);
Expand Down
4 changes: 4 additions & 0 deletions be/src/vec/aggregate_functions/aggregate_function_uniq.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ struct AggregateFunctionUniqExactData {
Set set;

static String get_name() { return "multi_distinct"; }

void reset() { set.clear(); }
};

namespace detail {
Expand Down Expand Up @@ -115,6 +117,8 @@ class AggregateFunctionUniq final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ struct AggregateFunctionUniqDistributeKeyData {

Set set;
UInt64 count = 0;

void reset() {
set.clear();
count = 0;
}
};

template <typename T, typename Data>
Expand All @@ -83,6 +88,8 @@ class AggregateFunctionUniqDistributeKey final

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeInt64>(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena*) const override {
detail::OneAdder<T, Data>::add(this->data(place), *columns[0], row_num);
Expand Down
52 changes: 23 additions & 29 deletions be/src/vec/aggregate_functions/aggregate_function_window_funnel.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,9 @@ struct WindowFunnelState {
bool enable_mode;
WindowFunnelMode window_funnel_mode;
mutable vectorized::MutableBlock mutable_block;
ColumnVector<NativeType>::Container* timestamp_column_data;
ColumnVector<NativeType>::Container* timestamp_column_data = nullptr;
std::vector<ColumnVector<UInt8>::Container*> event_columns_datas;
SortDescription sort_description {1};
bool sorted;

WindowFunnelState() {
event_count = 0;
Expand All @@ -97,20 +96,15 @@ struct WindowFunnelState {
sort_description[0].column_number = 0;
sort_description[0].direction = 1;
sort_description[0].nulls_direction = -1;
sorted = false;
}
WindowFunnelState(int arg_event_count) : WindowFunnelState() {
event_count = arg_event_count;
event_columns_datas.resize(event_count);
auto timestamp_column = ColumnVector<NativeType>::create();
timestamp_column_data =
&assert_cast<ColumnVector<NativeType>&>(*timestamp_column).get_data();

MutableColumns event_columns;
for (int i = 0; i < event_count; i++) {
auto event_column = ColumnVector<UInt8>::create();
event_columns_datas.emplace_back(
&assert_cast<ColumnVector<UInt8>&>(*event_column).get_data());
event_columns.emplace_back(std::move(event_column));
event_columns.emplace_back(ColumnVector<UInt8>::create());
}
Block tmp_block;
tmp_block.insert({std::move(timestamp_column),
Expand All @@ -122,15 +116,18 @@ struct WindowFunnelState {
}

mutable_block = MutableBlock(std::move(tmp_block));
_reset_columns_ptr();
}

void reset() {
window = 0;
mutable_block.clear();
timestamp_column_data = nullptr;
event_columns_datas.clear();
sorted = false;
void _reset_columns_ptr() {
auto& ts_column = mutable_block.get_column_by_position(0);
timestamp_column_data = &assert_cast<ColumnVector<NativeType>&>(*ts_column).get_data();
for (int i = 0; i != event_count; i++) {
auto& event_column = mutable_block.get_column_by_position(i + 1);
event_columns_datas[i] = &assert_cast<ColumnVector<UInt8>&>(*event_column).get_data();
}
}
void reset() { mutable_block.clear_column_data(); }

void add(const IColumn** arg_columns, ssize_t row_num, int64_t win, WindowFunnelMode mode) {
window = win;
Expand All @@ -146,26 +143,23 @@ struct WindowFunnelState {
}

void sort() {
if (sorted) {
return;
}

Block tmp_block = mutable_block.to_block();
auto block = tmp_block.clone_without_columns();
sort_block(tmp_block, block, sort_description, 0);
mutable_block = MutableBlock(std::move(block));
sorted = true;
mutable_block = std::move(block);
_reset_columns_ptr();
}

template <WindowFunnelMode WINDOW_FUNNEL_MODE>
int _match_event_list(size_t& start_row, size_t row_count,
const NativeType* timestamp_data) const {
int _match_event_list(size_t& start_row, size_t row_count) const {
int matched_count = 0;
DateValueType start_timestamp;
DateValueType end_timestamp;
TimeInterval interval(SECOND, window, false);

int column_idx = 1;

const NativeType* timestamp_data = timestamp_column_data->data();
const auto& first_event_column = mutable_block.get_column_by_position(column_idx);
const auto& first_event_data =
assert_cast<const ColumnVector<UInt8>&>(*first_event_column).get_data();
Expand Down Expand Up @@ -250,14 +244,9 @@ struct WindowFunnelState {
int _get_internal() const {
size_t start_row = 0;
int max_found_event_count = 0;
const auto& ts_column = mutable_block.get_column_by_position(0)->get_ptr();
const auto& timestamp_data =
assert_cast<const ColumnVector<NativeType>&>(*ts_column).get_data().data();

auto row_count = mutable_block.rows();
while (start_row < row_count) {
auto found_event_count =
_match_event_list<WINDOW_FUNNEL_MODE>(start_row, row_count, timestamp_data);
auto found_event_count = _match_event_list<WINDOW_FUNNEL_MODE>(start_row, row_count);
if (found_event_count == event_count) {
return found_event_count;
}
Expand Down Expand Up @@ -324,6 +313,7 @@ struct WindowFunnelState {
status = block.serialize(
5, &pblock, &uncompressed_bytes, &compressed_bytes,
segment_v2::CompressionTypePB::ZSTD); // ZSTD for better compression ratio
block.clear_column_data();
if (!status.ok()) {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string());
return;
Expand All @@ -336,6 +326,9 @@ struct WindowFunnelState {
auto data_bytes = buff.size();
write_var_uint(data_bytes, out);
out.write(buff.data(), data_bytes);

mutable_block = std::move(block);
const_cast<WindowFunnelState<TYPE_INDEX, NativeType>*>(this)->_reset_columns_ptr();
}

void read(BufferReadable& in) {
Expand Down Expand Up @@ -366,6 +359,7 @@ struct WindowFunnelState {
throw doris::Exception(ErrorCode::INTERNAL_ERROR, status.to_string());
}
mutable_block = MutableBlock(std::move(block));
_reset_columns_ptr();
}
};

Expand Down
Loading

0 comments on commit 511e455

Please sign in to comment.