Skip to content

Commit

Permalink
support java-udtf
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangstar333 committed Apr 17, 2024
1 parent 6aa7665 commit e5100d5
Show file tree
Hide file tree
Showing 23 changed files with 722 additions and 29 deletions.
11 changes: 6 additions & 5 deletions be/src/pipeline/exec/table_function_operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,15 @@ Status TableFunctionLocalState::open(RuntimeState* state) {
for (size_t i = 0; i < _vfn_ctxs.size(); i++) {
RETURN_IF_ERROR(p._vfn_ctxs[i]->clone(state, _vfn_ctxs[i]));

const std::string& tf_name = _vfn_ctxs[i]->root()->fn().name.function_name;
vectorized::TableFunction* fn = nullptr;
RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(tf_name, state->obj_pool(), &fn));
RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(_vfn_ctxs[i]->root()->fn(),
state->obj_pool(), &fn));
fn->set_expr_context(_vfn_ctxs[i]);
_fns.push_back(fn);
}

for (auto* fn : _fns) {
RETURN_IF_ERROR(fn->open());
}
_cur_child_offset = -1;
return Status::OK();
}
Expand Down Expand Up @@ -269,9 +271,8 @@ Status TableFunctionOperatorX::init(const TPlanNode& tnode, RuntimeState* state)
_vfn_ctxs.push_back(ctx);

auto root = ctx->root();
const std::string& tf_name = root->fn().name.function_name;
vectorized::TableFunction* fn = nullptr;
RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(tf_name, _pool, &fn));
RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(root->fn(), _pool, &fn));
fn->set_expr_context(ctx);
_fns.push_back(fn);
}
Expand Down
9 changes: 8 additions & 1 deletion be/src/pipeline/exec/table_function_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ class TableFunctionLocalState final : public PipelineXLocalState<> {
~TableFunctionLocalState() override = default;

Status open(RuntimeState* state) override;
Status close(RuntimeState* state) override {
for (auto* fn : _fns) {
RETURN_IF_ERROR(fn->close());
}
RETURN_IF_ERROR(PipelineXLocalState<>::close(state));
return Status::OK();
}
void process_next_child_row();
Status get_expanded_block(RuntimeState* state, vectorized::Block* output_block, bool* eos);

Expand All @@ -74,7 +81,7 @@ class TableFunctionLocalState final : public PipelineXLocalState<> {

std::vector<vectorized::TableFunction*> _fns;
vectorized::VExprContextSPtrs _vfn_ctxs;
int64_t _cur_child_offset = 0;
int64_t _cur_child_offset = -1;
std::unique_ptr<vectorized::Block> _child_block;
int _current_row_insert_times = 0;
bool _child_eos = false;
Expand Down
2 changes: 1 addition & 1 deletion be/src/vec/columns/column_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ void ColumnArray::insert(const Field& x) {
}

void ColumnArray::insert_from(const IColumn& src_, size_t n) {
DCHECK(n < src_.size());
DCHECK(n < src_.size()) << n << " " << src_.size();
const ColumnArray& src = assert_cast<const ColumnArray&>(src_);
size_t size = src.size_at(n);
size_t offset = src.offset_at(n);
Expand Down
4 changes: 2 additions & 2 deletions be/src/vec/exec/vtable_function_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,8 @@ Status VTableFunctionNode::init(const TPlanNode& tnode, RuntimeState* state) {
_vfn_ctxs.push_back(ctx);

auto root = ctx->root();
const std::string& tf_name = root->fn().name.function_name;
TableFunction* fn = nullptr;
RETURN_IF_ERROR(TableFunctionFactory::get_fn(tf_name, _pool, &fn));
RETURN_IF_ERROR(TableFunctionFactory::get_fn(root->fn(), _pool, &fn));
fn->set_expr_context(ctx);
_fns.push_back(fn);
}
Expand Down Expand Up @@ -165,6 +164,7 @@ Status VTableFunctionNode::_get_expanded_block(RuntimeState* state, Block* outpu
for (int i = 0; i < _fn_num; i++) {
if (columns[i + _child_slots.size()]->is_nullable()) {
_fns[i]->set_nullable();
LOG(INFO) << "_fns[i]->set_nullable(): " << i + _child_slots.size();
}
}

Expand Down
11 changes: 9 additions & 2 deletions be/src/vec/exec/vtable_function_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ class VTableFunctionNode final : public ExecNode {
Status alloc_resource(RuntimeState* state) override {
SCOPED_TIMER(_exec_timer);
RETURN_IF_ERROR(ExecNode::alloc_resource(state));
return VExpr::open(_vfn_ctxs, state);
RETURN_IF_ERROR(VExpr::open(_vfn_ctxs, state));
for (auto* fn : _fns) {
RETURN_IF_ERROR(fn->open());
}
return Status::OK();
}
Status get_next(RuntimeState* state, Block* block, bool* eos) override;
bool need_more_input_data() const { return !_child_block->rows() && !_child_eos; }
Expand All @@ -67,6 +71,9 @@ class VTableFunctionNode final : public ExecNode {
if (_num_rows_filtered_counter != nullptr) {
COUNTER_SET(_num_rows_filtered_counter, static_cast<int64_t>(_num_rows_filtered));
}
for (auto* fn : _fns) {
static_cast<void>(fn->close());
}
ExecNode::release_resource(state);
}

Expand Down Expand Up @@ -145,7 +152,7 @@ class VTableFunctionNode final : public ExecNode {
std::shared_ptr<Block> _child_block;
std::vector<SlotDescriptor*> _child_slots;
std::vector<SlotDescriptor*> _output_slots;
int64_t _cur_child_offset = 0;
int64_t _cur_child_offset = -1;

VExprContextSPtrs _vfn_ctxs;

Expand Down
34 changes: 21 additions & 13 deletions be/src/vec/exprs/table_function/table_function_factory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,14 @@

#include "vec/exprs/table_function/table_function_factory.h"

#include <gen_cpp/Types_types.h>

#include <string_view>
#include <utility>

#include "common/object_pool.h"
#include "vec/exprs/table_function/table_function.h"
#include "vec/exprs/table_function/udf_table_function.h"
#include "vec/exprs/table_function/vexplode.h"
#include "vec/exprs/table_function/vexplode_bitmap.h"
#include "vec/exprs/table_function/vexplode_json_array.h"
Expand Down Expand Up @@ -65,22 +69,26 @@ const std::unordered_map<std::string, std::function<std::unique_ptr<TableFunctio
{"explode_map", TableFunctionCreator<VExplodeMapTableFunction> {}},
{"explode", TableFunctionCreator<VExplodeTableFunction> {}}};

Status TableFunctionFactory::get_fn(const std::string& fn_name_raw, ObjectPool* pool,
TableFunction** fn) {
bool is_outer = match_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER);
std::string fn_name_real =
is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw;
Status TableFunctionFactory::get_fn(const TFunction& t_fn, ObjectPool* pool, TableFunction** fn) {
const std::string fn_name_raw = t_fn.name.function_name;
if (t_fn.binary_type == TFunctionBinaryType::JAVA_UDF) {
*fn = pool->add(UDFTableFunction::create_unique(t_fn).release());
return Status::OK();
} else {
bool is_outer = match_suffix(t_fn.name.function_name, COMBINATOR_SUFFIX_OUTER);
std::string fn_name_real =
is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw;

auto fn_iterator = _function_map.find(fn_name_real);
if (fn_iterator != _function_map.end()) {
*fn = pool->add(fn_iterator->second().release());
if (is_outer) {
(*fn)->set_outer();
}
auto fn_iterator = _function_map.find(fn_name_real);
if (fn_iterator != _function_map.end()) {
*fn = pool->add(fn_iterator->second().release());
if (is_outer) {
(*fn)->set_outer();
}

return Status::OK();
return Status::OK();
}
}

return Status::NotSupported("Table function {} is not support", fn_name_raw);
}

Expand Down
4 changes: 3 additions & 1 deletion be/src/vec/exprs/table_function/table_function_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

#pragma once

#include <gen_cpp/Types_types.h>

#include <functional>
#include <memory>
#include <string>
Expand All @@ -33,7 +35,7 @@ class TableFunction;
class TableFunctionFactory {
public:
TableFunctionFactory() = delete;
static Status get_fn(const std::string& fn_name_raw, ObjectPool* pool, TableFunction** fn);
static Status get_fn(const TFunction& t_fn, ObjectPool* pool, TableFunction** fn);

const static std::unordered_map<std::string, std::function<std::unique_ptr<TableFunction>()>>
_function_map;
Expand Down
176 changes: 176 additions & 0 deletions be/src/vec/exprs/table_function/udf_table_function.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "vec/exprs/table_function/udf_table_function.h"

#include <glog/logging.h>

#include <algorithm>
#include <iterator>
#include <memory>
#include <ostream>
#include <sstream>

#include "common/status.h"
#include "runtime/user_function_cache.h"
#include "vec/columns/column.h"
#include "vec/columns/column_array.h"
#include "vec/columns/column_nullable.h"
#include "vec/columns/column_string.h"
#include "vec/common/assert_cast.h"
#include "vec/core/block.h"
#include "vec/core/column_with_type_and_name.h"
#include "vec/data_types/data_type_array.h"
#include "vec/data_types/data_type_factory.hpp"
#include "vec/exec/jni_connector.h"
#include "vec/exprs/vexpr.h"
#include "vec/exprs/vexpr_context.h"

namespace doris::vectorized {
const char* EXECUTOR_CLASS = "org/apache/doris/udf/UdfExecutor";
const char* EXECUTOR_CTOR_SIGNATURE = "([B)V";
const char* EXECUTOR_EVALUATE_SIGNATURE = "(Ljava/util/Map;Ljava/util/Map;)J";
const char* EXECUTOR_CLOSE_SIGNATURE = "()V";
UDFTableFunction::UDFTableFunction(const TFunction& t_fn) : TableFunction(), _t_fn(t_fn) {
_fn_name = _t_fn.name.function_name;
_return_type = DataTypeFactory::instance().create_data_type(
TypeDescriptor::from_thrift(t_fn.ret_type));
_return_type = std::make_shared<DataTypeArray>(make_nullable(_return_type));
}

Status UDFTableFunction::open() {
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
if (env == nullptr) {
return Status::InternalError("Failed to get/create JVM");
}
_jni_ctx = std::make_shared<JniContext>();
// Add a scoped cleanup jni reference object. This cleans up local refs made below.
JniLocalFrame jni_frame;
{
std::string local_location;
auto* function_cache = UserFunctionCache::instance();
RETURN_IF_ERROR(function_cache->get_jarpath(_t_fn.id, _t_fn.hdfs_location, _t_fn.checksum,
&local_location));
TJavaUdfExecutorCtorParams ctor_params;
ctor_params.__set_fn(_t_fn);
ctor_params.__set_location(local_location);
jbyteArray ctor_params_bytes;
// Pushed frame will be popped when jni_frame goes out-of-scope.
RETURN_IF_ERROR(jni_frame.push(env));
RETURN_IF_ERROR(SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes));
RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, EXECUTOR_CLASS, &_jni_ctx->executor_cl));
_jni_ctx->executor_ctor_id =
env->GetMethodID(_jni_ctx->executor_cl, "<init>", EXECUTOR_CTOR_SIGNATURE);
_jni_ctx->executor_evaluate_id =
env->GetMethodID(_jni_ctx->executor_cl, "evaluate", EXECUTOR_EVALUATE_SIGNATURE);
_jni_ctx->executor_close_id =
env->GetMethodID(_jni_ctx->executor_cl, "close", EXECUTOR_CLOSE_SIGNATURE);
_jni_ctx->executor = env->NewObject(_jni_ctx->executor_cl, _jni_ctx->executor_ctor_id,
ctor_params_bytes);
jbyte* pBytes = env->GetByteArrayElements(ctor_params_bytes, nullptr);
env->ReleaseByteArrayElements(ctor_params_bytes, pBytes, JNI_ABORT);
env->DeleteLocalRef(ctor_params_bytes);
}
RETURN_ERROR_IF_EXC(env);
RETURN_IF_ERROR(JniUtil::LocalToGlobalRef(env, _jni_ctx->executor, &_jni_ctx->executor));
_jni_ctx->open_successes = true;
return Status::OK();
}

Status UDFTableFunction::process_init(Block* block, RuntimeState* state) {
auto child_size = _expr_context->root()->children().size();
std::vector<size_t> child_column_idxs;
child_column_idxs.resize(child_size);
for (int i = 0; i < child_size; ++i) {
int result_id = -1;
RETURN_IF_ERROR(_expr_context->root()->children()[i]->execute(_expr_context.get(), block,
&result_id));
DCHECK_NE(result_id, -1);
child_column_idxs[i] = result_id;
}
JNIEnv* env = nullptr;
RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env));
std::unique_ptr<long[]> input_table;
RETURN_IF_ERROR(
JniConnector::to_java_table(block, block->rows(), child_column_idxs, input_table));
auto input_table_schema = JniConnector::parse_table_schema(block, child_column_idxs, true);
std::map<String, String> input_params = {
{"meta_address", std::to_string((long)input_table.get())},
{"required_fields", input_table_schema.first},
{"columns_types", input_table_schema.second}};

jobject input_map = JniUtil::convert_to_java_map(env, input_params);
_array_result_column = _return_type->create_column();
_result_column_idx = block->columns();
block->insert({_array_result_column, _return_type, "res"});
auto output_table_schema = JniConnector::parse_table_schema(block, {_result_column_idx}, true);
std::string output_nullable = _return_type->is_nullable() ? "true" : "false";
std::map<String, String> output_params = {{"is_nullable", output_nullable},
{"required_fields", output_table_schema.first},
{"columns_types", output_table_schema.second}};

jobject output_map = JniUtil::convert_to_java_map(env, output_params);
DCHECK(_jni_ctx != nullptr);
DCHECK(_jni_ctx->executor != nullptr);
long output_address = env->CallLongMethod(_jni_ctx->executor, _jni_ctx->executor_evaluate_id,
input_map, output_map);
RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env));
env->DeleteLocalRef(input_map);
env->DeleteLocalRef(output_map);
RETURN_IF_ERROR(JniConnector::fill_block(block, {_result_column_idx}, output_address));
block->erase(_result_column_idx);
if (!extract_column_array_info(*_array_result_column, _array_column_detail)) {
return Status::NotSupported("column type {} not supported now",
block->get_by_position(_result_column_idx).column->get_name());
}
return Status::OK();
}

void UDFTableFunction::process_row(size_t row_idx) {
TableFunction::process_row(row_idx);
if (!_array_column_detail.array_nullmap_data ||
!_array_column_detail.array_nullmap_data[row_idx]) {
_array_offset = (*_array_column_detail.offsets_ptr)[row_idx - 1];
_cur_size = (*_array_column_detail.offsets_ptr)[row_idx] - _array_offset;
}
}

void UDFTableFunction::process_close() {
_array_result_column = nullptr;
_array_column_detail.reset();
_array_offset = 0;
}

void UDFTableFunction::get_value(MutableColumnPtr& column) {
size_t pos = _array_offset + _cur_offset;
if (current_empty() || (_array_column_detail.nested_nullmap_data &&
_array_column_detail.nested_nullmap_data[pos])) {
column->insert_default();
} else {
if (_is_nullable) {
auto* nullable_column = assert_cast<ColumnNullable*>(column.get());
auto nested_column = nullable_column->get_nested_column_ptr();
auto nullmap_column = nullable_column->get_null_map_column_ptr();
nested_column->insert_from(*_array_column_detail.nested_col, pos);
assert_cast<ColumnUInt8*>(nullmap_column.get())->insert_default();
} else {
column->insert_from(*_array_column_detail.nested_col, pos);
}
}
}
} // namespace doris::vectorized
Loading

0 comments on commit e5100d5

Please sign in to comment.