From cee07d62ce6fcd9370f5789daec4571a09af41a4 Mon Sep 17 00:00:00 2001 From: Mryange <59914473+Mryange@users.noreply.github.com> Date: Thu, 19 Sep 2024 09:29:31 +0800 Subject: [PATCH] [fix](function) fix Substring/SubReplace error result with input utf8 string (#40929) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ``` mysql [(none)]>select sub_replace("你好世界","a",1); +-------------------------------------+ | sub_replace('你好世界', 'a', 1) | +-------------------------------------+ | �a�好世界 | +-------------------------------------+ mysql [(none)]>select SUBSTRING('中文测试',5); +------------------------------------------+ | substring('中文测试', 5, 2147483647) | +------------------------------------------+ | 中文测试 | +------------------------------------------+ 1 row in set (0.04 sec) now mysql [(none)]>select sub_replace("你好世界","a",1); +-------------------------------------+ | sub_replace('你好世界', 'a', 1) | +-------------------------------------+ | 你a世界 | +-------------------------------------+ 1 row in set (0.05 sec) mysql [(none)]>select SUBSTRING('中文测试',5); +------------------------------------------+ | substring('中文测试', 5, 2147483647) | +------------------------------------------+ | | +------------------------------------------+ 1 row in set (0.13 sec) ``` --- be/src/vec/functions/function_string.h | 132 +++++++++++++----- .../string_functions/test_string_function.out | 60 ++++++++ .../string_functions/test_string_function.out | Bin 4590 -> 4838 bytes .../test_string_function.groovy | 23 +++ .../test_string_function.groovy | 10 ++ 5 files changed, 188 insertions(+), 37 deletions(-) diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 53c300f50aa761..4ae8cbf5ff2402 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -242,9 +242,11 @@ struct SubstringUtil { const char* str_data = (char*)chars.data() + offsets[i - 1]; int start_value = is_const ? start[0] : start[i]; int len_value = is_const ? len[0] : len[i]; - + // Unsigned numbers cannot be used here because start_value can be negative. + int char_len = simd::VStringFunctions::get_char_len(str_data, str_size); // return empty string if start > src.length - if (start_value > str_size || str_size == 0 || start_value == 0 || len_value <= 0) { + // Here, start_value is compared against the length of the character. + if (start_value > char_len || str_size == 0 || start_value == 0 || len_value <= 0) { StringOP::push_empty_string(i, res_chars, res_offsets); continue; } @@ -3386,8 +3388,6 @@ class FunctionSubReplace : public IFunction { return get_variadic_argument_types_impl().size(); } - bool use_default_implementation_for_nulls() const override { return false; } - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { return Impl::execute_impl(context, block, arguments, result, input_rows_count); @@ -3398,59 +3398,116 @@ struct SubReplaceImpl { static Status replace_execute(Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { auto res_column = ColumnString::create(); - auto result_column = assert_cast(res_column.get()); + auto* result_column = assert_cast(res_column.get()); auto args_null_map = ColumnUInt8::create(input_rows_count, 0); ColumnPtr argument_columns[4]; + bool col_const[4]; for (int i = 0; i < 4; ++i) { - argument_columns[i] = - block.get_by_position(arguments[i]).column->convert_to_full_column_if_const(); - if (auto* nullable = check_and_get_column(*argument_columns[i])) { - // Danger: Here must dispose the null map data first! Because - // argument_columns[i]=nullable->get_nested_column_ptr(); will release the mem - // of column nullable mem of null map - VectorizedUtils::update_null_map(args_null_map->get_data(), - nullable->get_null_map_data()); - argument_columns[i] = nullable->get_nested_column_ptr(); - } + std::tie(argument_columns[i], col_const[i]) = + unpack_if_const(block.get_by_position(arguments[i]).column); } + const auto* data_column = assert_cast(argument_columns[0].get()); + const auto* mask_column = assert_cast(argument_columns[1].get()); + const auto* start_column = + assert_cast*>(argument_columns[2].get()); + const auto* length_column = + assert_cast*>(argument_columns[3].get()); - auto data_column = assert_cast(argument_columns[0].get()); - auto mask_column = assert_cast(argument_columns[1].get()); - auto start_column = assert_cast*>(argument_columns[2].get()); - auto length_column = assert_cast*>(argument_columns[3].get()); - - vector(data_column, mask_column, start_column->get_data(), length_column->get_data(), - args_null_map->get_data(), result_column, input_rows_count); - + std::visit( + [&](auto origin_str_const, auto new_str_const, auto start_const, auto len_const) { + if (simd::VStringFunctions::is_ascii( + StringRef {data_column->get_chars().data(), data_column->size()})) { + vector_ascii( + data_column, mask_column, start_column->get_data(), + length_column->get_data(), args_null_map->get_data(), result_column, + input_rows_count); + } else { + vector_utf8( + data_column, mask_column, start_column->get_data(), + length_column->get_data(), args_null_map->get_data(), result_column, + input_rows_count); + } + }, + vectorized::make_bool_variant(col_const[0]), + vectorized::make_bool_variant(col_const[1]), + vectorized::make_bool_variant(col_const[2]), + vectorized::make_bool_variant(col_const[3])); block.get_by_position(result).column = ColumnNullable::create(std::move(res_column), std::move(args_null_map)); return Status::OK(); } private: - static void vector(const ColumnString* data_column, const ColumnString* mask_column, - const PaddedPODArray& start, const PaddedPODArray& length, - NullMap& args_null_map, ColumnString* result_column, - size_t input_rows_count) { + template + static void vector_ascii(const ColumnString* data_column, const ColumnString* mask_column, + const PaddedPODArray& args_start, + const PaddedPODArray& args_length, NullMap& args_null_map, + ColumnString* result_column, size_t input_rows_count) { ColumnString::Chars& res_chars = result_column->get_chars(); ColumnString::Offsets& res_offsets = result_column->get_offsets(); for (size_t row = 0; row < input_rows_count; ++row) { - StringRef origin_str = data_column->get_data_at(row); - StringRef new_str = mask_column->get_data_at(row); - size_t origin_str_len = origin_str.size; + StringRef origin_str = + data_column->get_data_at(index_check_const(row)); + StringRef new_str = mask_column->get_data_at(index_check_const(row)); + const auto start = args_start[index_check_const(row)]; + const auto length = args_length[index_check_const(row)]; + const size_t origin_str_len = origin_str.size; //input is null, start < 0, len < 0, str_size <= start. return NULL - if (args_null_map[row] || start[row] < 0 || length[row] < 0 || - origin_str_len <= start[row]) { + if (args_null_map[row] || start < 0 || length < 0 || origin_str_len <= start) { res_offsets.push_back(res_chars.size()); args_null_map[row] = 1; } else { std::string_view replace_str = new_str.to_string_view(); std::string result = origin_str.to_string(); - result.replace(start[row], length[row], replace_str); + result.replace(start, length, replace_str); result_column->insert_data(result.data(), result.length()); } } } + + template + static void vector_utf8(const ColumnString* data_column, const ColumnString* mask_column, + const PaddedPODArray& args_start, + const PaddedPODArray& args_length, NullMap& args_null_map, + ColumnString* result_column, size_t input_rows_count) { + ColumnString::Chars& res_chars = result_column->get_chars(); + ColumnString::Offsets& res_offsets = result_column->get_offsets(); + + for (size_t row = 0; row < input_rows_count; ++row) { + StringRef origin_str = + data_column->get_data_at(index_check_const(row)); + StringRef new_str = mask_column->get_data_at(index_check_const(row)); + const auto start = args_start[index_check_const(row)]; + const auto length = args_length[index_check_const(row)]; + //input is null, start < 0, len < 0 return NULL + if (args_null_map[row] || start < 0 || length < 0) { + res_offsets.push_back(res_chars.size()); + args_null_map[row] = 1; + continue; + } + + const auto [start_byte_len, start_char_len] = + simd::VStringFunctions::iterate_utf8_with_limit_length(origin_str.begin(), + origin_str.end(), start); + + // start >= orgin.size + DCHECK(start_char_len <= start); + if (start_byte_len == origin_str.size) { + res_offsets.push_back(res_chars.size()); + args_null_map[row] = 1; + continue; + } + + auto [end_byte_len, end_char_len] = + simd::VStringFunctions::iterate_utf8_with_limit_length( + origin_str.begin() + start_byte_len, origin_str.end(), length); + DCHECK(end_char_len <= length); + std::string_view replace_str = new_str.to_string_view(); + std::string result = origin_str.to_string(); + result.replace(start_byte_len, end_byte_len, replace_str); + result_column->insert_data(result.data(), result.length()); + } + } }; struct SubReplaceThreeImpl { @@ -3467,13 +3524,14 @@ struct SubReplaceThreeImpl { auto str_col = block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); - if (auto* nullable = check_and_get_column(*str_col)) { + if (const auto* nullable = check_and_get_column(*str_col)) { str_col = nullable->get_nested_column_ptr(); } - auto& str_offset = assert_cast(str_col.get())->get_offsets(); - + const auto* str_column = assert_cast(str_col.get()); + // use utf8 len for (int i = 0; i < input_rows_count; ++i) { - strlen_data[i] = str_offset[i] - str_offset[i - 1]; + StringRef str_ref = str_column->get_data_at(i); + strlen_data[i] = simd::VStringFunctions::get_char_len(str_ref.data, str_ref.size); } block.insert({std::move(params), std::make_shared(), "strlen"}); diff --git a/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out b/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out index e8305c284ff520..d85794989f7de0 100644 --- a/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out +++ b/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out @@ -386,3 +386,63 @@ tNEW-STRorigin str -- !sql -- d***is +-- !sub_replace_utf8_sql1 -- +你a世界 + +-- !sub_replace_utf8_sql2 -- +你ab界 + +-- !sub_replace_utf8_sql3 -- +你ab + +-- !sub_replace_utf8_sql4 -- +你abcd我界 + +-- !sub_replace_utf8_sql5 -- +\N + +-- !sub_replace_utf8_sql6 -- +大家世界 + +-- !sub_replace_utf8_sql7 -- +你大家114514 + +-- !sub_replace_utf8_sql8 -- +\N + +-- !sub_replace_utf8_sql9 -- +\N + +-- !sub_replace_utf8_sql10 -- +\N + +-- !sub_replace_utf8_sql1 -- +你a世界 + +-- !sub_replace_utf8_sql2 -- +你ab界 + +-- !sub_replace_utf8_sql3 -- +你ab + +-- !sub_replace_utf8_sql4 -- +你abcd我界 + +-- !sub_replace_utf8_sql5 -- +\N + +-- !sub_replace_utf8_sql6 -- +大家世界 + +-- !sub_replace_utf8_sql7 -- +你大家114514 + +-- !sub_replace_utf8_sql8 -- +\N + +-- !sub_replace_utf8_sql9 -- +\N + +-- !sub_replace_utf8_sql10 -- +\N + diff --git a/regression-test/data/query_p0/sql_functions/string_functions/test_string_function.out b/regression-test/data/query_p0/sql_functions/string_functions/test_string_function.out index dfcf50a244b48a4ae186880bad82e11312b2e4cb..cadf5039794dd87ff4e1af559db1ccb3172a8b23 100644 GIT binary patch delta 268 zcmaE-{7iL2K39EdQgKO9W?p)HX-S$zd~sopg03zX7niQCf?_dNxfkoF;*)+ht^L{7 YZhR_G45%M4%L{Yjfgapk%2mk@05V-(#sB~S delta 12 TcmaE+`c8R6KG)`S?s|3rCHMsK diff --git a/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy b/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy index 20c8294b1144ef..6e9cd947bc2ed5 100644 --- a/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy +++ b/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy @@ -191,4 +191,27 @@ suite("test_string_function") { qt_sql "select sub_replace(\"this is origin str\",\"NEW-STR\",1);" qt_sql "select sub_replace(\"doris\",\"***\",1,2);" + sql """ set debug_skip_fold_constant = true;""" + qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);" + qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);" + qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);" + qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);" + qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);" + qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);" + qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);" + qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);" + qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);" + qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);" + sql """ set debug_skip_fold_constant = false;""" + qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);" + qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);" + qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);" + qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);" + qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);" + qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);" + qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);" + qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);" + qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);" + qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);" + } diff --git a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy index b71d339a5387a6..6e18fb57eeb4cf 100644 --- a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy @@ -228,6 +228,16 @@ suite("test_string_function", "arrow_flight_sql") { qt_sql "select substring('abcdef',3,-1);" qt_sql "select substring('abcdef',-3,-1);" qt_sql "select substring('abcdef',10,1);" + sql """ set debug_skip_fold_constant = true;""" + qt_substring_utf8_sql "select substring('中文测试',5);" + qt_substring_utf8_sql "select substring('中文测试',4);" + qt_substring_utf8_sql "select substring('中文测试',2,2);" + qt_substring_utf8_sql "select substring('中文测试',-1,2);" + sql """ set debug_skip_fold_constant = false;""" + qt_substring_utf8_sql "select substring('中文测试',5);" + qt_substring_utf8_sql "select substring('中文测试',4);" + qt_substring_utf8_sql "select substring('中文测试',2,2);" + qt_substring_utf8_sql "select substring('中文测试',-1,2);" sql """ drop table if exists test_string_function; """ sql """ create table test_string_function (