Skip to content

Commit

Permalink
[GLUTEN-7780][CH] Fix split diff (#7781)
Browse files Browse the repository at this point in the history
* fix split diff

* fix code style

* fix code style
  • Loading branch information
taiyang-li authored Nov 4, 2024
1 parent 78d3604 commit c40735b
Show file tree
Hide file tree
Showing 7 changed files with 252 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -563,5 +563,12 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite {
compareResultsAgainstVanillaSpark(sql, true, { _ => })
spark.sql("drop table t1")
}

test("GLUTEN-7780 fix split diff") {
val sql = "select split(concat('a|b|c', cast(id as string)), '\\|')" +
", split(concat('a|b|c', cast(id as string)), '\\\\|')" +
", split(concat('a|b|c', cast(id as string)), '|') from range(10)"
compareResultsAgainstVanillaSpark(sql, true, { _ => })
}
}
// scalastyle:off line.size.limit
239 changes: 239 additions & 0 deletions cpp-ch/local-engine/Functions/SparkFunctionSplitByRegexp.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <Columns/ColumnConst.h>
#include <DataTypes/IDataType.h>
#include <Functions/FunctionFactory.h>
#include <Functions/FunctionHelpers.h>
#include <Functions/FunctionTokens.h>
#include <Functions/Regexps.h>
#include <Common/StringUtils.h>
#include <base/map.h>
#include <Common/assert_cast.h>


namespace DB
{

namespace ErrorCodes
{
extern const int ILLEGAL_COLUMN;
}


/** Functions that split strings into an array of strings or vice versa.
*
* splitByRegexp(regexp, s[, max_substrings])
*/
namespace
{

using Pos = const char *;

class SparkSplitByRegexpImpl
{
private:
Regexps::RegexpPtr re;
OptimizedRegularExpression::MatchVec matches;

Pos pos;
Pos end;

std::optional<size_t> max_splits;
size_t splits;
bool max_substrings_includes_remaining_string;

public:
static constexpr auto name = "splitByRegexpSpark";

static bool isVariadic() { return true; }
static size_t getNumberOfArguments() { return 0; }

static ColumnNumbers getArgumentsThatAreAlwaysConstant() { return {0, 2}; }

static void checkArguments(const IFunction & func, const ColumnsWithTypeAndName & arguments)
{
checkArgumentsWithSeparatorAndOptionalMaxSubstrings(func, arguments);
}

static constexpr auto strings_argument_position = 1uz;

void init(const ColumnsWithTypeAndName & arguments, bool max_substrings_includes_remaining_string_)
{
const ColumnConst * col = checkAndGetColumnConstStringOrFixedString(arguments[0].column.get());

if (!col)
throw Exception(ErrorCodes::ILLEGAL_COLUMN, "Illegal column {} of first argument of function {}. "
"Must be constant string.", arguments[0].column->getName(), name);

if (!col->getValue<String>().empty())
re = std::make_shared<OptimizedRegularExpression>(Regexps::createRegexp<false, false, false>(col->getValue<String>()));

max_substrings_includes_remaining_string = max_substrings_includes_remaining_string_;
max_splits = extractMaxSplits(arguments, 2);
}

/// Called for each next string.
void set(Pos pos_, Pos end_)
{
pos = pos_;
end = end_;
splits = 0;
}

/// Get the next token, if any, or return false.
bool get(Pos & token_begin, Pos & token_end)
{
if (!re)
{
if (pos == end)
return false;

token_begin = pos;

if (max_splits)
{
if (max_substrings_includes_remaining_string)
{
if (splits == *max_splits - 1)
{
token_end = end;
pos = end;
return true;
}
}
else
if (splits == *max_splits)
return false;
}

++pos;
token_end = pos;
++splits;
}
else
{
if (!pos || pos > end)
return false;

token_begin = pos;

if (max_splits)
{
if (max_substrings_includes_remaining_string)
{
if (splits == *max_splits - 1)
{
token_end = end;
pos = nullptr;
return true;
}
}
else
if (splits == *max_splits)
return false;
}

auto res = re->match(pos, end - pos, matches);
if (!res)
{
token_end = end;
pos = end + 1;
}
else if (!matches[0].length)
{
/// If match part is empty, increment position to avoid infinite loop.
token_end = (pos == end ? end : pos + 1);
++pos;
++splits;
}
else
{
token_end = pos + matches[0].offset;
pos = token_end + matches[0].length;
++splits;
}
}

return true;
}
};

using SparkFunctionSplitByRegexp = FunctionTokens<SparkSplitByRegexpImpl>;

/// Fallback splitByRegexp to splitByChar when its 1st argument is a trivial char for better performance
class SparkSplitByRegexpOverloadResolver : public IFunctionOverloadResolver
{
public:
static constexpr auto name = "splitByRegexpSpark";
static FunctionOverloadResolverPtr create(ContextPtr context) { return std::make_unique<SparkSplitByRegexpOverloadResolver>(context); }

explicit SparkSplitByRegexpOverloadResolver(ContextPtr context_)
: context(context_)
, split_by_regexp(SparkFunctionSplitByRegexp::create(context)) {}

String getName() const override { return name; }
size_t getNumberOfArguments() const override { return SparkSplitByRegexpImpl::getNumberOfArguments(); }
bool isVariadic() const override { return SparkSplitByRegexpImpl::isVariadic(); }

FunctionBasePtr buildImpl(const ColumnsWithTypeAndName & arguments, const DataTypePtr & return_type) const override
{
if (patternIsTrivialChar(arguments))
return FunctionFactory::instance().getImpl("splitByChar", context)->build(arguments);
return std::make_unique<FunctionToFunctionBaseAdaptor>(
split_by_regexp, collections::map<DataTypes>(arguments, [](const auto & elem) { return elem.type; }), return_type);
}

DataTypePtr getReturnTypeImpl(const ColumnsWithTypeAndName & arguments) const override
{
return split_by_regexp->getReturnTypeImpl(arguments);
}

private:
bool patternIsTrivialChar(const ColumnsWithTypeAndName & arguments) const
{
if (!arguments[0].column.get())
return false;
const ColumnConst * col = checkAndGetColumnConstStringOrFixedString(arguments[0].column.get());
if (!col)
return false;

String pattern = col->getValue<String>();
if (pattern.size() == 1)
{
OptimizedRegularExpression re = Regexps::createRegexp<false, false, false>(pattern);

std::string required_substring;
bool is_trivial;
bool required_substring_is_prefix;
re.getAnalyzeResult(required_substring, is_trivial, required_substring_is_prefix);
return is_trivial && required_substring == pattern;
}
return false;
}

ContextPtr context;
FunctionPtr split_by_regexp;
};
}

REGISTER_FUNCTION(SparkSplitByRegexp)
{
factory.registerFunction<SparkSplitByRegexpOverloadResolver>();
}

}
12 changes: 6 additions & 6 deletions cpp-ch/local-engine/Parser/scalar_function_parser/split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@

namespace local_engine
{
class SparkFunctionSplitParser : public FunctionParser
class FunctionSplitParser : public FunctionParser
{
public:
SparkFunctionSplitParser(ParserContextPtr parser_context_) : FunctionParser(parser_context_) {}
~SparkFunctionSplitParser() override = default;
FunctionSplitParser(ParserContextPtr parser_context_) : FunctionParser(parser_context_) {}
~FunctionSplitParser() override = default;
static constexpr auto name = "split";
String getName() const override { return name; }
String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "splitByRegexp"; }
String getCHFunctionName(const substrait::Expression_ScalarFunction &) const override { return "splitByRegexpSpark"; }

const DB::ActionsDAG::Node * parse(const substrait::Expression_ScalarFunction & substrait_func, DB::ActionsDAG & actions_dag) const override
{
Expand All @@ -35,14 +35,14 @@ class SparkFunctionSplitParser : public FunctionParser
for (const auto & arg : args)
parsed_args.emplace_back(parseExpression(actions_dag, arg.value()));
/// In Spark: split(str, regex [, limit] )
/// In CH: splitByRegexp(regexp, str [, limit])
/// In CH: splitByRegexpSpark(regexp, str [, limit])
if (parsed_args.size() >= 2)
std::swap(parsed_args[0], parsed_args[1]);
auto ch_function_name = getCHFunctionName(substrait_func);
const auto * func_node = toFunctionNode(actions_dag, ch_function_name, parsed_args);
return convertNodeTypeIfNeeded(substrait_func, func_node, actions_dag);
}
};
static FunctionParserRegister<SparkFunctionSplitParser> register_split;
static FunctionParserRegister<FunctionSplitParser> register_split;
}

Original file line number Diff line number Diff line change
Expand Up @@ -844,8 +844,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-32110: compare special double/float values in struct")
enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen with long seed")
enableSuite[GlutenRegexpExpressionsSuite]
.exclude("LIKE ALL")
.exclude("LIKE ANY")
.exclude("LIKE Pattern")
.exclude("LIKE Pattern ESCAPE '/'")
.exclude("LIKE Pattern ESCAPE '#'")
Expand All @@ -854,8 +852,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("RegexReplace")
.exclude("RegexExtract")
.exclude("RegexExtractAll")
.exclude("SPLIT")
.exclude("SPARK-34814: LikeSimplification should handle NULL")
enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
enableSuite[GlutenStringExpressionsSuite]
.exclude("StringComparison")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-32110: compare special double/float values in struct")
enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen with long seed")
enableSuite[GlutenRegexpExpressionsSuite]
.exclude("LIKE ALL")
.exclude("LIKE ANY")
.exclude("LIKE Pattern")
.exclude("LIKE Pattern ESCAPE '/'")
.exclude("LIKE Pattern ESCAPE '#'")
Expand All @@ -827,8 +825,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("RegexReplace")
.exclude("RegexExtract")
.exclude("RegexExtractAll")
.exclude("SPLIT")
.exclude("SPARK - 34814: LikeSimplification should handleNULL")
enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
enableSuite[GlutenStringExpressionsSuite]
.exclude("StringComparison")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-32110: compare special double/float values in struct")
enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen with long seed")
enableSuite[GlutenRegexpExpressionsSuite]
.exclude("LIKE ALL")
.exclude("LIKE ANY")
.exclude("LIKE Pattern")
.exclude("LIKE Pattern ESCAPE '/'")
.exclude("LIKE Pattern ESCAPE '#'")
Expand All @@ -750,8 +748,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("RegexReplace")
.exclude("RegexExtract")
.exclude("RegexExtractAll")
.exclude("SPLIT")
.exclude("SPARK - 34814: LikeSimplification should handleNULL")
enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
enableSuite[GlutenStringExpressionsSuite]
.exclude("StringComparison")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -740,8 +740,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("SPARK-32110: compare special double/float values in struct")
enableSuite[GlutenRandomSuite].exclude("random").exclude("SPARK-9127 codegen with long seed")
enableSuite[GlutenRegexpExpressionsSuite]
.exclude("LIKE ALL")
.exclude("LIKE ANY")
.exclude("LIKE Pattern")
.exclude("LIKE Pattern ESCAPE '/'")
.exclude("LIKE Pattern ESCAPE '#'")
Expand All @@ -750,8 +748,6 @@ class ClickHouseTestSettings extends BackendTestSettings {
.exclude("RegexReplace")
.exclude("RegexExtract")
.exclude("RegexExtractAll")
.exclude("SPLIT")
.exclude("SPARK - 34814: LikeSimplification should handleNULL")
enableSuite[GlutenSortOrderExpressionsSuite].exclude("SortPrefix")
enableSuite[GlutenStringExpressionsSuite]
.exclude("StringComparison")
Expand Down

0 comments on commit c40735b

Please sign in to comment.