Skip to content

Commit

Permalink
impl scalar functions trim_in
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiwen-up committed Oct 10, 2024
1 parent a9caf05 commit 4452a4b
Show file tree
Hide file tree
Showing 9 changed files with 821 additions and 3 deletions.
99 changes: 96 additions & 3 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <math.h>
#include <re2/stringpiece.h>

#include <bitset>
#include <cstddef>
#include <string_view>

Expand Down Expand Up @@ -508,6 +509,15 @@ struct NameLTrim {
struct NameRTrim {
static constexpr auto name = "rtrim";
};
struct NameTrimIn {
static constexpr auto name = "trim_in";
};
struct NameLTrimIn {
static constexpr auto name = "ltrim_in";
};
struct NameRTrimIn {
static constexpr auto name = "rtrim_in";
};
template <bool is_ltrim, bool is_rtrim, bool trim_single>
struct TrimUtil {
static Status vector(const ColumnString::Chars& str_data,
Expand Down Expand Up @@ -535,6 +545,42 @@ struct TrimUtil {
return Status::OK();
}
};
template <bool is_ltrim_in, bool is_rtrim_in, bool trim_single>
struct TrimInUtil {
static Status vector(const ColumnString::Chars& str_data,
const ColumnString::Offsets& str_offsets, const StringRef& remove_str,
ColumnString::Chars& res_data, ColumnString::Offsets& res_offsets) {
const size_t offset_size = str_offsets.size();
res_offsets.resize(offset_size);
res_data.reserve(str_data.size());
std::bitset<256> char_lookup;
for (size_t i = 0; i < remove_str.size; ++i) {
char_lookup[static_cast<unsigned char>(remove_str.data[i])] = true;
}

for (size_t i = 0; i < offset_size; ++i) {
const auto* str_begin = str_data.data() + str_offsets[i - 1];
const auto* str_end = str_data.data() + str_offsets[i];
const auto* p = str_begin;

if constexpr (is_ltrim_in) {
while (p < str_end && char_lookup[static_cast<unsigned char>(*p)]) {
p++;
}
}
const auto* str_end_trimmed = str_end;
if constexpr (is_rtrim_in) {
while (str_end_trimmed > p &&
char_lookup[static_cast<unsigned char>(*(str_end_trimmed - 1))]) {
str_end_trimmed--;
}
}
res_data.insert_assume_reserved(p, str_end_trimmed);
res_offsets[i] = res_data.size();
}
return Status::OK();
}
};
// This is an implementation of a parameter for the Trim function.
template <bool is_ltrim, bool is_rtrim, typename Name>
struct Trim1Impl {
Expand Down Expand Up @@ -583,14 +629,23 @@ struct Trim2Impl {
const auto* remove_str_raw = col_right->get_chars().data();
const ColumnString::Offset remove_str_size = col_right->get_offsets()[0];
const StringRef remove_str(remove_str_raw, remove_str_size);

if (remove_str.size == 1) {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, true>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str, col_res->get_chars(),
col_res->get_offsets())));
if constexpr (std::is_same<Name, NameTrimIn>::value ||
std::is_same<Name, NameLTrimIn>::value ||
std::is_same<Name, NameRTrimIn>::value) {
RETURN_IF_ERROR((TrimInUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(), col_res->get_offsets())));
} else {
RETURN_IF_ERROR((TrimUtil<is_ltrim, is_rtrim, false>::vector(
col->get_chars(), col->get_offsets(), remove_str,
col_res->get_chars(), col_res->get_offsets())));
}
}
block.replace_by_position(result, std::move(col_res));
} else {
Expand Down Expand Up @@ -640,6 +695,38 @@ class FunctionTrim : public IFunction {
}
};

template <typename impl>
class FunctionTrimIn : public IFunction {
public:
static constexpr auto name = impl::name;
static FunctionPtr create() { return std::make_shared<FunctionTrimIn<impl>>(); }
String get_name() const override { return impl::name; }

size_t get_number_of_arguments() const override {
return get_variadic_argument_types_impl().size();
}

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
if (!is_string_or_fixed_string(arguments[0])) {
throw doris::Exception(ErrorCode::INVALID_ARGUMENT,
"Illegal type {} of argument of function {}",
arguments[0]->get_name(), get_name());
}
return arguments[0];
}
// The second parameter of "trim" is a constant.
ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; }

DataTypes get_variadic_argument_types_impl() const override {
return impl::get_variadic_argument_types();
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
return impl::execute(context, block, arguments, result, input_rows_count);
}
};

static constexpr int MAX_STACK_CIPHER_LEN = 1024 * 64;
struct UnHexImpl {
static constexpr auto name = "unhex";
Expand Down Expand Up @@ -1023,6 +1110,12 @@ void register_function_string(SimpleFunctionFactory& factory) {
factory.register_function<FunctionTrim<Trim2Impl<true, true, NameTrim>>>();
factory.register_function<FunctionTrim<Trim2Impl<true, false, NameLTrim>>>();
factory.register_function<FunctionTrim<Trim2Impl<false, true, NameRTrim>>>();
factory.register_function<FunctionTrimIn<Trim1Impl<true, true, NameTrimIn>>>();
factory.register_function<FunctionTrimIn<Trim1Impl<true, false, NameLTrimIn>>>();
factory.register_function<FunctionTrimIn<Trim1Impl<false, true, NameRTrimIn>>>();
factory.register_function<FunctionTrimIn<Trim2Impl<true, true, NameTrimIn>>>();
factory.register_function<FunctionTrimIn<Trim2Impl<true, false, NameLTrimIn>>>();
factory.register_function<FunctionTrimIn<Trim2Impl<false, true, NameRTrimIn>>>();
factory.register_function<FunctionConvertTo>();
factory.register_function<FunctionSubstring<Substr3Impl>>();
factory.register_function<FunctionSubstring<Substr2Impl>>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lower;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lpad;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Ltrim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.LtrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MakeDate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsKey;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MapContainsValue;
Expand Down Expand Up @@ -358,6 +359,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.RoundBankers;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Rpad;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Rtrim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.RtrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecToTime;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Second;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondCeil;
Expand Down Expand Up @@ -438,6 +440,7 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Tokenize;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Translate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Trim;
import org.apache.doris.nereids.trees.expressions.functions.scalar.TrimIn;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Truncate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Unhex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.UnixTimestamp;
Expand Down Expand Up @@ -760,6 +763,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Lower.class, "lcase", "lower"),
scalar(Lpad.class, "lpad"),
scalar(Ltrim.class, "ltrim"),
scalar(LtrimIn.class, "ltrim_in"),
scalar(MakeDate.class, "makedate"),
scalar(MapContainsKey.class, "map_contains_key"),
scalar(MapContainsValue.class, "map_contains_value"),
Expand Down Expand Up @@ -835,6 +839,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(RoundBankers.class, "round_bankers"),
scalar(Rpad.class, "rpad"),
scalar(Rtrim.class, "rtrim"),
scalar(RtrimIn.class, "rtrim_in"),
scalar(Second.class, "second"),
scalar(SecondCeil.class, "second_ceil"),
scalar(SecondFloor.class, "second_floor"),
Expand Down Expand Up @@ -920,6 +925,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ToQuantileState.class, "to_quantile_state"),
scalar(Translate.class, "translate"),
scalar(Trim.class, "trim"),
scalar(TrimIn.class, "trim_in"),
scalar(Truncate.class, "truncate"),
scalar(Unhex.class, "unhex"),
scalar(UnixTimestamp.class, "unix_timestamp"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,27 @@ private static String trimImpl(String first, String second, boolean left, boolea
return result;
}

private static String trimInImpl(String first, String second, boolean left, boolean right) {
StringBuilder result = new StringBuilder(first);

if (left) {
int start = 0;
while (start < result.length() && second.indexOf(result.charAt(start)) != -1) {
start++;
}
result.delete(0, start);
}
if (right) {
int end = result.length();
while (end > 0 && second.indexOf(result.charAt(end - 1)) != -1) {
end--;
}
result.delete(end, result.length());
}

return result.toString();
}

/**
* Executable arithmetic functions Trim
*/
Expand Down Expand Up @@ -199,6 +220,54 @@ public static Expression rtrimVarcharVarchar(StringLikeLiteral first, StringLike
return castStringLikeLiteral(first, trimImpl(first.getValue(), second.getValue(), false, true));
}

/**
* Executable arithmetic functions Trim_In
*/
@ExecFunction(name = "trim_in")
public static Expression trimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", true, true));
}

/**
* Executable arithmetic functions Trim_In
*/
@ExecFunction(name = "trim_in")
public static Expression trimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), true, true));
}

/**
* Executable arithmetic functions ltrim_in
*/
@ExecFunction(name = "ltrim_in")
public static Expression ltrimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", true, false));
}

/**
* Executable arithmetic functions ltrim_in
*/
@ExecFunction(name = "ltrim_in")
public static Expression ltrimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), true, false));
}

/**
* Executable arithmetic functions rtrim_in
*/
@ExecFunction(name = "rtrim_in")
public static Expression rtrimInVarchar(StringLikeLiteral first) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), " ", false, true));
}

/**
* Executable arithmetic functions rtrim_in
*/
@ExecFunction(name = "rtrim_in")
public static Expression rtrimInVarcharVarchar(StringLikeLiteral first, StringLikeLiteral second) {
return castStringLikeLiteral(first, trimInImpl(first.getValue(), second.getValue(), false, true));
}

/**
* Executable arithmetic functions Replace
*/
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'ltrimIn'. This class is generated by GenerateFunction.
*/
public class LtrimIn extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {

private static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE).args(StringType.INSTANCE, StringType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE).args(StringType.INSTANCE)
);

private LtrimIn(List<Expression> args) {
super("ltrim_in", args);
}

/**
* constructor with 1 argument.
*/
public LtrimIn(Expression arg) {
super("ltrim_in", arg);
}

/**
* constructor with 2 argument.
*/
public LtrimIn(Expression arg0, Expression arg1) {
super("ltrim_in", arg0, arg1);
}

/**
* withChildren.
*/
@Override
public LtrimIn withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1 || children.size() == 2);
return new LtrimIn(children);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitLtrimIn(this, context);
}
}
Loading

0 comments on commit 4452a4b

Please sign in to comment.