Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix](hash join) fix stack overflow caused by evaluate case expr on huge build block (#28851) #28882

Merged
merged 1 commit into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion be/src/vec/columns/column_vector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,8 @@ ColumnPtr ColumnVector<T>::replicate(const IColumn::Offsets& offsets) const {
res_data.reserve(offsets.back());

// vectorized this code to speed up
IColumn::Offset counts[size];
auto counts_uptr = std::unique_ptr<IColumn::Offset[]>(new IColumn::Offset[size]);
IColumn::Offset* counts = counts_uptr.get();
for (ssize_t i = 0; i < size; ++i) {
counts[i] = offsets[i] - offsets[i - 1];
}
Expand Down
10 changes: 6 additions & 4 deletions be/src/vec/exec/join/vhash_join_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,10 @@ Status HashJoinNode::sink(doris::RuntimeState* state, vectorized::Block* in_bloc
_build_side_mem_used += in_block->allocated_bytes();

if (in_block->rows() != 0) {
_build_col_ids.resize(_build_expr_ctxs.size());
RETURN_IF_ERROR(_do_evaluate(*in_block, _build_expr_ctxs, *_build_expr_call_timer,
_build_col_ids));

SCOPED_TIMER(_build_side_merge_block_timer);
RETURN_IF_ERROR(_build_side_mutable_block.merge(*in_block));
}
Expand Down Expand Up @@ -1152,24 +1156,22 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block, uin
ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size());

ColumnUInt8::MutablePtr null_map_val;
std::vector<int> res_col_ids(_build_expr_ctxs.size());
RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, *_build_expr_call_timer, res_col_ids));
if (_join_op == TJoinOp::LEFT_OUTER_JOIN || _join_op == TJoinOp::FULL_OUTER_JOIN) {
_convert_block_to_null(block);
}
// TODO: Now we are not sure whether a column is nullable only by ExecNode's `row_desc`
// so we have to initialize this flag by the first build block.
if (!_has_set_need_null_map_for_build) {
_has_set_need_null_map_for_build = true;
_set_build_ignore_flag(block, res_col_ids);
_set_build_ignore_flag(block, _build_col_ids);
}
if (_short_circuit_for_null_in_build_side || _build_side_ignore_null) {
null_map_val = ColumnUInt8::create();
null_map_val->get_data().assign(rows, (uint8_t)0);
}

// Get the key column that needs to be built
Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs, res_col_ids);
Status st = _extract_join_column<true>(block, null_map_val, raw_ptrs, _build_col_ids);

st = std::visit(
Overload {
Expand Down
1 change: 1 addition & 0 deletions be/src/vec/exec/join/vhash_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ class HashJoinNode final : public VJoinNodeBase {

std::vector<IRuntimeFilter*> _runtime_filters;
size_t _build_bf_cardinality = 0;
std::vector<int> _build_col_ids;
};
} // namespace vectorized
} // namespace doris
21 changes: 11 additions & 10 deletions be/src/vec/functions/function_case.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ class FunctionCase : public IFunction {
int rows_count = column_holder.rows_count;

// `then` data index corresponding to each row of results, 0 represents `else`.
int then_idx[rows_count];
int* __restrict then_idx_ptr = then_idx;
memset(then_idx_ptr, 0, sizeof(then_idx));
auto then_idx_uptr = std::unique_ptr<int[]>(new int[rows_count]);
int* __restrict then_idx_ptr = then_idx_uptr.get();
memset(then_idx_ptr, 0, rows_count * sizeof(int));

for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) {
for (int i = 1; i < column_holder.pair_count; i++) {
Expand Down Expand Up @@ -198,7 +198,7 @@ class FunctionCase : public IFunction {
}

auto result_column_ptr = data_type->create_column();
update_result_normal<int, ColumnType, then_null>(result_column_ptr, then_idx,
update_result_normal<int, ColumnType, then_null>(result_column_ptr, then_idx_ptr,
column_holder);
block.replace_by_position(result, std::move(result_column_ptr));
return Status::OK();
Expand All @@ -215,9 +215,9 @@ class FunctionCase : public IFunction {
int rows_count = column_holder.rows_count;

// `then` data index corresponding to each row of results, 0 represents `else`.
uint8_t then_idx[rows_count];
uint8_t* __restrict then_idx_ptr = then_idx;
memset(then_idx_ptr, 0, sizeof(then_idx));
auto then_idx_uptr = std::unique_ptr<uint8_t[]>(new uint8_t[rows_count]);
uint8_t* __restrict then_idx_ptr = then_idx_uptr.get();
memset(then_idx_ptr, 0, rows_count);

auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr);

Expand Down Expand Up @@ -254,13 +254,13 @@ class FunctionCase : public IFunction {
}
}

return execute_update_result<ColumnType, then_null>(data_type, result, block, then_idx,
return execute_update_result<ColumnType, then_null>(data_type, result, block, then_idx_ptr,
column_holder);
}

template <typename ColumnType, bool then_null>
Status execute_update_result(const DataTypePtr& data_type, size_t result, Block& block,
uint8* then_idx, CaseWhenColumnHolder& column_holder) {
const uint8* then_idx, CaseWhenColumnHolder& column_holder) {
auto result_column_ptr = data_type->create_column();

if constexpr (std::is_same_v<ColumnType, ColumnString> ||
Expand All @@ -287,7 +287,8 @@ class FunctionCase : public IFunction {
}

template <typename IndexType, typename ColumnType, bool then_null>
void update_result_normal(MutableColumnPtr& result_column_ptr, IndexType* then_idx,
void update_result_normal(MutableColumnPtr& result_column_ptr,
const IndexType* __restrict then_idx,
CaseWhenColumnHolder& column_holder) {
std::vector<uint8_t> is_consts(column_holder.then_ptrs.size());
std::vector<ColumnPtr> raw_columns(column_holder.then_ptrs.size());
Expand Down
31 changes: 28 additions & 3 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,7 @@ class FunctionTrim : public IFunction {
}
};

static constexpr int MAX_STACK_CIPHER_LEN = 1024 * 64;
struct UnHexImpl {
static constexpr auto name = "unhex";
using ReturnType = DataTypeString;
Expand Down Expand Up @@ -652,8 +653,16 @@ struct UnHexImpl {
continue;
}

char dst_array[MAX_STACK_CIPHER_LEN];
char* dst = dst_array;

int cipher_len = srclen / 2;
char dst[cipher_len];
std::unique_ptr<char[]> dst_uptr;
if (cipher_len > MAX_STACK_CIPHER_LEN) {
dst_uptr.reset(new char[cipher_len]);
dst = dst_uptr.get();
}

int outlen = hex_decode(source, srclen, dst);

if (outlen < 0) {
Expand Down Expand Up @@ -723,8 +732,16 @@ struct ToBase64Impl {
continue;
}

char dst_array[MAX_STACK_CIPHER_LEN];
char* dst = dst_array;

int cipher_len = (int)(4.0 * ceil((double)srclen / 3.0));
char dst[cipher_len];
std::unique_ptr<char[]> dst_uptr;
if (cipher_len > MAX_STACK_CIPHER_LEN) {
dst_uptr.reset(new char[cipher_len]);
dst = dst_uptr.get();
}

int outlen = base64_encode((const unsigned char*)source, srclen, (unsigned char*)dst);

if (outlen < 0) {
Expand Down Expand Up @@ -763,8 +780,16 @@ struct FromBase64Impl {
continue;
}

char dst_array[MAX_STACK_CIPHER_LEN];
char* dst = dst_array;

int cipher_len = srclen;
char dst[cipher_len];
std::unique_ptr<char[]> dst_uptr;
if (cipher_len > MAX_STACK_CIPHER_LEN) {
dst_uptr.reset(new char[cipher_len]);
dst = dst_uptr.get();
}

int outlen = base64_decode(source, srclen, dst);

if (outlen < 0) {
Expand Down
3 changes: 2 additions & 1 deletion be/src/vec/functions/multiply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ struct MultiplyImpl {
static void vector_vector(const ColumnDecimal128::Container::value_type* __restrict a,
const ColumnDecimal128::Container::value_type* __restrict b,
ColumnDecimal128::Container::value_type* c, size_t size) {
int8 sgn[size];
auto sng_uptr = std::unique_ptr<int8[]>(new int8[size]);
int8* sgn = sng_uptr.get();
auto max = DecimalV2Value::get_max_decimal();
auto min = DecimalV2Value::get_min_decimal();

Expand Down
Loading