Skip to content

Commit

Permalink
[fix](function) fix Substring/SubReplace error result with input utf8…
Browse files Browse the repository at this point in the history
… string (#40929)

```

mysql [(none)]>select sub_replace("你好世界","a",1);
+-------------------------------------+
| sub_replace('你好世界', 'a', 1)     |
+-------------------------------------+
| �a�好世界                             |
+-------------------------------------+



mysql [(none)]>select SUBSTRING('中文测试',5);
+------------------------------------------+
| substring('中文测试', 5, 2147483647)     |
+------------------------------------------+
| 中文测试                                 |
+------------------------------------------+
1 row in set (0.04 sec)



now
mysql [(none)]>select sub_replace("你好世界","a",1);
+-------------------------------------+
| sub_replace('你好世界', 'a', 1)     |
+-------------------------------------+
| 你a世界                             |
+-------------------------------------+
1 row in set (0.05 sec)

mysql [(none)]>select SUBSTRING('中文测试',5);
+------------------------------------------+
| substring('中文测试', 5, 2147483647)     |
+------------------------------------------+
|                                          |
+------------------------------------------+
1 row in set (0.13 sec)
```
  • Loading branch information
Mryange authored Sep 19, 2024
1 parent 538817a commit cee07d6
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 37 deletions.
132 changes: 95 additions & 37 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -242,9 +242,11 @@ struct SubstringUtil {
const char* str_data = (char*)chars.data() + offsets[i - 1];
int start_value = is_const ? start[0] : start[i];
int len_value = is_const ? len[0] : len[i];

// Unsigned numbers cannot be used here because start_value can be negative.
int char_len = simd::VStringFunctions::get_char_len(str_data, str_size);
// return empty string if start > src.length
if (start_value > str_size || str_size == 0 || start_value == 0 || len_value <= 0) {
// Here, start_value is compared against the length of the character.
if (start_value > char_len || str_size == 0 || start_value == 0 || len_value <= 0) {
StringOP::push_empty_string(i, res_chars, res_offsets);
continue;
}
Expand Down Expand Up @@ -3386,8 +3388,6 @@ class FunctionSubReplace : public IFunction {
return get_variadic_argument_types_impl().size();
}

bool use_default_implementation_for_nulls() const override { return false; }

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
return Impl::execute_impl(context, block, arguments, result, input_rows_count);
Expand All @@ -3398,59 +3398,116 @@ struct SubReplaceImpl {
static Status replace_execute(Block& block, const ColumnNumbers& arguments, size_t result,
size_t input_rows_count) {
auto res_column = ColumnString::create();
auto result_column = assert_cast<ColumnString*>(res_column.get());
auto* result_column = assert_cast<ColumnString*>(res_column.get());
auto args_null_map = ColumnUInt8::create(input_rows_count, 0);
ColumnPtr argument_columns[4];
bool col_const[4];
for (int i = 0; i < 4; ++i) {
argument_columns[i] =
block.get_by_position(arguments[i]).column->convert_to_full_column_if_const();
if (auto* nullable = check_and_get_column<ColumnNullable>(*argument_columns[i])) {
// Danger: Here must dispose the null map data first! Because
// argument_columns[i]=nullable->get_nested_column_ptr(); will release the mem
// of column nullable mem of null map
VectorizedUtils::update_null_map(args_null_map->get_data(),
nullable->get_null_map_data());
argument_columns[i] = nullable->get_nested_column_ptr();
}
std::tie(argument_columns[i], col_const[i]) =
unpack_if_const(block.get_by_position(arguments[i]).column);
}
const auto* data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
const auto* mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
const auto* start_column =
assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
const auto* length_column =
assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());

auto data_column = assert_cast<const ColumnString*>(argument_columns[0].get());
auto mask_column = assert_cast<const ColumnString*>(argument_columns[1].get());
auto start_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[2].get());
auto length_column = assert_cast<const ColumnVector<Int32>*>(argument_columns[3].get());

vector(data_column, mask_column, start_column->get_data(), length_column->get_data(),
args_null_map->get_data(), result_column, input_rows_count);

std::visit(
[&](auto origin_str_const, auto new_str_const, auto start_const, auto len_const) {
if (simd::VStringFunctions::is_ascii(
StringRef {data_column->get_chars().data(), data_column->size()})) {
vector_ascii<origin_str_const, new_str_const, start_const, len_const>(
data_column, mask_column, start_column->get_data(),
length_column->get_data(), args_null_map->get_data(), result_column,
input_rows_count);
} else {
vector_utf8<origin_str_const, new_str_const, start_const, len_const>(
data_column, mask_column, start_column->get_data(),
length_column->get_data(), args_null_map->get_data(), result_column,
input_rows_count);
}
},
vectorized::make_bool_variant(col_const[0]),
vectorized::make_bool_variant(col_const[1]),
vectorized::make_bool_variant(col_const[2]),
vectorized::make_bool_variant(col_const[3]));
block.get_by_position(result).column =
ColumnNullable::create(std::move(res_column), std::move(args_null_map));
return Status::OK();
}

private:
static void vector(const ColumnString* data_column, const ColumnString* mask_column,
const PaddedPODArray<Int32>& start, const PaddedPODArray<Int32>& length,
NullMap& args_null_map, ColumnString* result_column,
size_t input_rows_count) {
template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
static void vector_ascii(const ColumnString* data_column, const ColumnString* mask_column,
const PaddedPODArray<Int32>& args_start,
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
ColumnString* result_column, size_t input_rows_count) {
ColumnString::Chars& res_chars = result_column->get_chars();
ColumnString::Offsets& res_offsets = result_column->get_offsets();
for (size_t row = 0; row < input_rows_count; ++row) {
StringRef origin_str = data_column->get_data_at(row);
StringRef new_str = mask_column->get_data_at(row);
size_t origin_str_len = origin_str.size;
StringRef origin_str =
data_column->get_data_at(index_check_const<origin_str_const>(row));
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
const auto start = args_start[index_check_const<start_const>(row)];
const auto length = args_length[index_check_const<len_const>(row)];
const size_t origin_str_len = origin_str.size;
//input is null, start < 0, len < 0, str_size <= start. return NULL
if (args_null_map[row] || start[row] < 0 || length[row] < 0 ||
origin_str_len <= start[row]) {
if (args_null_map[row] || start < 0 || length < 0 || origin_str_len <= start) {
res_offsets.push_back(res_chars.size());
args_null_map[row] = 1;
} else {
std::string_view replace_str = new_str.to_string_view();
std::string result = origin_str.to_string();
result.replace(start[row], length[row], replace_str);
result.replace(start, length, replace_str);
result_column->insert_data(result.data(), result.length());
}
}
}

template <bool origin_str_const, bool new_str_const, bool start_const, bool len_const>
static void vector_utf8(const ColumnString* data_column, const ColumnString* mask_column,
const PaddedPODArray<Int32>& args_start,
const PaddedPODArray<Int32>& args_length, NullMap& args_null_map,
ColumnString* result_column, size_t input_rows_count) {
ColumnString::Chars& res_chars = result_column->get_chars();
ColumnString::Offsets& res_offsets = result_column->get_offsets();

for (size_t row = 0; row < input_rows_count; ++row) {
StringRef origin_str =
data_column->get_data_at(index_check_const<origin_str_const>(row));
StringRef new_str = mask_column->get_data_at(index_check_const<new_str_const>(row));
const auto start = args_start[index_check_const<start_const>(row)];
const auto length = args_length[index_check_const<len_const>(row)];
//input is null, start < 0, len < 0 return NULL
if (args_null_map[row] || start < 0 || length < 0) {
res_offsets.push_back(res_chars.size());
args_null_map[row] = 1;
continue;
}

const auto [start_byte_len, start_char_len] =
simd::VStringFunctions::iterate_utf8_with_limit_length(origin_str.begin(),
origin_str.end(), start);

// start >= orgin.size
DCHECK(start_char_len <= start);
if (start_byte_len == origin_str.size) {
res_offsets.push_back(res_chars.size());
args_null_map[row] = 1;
continue;
}

auto [end_byte_len, end_char_len] =
simd::VStringFunctions::iterate_utf8_with_limit_length(
origin_str.begin() + start_byte_len, origin_str.end(), length);
DCHECK(end_char_len <= length);
std::string_view replace_str = new_str.to_string_view();
std::string result = origin_str.to_string();
result.replace(start_byte_len, end_byte_len, replace_str);
result_column->insert_data(result.data(), result.length());
}
}
};

struct SubReplaceThreeImpl {
Expand All @@ -3467,13 +3524,14 @@ struct SubReplaceThreeImpl {

auto str_col =
block.get_by_position(arguments[1]).column->convert_to_full_column_if_const();
if (auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
if (const auto* nullable = check_and_get_column<const ColumnNullable>(*str_col)) {
str_col = nullable->get_nested_column_ptr();
}
auto& str_offset = assert_cast<const ColumnString*>(str_col.get())->get_offsets();

const auto* str_column = assert_cast<const ColumnString*>(str_col.get());
// use utf8 len
for (int i = 0; i < input_rows_count; ++i) {
strlen_data[i] = str_offset[i] - str_offset[i - 1];
StringRef str_ref = str_column->get_data_at(i);
strlen_data[i] = simd::VStringFunctions::get_char_len(str_ref.data, str_ref.size);
}

block.insert({std::move(params), std::make_shared<DataTypeInt32>(), "strlen"});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -386,3 +386,63 @@ tNEW-STRorigin str
-- !sql --
d***is

-- !sub_replace_utf8_sql1 --
你a世界

-- !sub_replace_utf8_sql2 --
你ab界

-- !sub_replace_utf8_sql3 --
你ab

-- !sub_replace_utf8_sql4 --
你abcd我界

-- !sub_replace_utf8_sql5 --
\N

-- !sub_replace_utf8_sql6 --
大家世界

-- !sub_replace_utf8_sql7 --
你大家114514

-- !sub_replace_utf8_sql8 --
\N

-- !sub_replace_utf8_sql9 --
\N

-- !sub_replace_utf8_sql10 --
\N

-- !sub_replace_utf8_sql1 --
你a世界

-- !sub_replace_utf8_sql2 --
你ab界

-- !sub_replace_utf8_sql3 --
你ab

-- !sub_replace_utf8_sql4 --
你abcd我界

-- !sub_replace_utf8_sql5 --
\N

-- !sub_replace_utf8_sql6 --
大家世界

-- !sub_replace_utf8_sql7 --
你大家114514

-- !sub_replace_utf8_sql8 --
\N

-- !sub_replace_utf8_sql9 --
\N

-- !sub_replace_utf8_sql10 --
\N

Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -191,4 +191,27 @@ suite("test_string_function") {

qt_sql "select sub_replace(\"this is origin str\",\"NEW-STR\",1);"
qt_sql "select sub_replace(\"doris\",\"***\",1,2);"
sql """ set debug_skip_fold_constant = true;"""
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"
sql """ set debug_skip_fold_constant = false;"""
qt_sub_replace_utf8_sql1 " select sub_replace('你好世界','a',1);"
qt_sub_replace_utf8_sql2 " select sub_replace('你好世界','ab',1);"
qt_sub_replace_utf8_sql3 " select sub_replace('你好世界','ab',1,20);"
qt_sub_replace_utf8_sql4 " select sub_replace('你好世界','abcd我',1,2);"
qt_sub_replace_utf8_sql5 " select sub_replace('你好世界','a',6);"
qt_sub_replace_utf8_sql6 " select sub_replace('你好世界','大家',0);"
qt_sub_replace_utf8_sql7 " select sub_replace('你好世界','大家114514',1,20);"
qt_sub_replace_utf8_sql8 " select sub_replace('你好世界','大家114514',6,20);"
qt_sub_replace_utf8_sql9 " select sub_replace('你好世界','大家',4);"
qt_sub_replace_utf8_sql10 " select sub_replace('你好世界','大家',-1);"

}
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,16 @@ suite("test_string_function", "arrow_flight_sql") {
qt_sql "select substring('abcdef',3,-1);"
qt_sql "select substring('abcdef',-3,-1);"
qt_sql "select substring('abcdef',10,1);"
sql """ set debug_skip_fold_constant = true;"""
qt_substring_utf8_sql "select substring('中文测试',5);"
qt_substring_utf8_sql "select substring('中文测试',4);"
qt_substring_utf8_sql "select substring('中文测试',2,2);"
qt_substring_utf8_sql "select substring('中文测试',-1,2);"
sql """ set debug_skip_fold_constant = false;"""
qt_substring_utf8_sql "select substring('中文测试',5);"
qt_substring_utf8_sql "select substring('中文测试',4);"
qt_substring_utf8_sql "select substring('中文测试',2,2);"
qt_substring_utf8_sql "select substring('中文测试',-1,2);"

sql """ drop table if exists test_string_function; """
sql """ create table test_string_function (
Expand Down

0 comments on commit cee07d6

Please sign in to comment.