From d75300f166acdd4c9ca0a5d662472418d4272e95 Mon Sep 17 00:00:00 2001 From: TengJianPing <18241664+jacktengg@users.noreply.github.com> Date: Fri, 22 Dec 2023 15:45:12 +0800 Subject: [PATCH] [fix](hash join) fix stack overflow caused by evaluate case expr on huge build block (#28851) --- be/src/pipeline/exec/hashjoin_build_sink.cpp | 18 +++++++---- be/src/pipeline/exec/hashjoin_build_sink.h | 1 + be/src/vec/columns/column_vector.cpp | 3 +- be/src/vec/exec/join/vhash_join_node.cpp | 16 ++++++---- be/src/vec/exec/join/vhash_join_node.h | 1 + .../functions/function_binary_arithmetic.h | 5 ++-- be/src/vec/functions/function_case.h | 21 ++++++------- be/src/vec/functions/function_string.cpp | 30 +++++++++++++++++-- be/src/vec/functions/multiply.cpp | 3 +- 9 files changed, 69 insertions(+), 29 deletions(-) diff --git a/be/src/pipeline/exec/hashjoin_build_sink.cpp b/be/src/pipeline/exec/hashjoin_build_sink.cpp index c3238df035edff..8cd0376a957fa7 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.cpp +++ b/be/src/pipeline/exec/hashjoin_build_sink.cpp @@ -230,8 +230,6 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state, vectorized::ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size()); vectorized::ColumnUInt8::MutablePtr null_map_val; - std::vector res_col_ids(_build_expr_ctxs.size()); - RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, *_build_expr_call_timer, res_col_ids)); if (p._join_op == TJoinOp::LEFT_OUTER_JOIN || p._join_op == TJoinOp::FULL_OUTER_JOIN) { _convert_block_to_null(block); // first row is mocked @@ -247,7 +245,7 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state, // so we have to initialize this flag by the first build block. if (!_has_set_need_null_map_for_build) { _has_set_need_null_map_for_build = true; - _set_build_ignore_flag(block, res_col_ids); + _set_build_ignore_flag(block, _build_col_ids); } if (p._short_circuit_for_null_in_build_side || _build_side_ignore_null) { null_map_val = vectorized::ColumnUInt8::create(); @@ -255,7 +253,7 @@ Status HashJoinBuildSinkLocalState::process_build_block(RuntimeState* state, } // Get the key column that needs to be built - Status st = _extract_join_column(block, null_map_val, raw_ptrs, res_col_ids); + Status st = _extract_join_column(block, null_map_val, raw_ptrs, _build_col_ids); st = std::visit( Overload {[&](std::monostate& arg, auto join_op, auto has_null_value, @@ -458,13 +456,21 @@ Status HashJoinBuildSinkOperatorX::sink(RuntimeState* state, vectorized::Block* if (local_state._build_side_mutable_block.empty()) { auto tmp_build_block = vectorized::VectorizedUtils::create_empty_columnswithtypename( _child_x->row_desc()); + tmp_build_block = *(tmp_build_block.create_same_struct_block(1, false)); + local_state._build_col_ids.resize(_build_expr_ctxs.size()); + RETURN_IF_ERROR(local_state._do_evaluate(tmp_build_block, local_state._build_expr_ctxs, + *local_state._build_expr_call_timer, + local_state._build_col_ids)); local_state._build_side_mutable_block = vectorized::MutableBlock::build_mutable_block(&tmp_build_block); - RETURN_IF_ERROR(local_state._build_side_mutable_block.merge( - *(tmp_build_block.create_same_struct_block(1, false)))); } if (in_block->rows() != 0) { + std::vector res_col_ids(_build_expr_ctxs.size()); + RETURN_IF_ERROR(local_state._do_evaluate(*in_block, local_state._build_expr_ctxs, + *local_state._build_expr_call_timer, + res_col_ids)); + SCOPED_TIMER(local_state._build_side_merge_block_timer); RETURN_IF_ERROR(local_state._build_side_mutable_block.merge(*in_block)); if (local_state._build_side_mutable_block.rows() > diff --git a/be/src/pipeline/exec/hashjoin_build_sink.h b/be/src/pipeline/exec/hashjoin_build_sink.h index b2fbec8575ea6c..ecf0a4a31228f8 100644 --- a/be/src/pipeline/exec/hashjoin_build_sink.h +++ b/be/src/pipeline/exec/hashjoin_build_sink.h @@ -116,6 +116,7 @@ class HashJoinBuildSinkLocalState final bool _build_side_ignore_null = false; std::unordered_set _inserted_blocks; std::shared_ptr _shared_hash_table_dependency; + std::vector _build_col_ids; RuntimeProfile::Counter* _build_table_timer = nullptr; RuntimeProfile::Counter* _build_expr_call_timer = nullptr; diff --git a/be/src/vec/columns/column_vector.cpp b/be/src/vec/columns/column_vector.cpp index 45d9e8f70b0aa6..ca8db58fc98f2d 100644 --- a/be/src/vec/columns/column_vector.cpp +++ b/be/src/vec/columns/column_vector.cpp @@ -524,7 +524,8 @@ ColumnPtr ColumnVector::replicate(const IColumn::Offsets& offsets) const { res_data.reserve(offsets.back()); // vectorized this code to speed up - IColumn::Offset counts[size]; + auto counts_uptr = std::unique_ptr(new IColumn::Offset[size]); + IColumn::Offset* counts = counts_uptr.get(); for (ssize_t i = 0; i < size; ++i) { counts[i] = offsets[i] - offsets[i - 1]; } diff --git a/be/src/vec/exec/join/vhash_join_node.cpp b/be/src/vec/exec/join/vhash_join_node.cpp index 1202228ae85099..30f2450f458c11 100644 --- a/be/src/vec/exec/join/vhash_join_node.cpp +++ b/be/src/vec/exec/join/vhash_join_node.cpp @@ -725,12 +725,18 @@ Status HashJoinNode::sink(doris::RuntimeState* state, vectorized::Block* in_bloc if (_build_side_mutable_block.empty()) { auto tmp_build_block = VectorizedUtils::create_empty_columnswithtypename(child(1)->row_desc()); + tmp_build_block = *(tmp_build_block.create_same_struct_block(1, false)); + _build_col_ids.resize(_build_expr_ctxs.size()); + RETURN_IF_ERROR(_do_evaluate(tmp_build_block, _build_expr_ctxs, *_build_expr_call_timer, + _build_col_ids)); _build_side_mutable_block = MutableBlock::build_mutable_block(&tmp_build_block); - RETURN_IF_ERROR(_build_side_mutable_block.merge( - *(tmp_build_block.create_same_struct_block(1, false)))); } if (in_block->rows() != 0) { + std::vector res_col_ids(_build_expr_ctxs.size()); + RETURN_IF_ERROR(_do_evaluate(*in_block, _build_expr_ctxs, *_build_expr_call_timer, + res_col_ids)); + SCOPED_TIMER(_build_side_merge_block_timer); RETURN_IF_ERROR(_build_side_mutable_block.merge(*in_block)); if (_build_side_mutable_block.rows() > JOIN_BUILD_SIZE_LIMIT) { @@ -952,8 +958,6 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block) { ColumnRawPtrs raw_ptrs(_build_expr_ctxs.size()); ColumnUInt8::MutablePtr null_map_val; - std::vector res_col_ids(_build_expr_ctxs.size()); - RETURN_IF_ERROR(_do_evaluate(block, _build_expr_ctxs, *_build_expr_call_timer, res_col_ids)); if (_join_op == TJoinOp::LEFT_OUTER_JOIN || _join_op == TJoinOp::FULL_OUTER_JOIN) { _convert_block_to_null(block); // first row is mocked @@ -969,7 +973,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block) { // so we have to initialize this flag by the first build block. if (!_has_set_need_null_map_for_build) { _has_set_need_null_map_for_build = true; - _set_build_ignore_flag(block, res_col_ids); + _set_build_ignore_flag(block, _build_col_ids); } if (_short_circuit_for_null_in_build_side || _build_side_ignore_null) { null_map_val = ColumnUInt8::create(); @@ -977,7 +981,7 @@ Status HashJoinNode::_process_build_block(RuntimeState* state, Block& block) { } // Get the key column that needs to be built - Status st = _extract_join_column(block, null_map_val, raw_ptrs, res_col_ids); + Status st = _extract_join_column(block, null_map_val, raw_ptrs, _build_col_ids); st = std::visit( Overload {[&](std::monostate& arg, auto join_op, auto has_null_value, diff --git a/be/src/vec/exec/join/vhash_join_node.h b/be/src/vec/exec/join/vhash_join_node.h index 64f07af6504205..8304eedb2901d5 100644 --- a/be/src/vec/exec/join/vhash_join_node.h +++ b/be/src/vec/exec/join/vhash_join_node.h @@ -451,6 +451,7 @@ class HashJoinNode final : public VJoinNodeBase { std::vector _runtime_filters; std::atomic_bool _probe_open_finish = false; + std::vector _build_col_ids; }; } // namespace vectorized } // namespace doris diff --git a/be/src/vec/functions/function_binary_arithmetic.h b/be/src/vec/functions/function_binary_arithmetic.h index 30ede75ea17fe1..4b69561b14e9e4 100644 --- a/be/src/vec/functions/function_binary_arithmetic.h +++ b/be/src/vec/functions/function_binary_arithmetic.h @@ -265,7 +265,8 @@ struct DecimalBinaryOperation { make_bool_variant(need_adjust_scale && check_overflow)); if (OpTraits::is_multiply && need_adjust_scale && !check_overflow) { - int8_t sig[size]; + auto sig_uptr = std::unique_ptr(new int8_t[size]); + int8_t* sig = sig_uptr.get(); for (size_t i = 0; i < size; i++) { sig[i] = sgn(c[i].value); } @@ -917,7 +918,7 @@ class FunctionBinaryArithmetic : public IFunction { if constexpr (!std::is_same_v) { need_replace_null_data_to_default_ = IsDataTypeDecimal || - (name == "pow" && + (get_name() == "pow" && std::is_floating_point_v); if constexpr (IsDataTypeDecimal && IsDataTypeDecimal) { diff --git a/be/src/vec/functions/function_case.h b/be/src/vec/functions/function_case.h index 2ecc6bd186d29d..26e12e7bd1307c 100644 --- a/be/src/vec/functions/function_case.h +++ b/be/src/vec/functions/function_case.h @@ -159,9 +159,9 @@ class FunctionCase : public IFunction { int rows_count = column_holder.rows_count; // `then` data index corresponding to each row of results, 0 represents `else`. - int then_idx[rows_count]; - int* __restrict then_idx_ptr = then_idx; - memset(then_idx_ptr, 0, sizeof(then_idx)); + auto then_idx_uptr = std::unique_ptr(new int[rows_count]); + int* __restrict then_idx_ptr = then_idx_uptr.get(); + memset(then_idx_ptr, 0, rows_count * sizeof(int)); for (int row_idx = 0; row_idx < column_holder.rows_count; row_idx++) { for (int i = 1; i < column_holder.pair_count; i++) { @@ -189,7 +189,7 @@ class FunctionCase : public IFunction { } auto result_column_ptr = data_type->create_column(); - update_result_normal(result_column_ptr, then_idx, + update_result_normal(result_column_ptr, then_idx_ptr, column_holder); block.replace_by_position(result, std::move(result_column_ptr)); return Status::OK(); @@ -206,9 +206,9 @@ class FunctionCase : public IFunction { int rows_count = column_holder.rows_count; // `then` data index corresponding to each row of results, 0 represents `else`. - uint8_t then_idx[rows_count]; - uint8_t* __restrict then_idx_ptr = then_idx; - memset(then_idx_ptr, 0, sizeof(then_idx)); + auto then_idx_uptr = std::unique_ptr(new uint8_t[rows_count]); + uint8_t* __restrict then_idx_ptr = then_idx_uptr.get(); + memset(then_idx_ptr, 0, rows_count); auto case_column_ptr = column_holder.when_ptrs[0].value_or(nullptr); @@ -245,13 +245,13 @@ class FunctionCase : public IFunction { } } - return execute_update_result(data_type, result, block, then_idx, + return execute_update_result(data_type, result, block, then_idx_ptr, column_holder); } template Status execute_update_result(const DataTypePtr& data_type, size_t result, Block& block, - uint8* then_idx, CaseWhenColumnHolder& column_holder) const { + const uint8* then_idx, CaseWhenColumnHolder& column_holder) const { auto result_column_ptr = data_type->create_column(); if constexpr (std::is_same_v || @@ -282,7 +282,8 @@ class FunctionCase : public IFunction { } template - void update_result_normal(MutableColumnPtr& result_column_ptr, IndexType* then_idx, + void update_result_normal(MutableColumnPtr& result_column_ptr, + const IndexType* __restrict then_idx, CaseWhenColumnHolder& column_holder) const { std::vector is_consts(column_holder.then_ptrs.size()); std::vector raw_columns(column_holder.then_ptrs.size()); diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 6179d64e47d7db..7b4e043efe67a1 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -582,6 +582,7 @@ class FunctionTrim : public IFunction { } }; +static constexpr int MAX_STACK_CIPHER_LEN = 1024 * 64; struct UnHexImpl { static constexpr auto name = "unhex"; using ReturnType = DataTypeString; @@ -654,8 +655,16 @@ struct UnHexImpl { continue; } + char dst_array[MAX_STACK_CIPHER_LEN]; + char* dst = dst_array; + int cipher_len = srclen / 2; - char dst[cipher_len]; + std::unique_ptr dst_uptr; + if (cipher_len > MAX_STACK_CIPHER_LEN) { + dst_uptr.reset(new char[cipher_len]); + dst = dst_uptr.get(); + } + int outlen = hex_decode(source, srclen, dst); if (outlen < 0) { @@ -725,8 +734,16 @@ struct ToBase64Impl { continue; } + char dst_array[MAX_STACK_CIPHER_LEN]; + char* dst = dst_array; + int cipher_len = (int)(4.0 * ceil((double)srclen / 3.0)); - char dst[cipher_len]; + std::unique_ptr dst_uptr; + if (cipher_len > MAX_STACK_CIPHER_LEN) { + dst_uptr.reset(new char[cipher_len]); + dst = dst_uptr.get(); + } + int outlen = base64_encode((const unsigned char*)source, srclen, (unsigned char*)dst); if (outlen < 0) { @@ -765,8 +782,15 @@ struct FromBase64Impl { continue; } + char dst_array[MAX_STACK_CIPHER_LEN]; + char* dst = dst_array; + int cipher_len = srclen; - char dst[cipher_len]; + std::unique_ptr dst_uptr; + if (cipher_len > MAX_STACK_CIPHER_LEN) { + dst_uptr.reset(new char[cipher_len]); + dst = dst_uptr.get(); + } int outlen = base64_decode(source, srclen, dst); if (outlen < 0) { diff --git a/be/src/vec/functions/multiply.cpp b/be/src/vec/functions/multiply.cpp index 79653910991d86..0dc9f4a410c9d7 100644 --- a/be/src/vec/functions/multiply.cpp +++ b/be/src/vec/functions/multiply.cpp @@ -56,7 +56,8 @@ struct MultiplyImpl { static void vector_vector(const ColumnDecimal128::Container::value_type* __restrict a, const ColumnDecimal128::Container::value_type* __restrict b, ColumnDecimal128::Container::value_type* c, size_t size) { - int8 sgn[size]; + auto sng_uptr = std::unique_ptr(new int8[size]); + int8* sgn = sng_uptr.get(); auto max = DecimalV2Value::get_max_decimal(); auto min = DecimalV2Value::get_min_decimal();