Skip to content

Commit

Permalink
[feature](function) support hll functions hll_from_base64, hll_to_bas…
Browse files Browse the repository at this point in the history
…e64 (#32089)

Issue Number: #31320 

Support two hll functions:

- hll_from_base64
Convert a base64 string(result of function hll_to_base64) into a hll.
- hll_to_base64
Convert an input hll to a base64 string.
  • Loading branch information
superdiaodiao authored Apr 16, 2024
1 parent 8d773a7 commit c7c8916
Show file tree
Hide file tree
Showing 12 changed files with 594 additions and 2 deletions.
111 changes: 111 additions & 0 deletions be/src/vec/functions/hll_from_base64.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <cstddef>
#include <cstdint>

#include "olap/hll.h"
#include "util/url_coding.h"
#include "vec/columns/column_complex.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_hll.h"
#include "vec/functions/simple_function_factory.h"

namespace doris::vectorized {

class FunctionHllFromBase64 : public IFunction {
public:
static constexpr auto name = "hll_from_base64";

String get_name() const override { return name; }

static FunctionPtr create() { return std::make_shared<FunctionHllFromBase64>(); }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
return make_nullable(std::make_shared<DataTypeHLL>());
}

size_t get_number_of_arguments() const override { return 1; }

bool use_default_implementation_for_nulls() const override { return true; }

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count) const override {
auto res_null_map = ColumnUInt8::create(input_rows_count, 0);
auto res_data_column = ColumnHLL::create();
auto& null_map = res_null_map->get_data();
auto& res = res_data_column->get_data();

auto& argument_column = block.get_by_position(arguments[0]).column;
const auto& str_column = static_cast<const ColumnString&>(*argument_column);
const ColumnString::Chars& data = str_column.get_chars();
const ColumnString::Offsets& offsets = str_column.get_offsets();

res.reserve(input_rows_count);

std::string decode_buff;
int last_decode_buff_len = 0;
int curr_decode_buff_len = 0;
for (size_t i = 0; i < input_rows_count; ++i) {
const char* src_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]);
int64_t src_size = offsets[i] - offsets[i - 1];

// Base64 encoding has a characteristic where every 4 characters represent 3 bytes of data.
// Here, we check if the length of the input string is a multiple of 4 to ensure it's a valid base64 encoded string.
if (0 != src_size % 4) {
res.emplace_back();
null_map[i] = 1;
continue;
}

// Allocate sufficient space for the decoded data.
// The number 3 here represents the number of bytes in the decoded data for each group of 4 base64 characters.
// We set the size of the decoding buffer to be 'src_size + 3' to ensure there is enough space to store the decoded data.
curr_decode_buff_len = src_size + 3;
if (curr_decode_buff_len > last_decode_buff_len) {
decode_buff.resize(curr_decode_buff_len);
last_decode_buff_len = curr_decode_buff_len;
}
auto outlen = base64_decode(src_str, src_size, decode_buff.data());
if (outlen < 0) {
res.emplace_back();
null_map[i] = 1;
} else {
doris::Slice decoded_slice(decode_buff.data(), outlen);
doris::HyperLogLog hll;
if (!hll.deserialize(decoded_slice)) {
return Status::RuntimeError(
fmt::format("hll_from_base64 decode failed: base64: {}", src_str));
} else {
res.emplace_back(std::move(hll));
}
}
}

block.get_by_position(result).column =
ColumnNullable::create(std::move(res_data_column), std::move(res_null_map));
return Status::OK();
}
};

void register_function_hll_from_base64(SimpleFunctionFactory& factory) {
factory.register_function<FunctionHllFromBase64>();
}

} // namespace doris::vectorized
89 changes: 89 additions & 0 deletions be/src/vec/functions/hll_to_base64.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <cstddef>
#include <cstdint>

#include "olap/hll.h"
#include "util/url_coding.h"
#include "vec/columns/column_complex.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/data_type_hll.h"
#include "vec/data_types/data_type_string.h"
#include "vec/functions/function_totype.h"
#include "vec/functions/simple_function_factory.h"

namespace doris::vectorized {

struct NameHllToBase64 {
static constexpr auto name = "hll_to_base64";
};

struct HllToBase64 {
using ReturnType = DataTypeString;
static constexpr auto TYPE_INDEX = TypeIndex::HLL;
using Type = DataTypeHLL::FieldType;
using ReturnColumnType = ColumnString;
using Chars = ColumnString::Chars;
using Offsets = ColumnString::Offsets;

static Status vector(const std::vector<HyperLogLog>& data, Chars& chars, Offsets& offsets) {
size_t size = data.size();
offsets.resize(size);
size_t output_char_size = 0;
for (size_t i = 0; i < size; ++i) {
auto& hll_val = const_cast<HyperLogLog&>(data[i]);
auto ser_size = hll_val.max_serialized_size();
output_char_size += ser_size * (int)(4.0 * ceil((double)ser_size / 3.0));
}
ColumnString::check_chars_length(output_char_size, size);
chars.resize(output_char_size);
auto chars_data = chars.data();

size_t cur_ser_size = 0;
size_t last_ser_size = 0;
std::string ser_buff;
size_t encoded_offset = 0;
for (size_t i = 0; i < size; ++i) {
auto& hll_val = const_cast<HyperLogLog&>(data[i]);

cur_ser_size = hll_val.max_serialized_size();
if (cur_ser_size > last_ser_size) {
last_ser_size = cur_ser_size;
ser_buff.resize(cur_ser_size);
}
hll_val.serialize(reinterpret_cast<uint8_t*>(ser_buff.data()));
auto outlen = base64_encode((const unsigned char*)ser_buff.data(), cur_ser_size,
chars_data + encoded_offset);
DCHECK(outlen > 0);

encoded_offset += outlen;
offsets[i] = encoded_offset;
}
return Status::OK();
}
};

using FunctionHllToBase64 = FunctionUnaryToType<HllToBase64, NameHllToBase64>;

void register_function_hll_to_base64(SimpleFunctionFactory& factory) {
factory.register_function<FunctionHllToBase64>();
}

} // namespace doris::vectorized
4 changes: 4 additions & 0 deletions be/src/vec/functions/simple_function_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ void register_function_comparison(SimpleFunctionFactory& factory);
void register_function_comparison_eq_for_null(SimpleFunctionFactory& factory);
void register_function_hll_cardinality(SimpleFunctionFactory& factory);
void register_function_hll_empty(SimpleFunctionFactory& factory);
void register_function_hll_from_base64(SimpleFunctionFactory& factory);
void register_function_hll_hash(SimpleFunctionFactory& factory);
void register_function_hll_to_base64(SimpleFunctionFactory& factory);
void register_function_logical(SimpleFunctionFactory& factory);
void register_function_case(SimpleFunctionFactory& factory);
void register_function_cast(SimpleFunctionFactory& factory);
Expand Down Expand Up @@ -222,7 +224,9 @@ class SimpleFunctionFactory {
register_function_bitmap_variadic(instance);
register_function_hll_cardinality(instance);
register_function_hll_empty(instance);
register_function_hll_from_base64(instance);
register_function_hll_hash(instance);
register_function_hll_to_base64(instance);
register_function_comparison(instance);
register_function_logical(instance);
register_function_case(instance);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@
import org.apache.doris.nereids.trees.expressions.functions.scalar.Hex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllCardinality;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllEmpty;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllFromBase64;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllHash;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HllToBase64;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Hour;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HourCeil;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HourFloor;
Expand Down Expand Up @@ -617,7 +619,9 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(Hex.class, "hex"),
scalar(HllCardinality.class, "hll_cardinality"),
scalar(HllEmpty.class, "hll_empty"),
scalar(HllFromBase64.class, "hll_from_base64"),
scalar(HllHash.class, "hll_hash"),
scalar(HllToBase64.class, "hll_to_base64"),
scalar(Hour.class, "hour"),
scalar(HourCeil.class, "hour_ceil"),
scalar(HourFloor.class, "hour_floor"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'hll_from_string'.
*/
public class HllFromBase64 extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(HllType.INSTANCE).args(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(HllType.INSTANCE).args(StringType.INSTANCE)
);

/**
* constructor with 1 argument.
*/
public HllFromBase64(Expression arg) {
super("hll_from_base64", arg);
}

/**
* withChildren.
*/
@Override
public HllFromBase64 withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new HllFromBase64(children.get(0));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitHllFromBase64(this, context);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.scalar;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.StringType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* ScalarFunction 'hll_to_base64'.
*/
public class HllToBase64 extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(StringType.INSTANCE).args(HllType.INSTANCE)
);

/**
* constructor with 1 argument.
*/
public HllToBase64(Expression arg) {
super("hll_to_base64", arg);
}

/**
* withChildren.
*/
@Override
public HllToBase64 withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new HllToBase64(children.get(0));
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitHllToBase64(this, context);
}
}
Loading

0 comments on commit c7c8916

Please sign in to comment.