From e5100d5e27dd92fa6f767a3021d599260ee75dc0 Mon Sep 17 00:00:00 2001 From: zhangstar333 <2561612514@qq.com> Date: Wed, 17 Apr 2024 11:53:15 +0800 Subject: [PATCH] support java-udtf --- .../pipeline/exec/table_function_operator.cpp | 11 +- .../pipeline/exec/table_function_operator.h | 9 +- be/src/vec/columns/column_array.cpp | 2 +- be/src/vec/exec/vtable_function_node.cpp | 4 +- be/src/vec/exec/vtable_function_node.h | 11 +- .../table_function/table_function_factory.cpp | 34 ++-- .../table_function/table_function_factory.h | 4 +- .../table_function/udf_table_function.cpp | 176 ++++++++++++++++++ .../exprs/table_function/udf_table_function.h | 110 +++++++++++ be/src/vec/exprs/vectorized_fn_call.cpp | 10 +- be/src/vec/functions/function_fake.h | 33 ++++ .../org/apache/doris/udf/BaseExecutor.java | 4 + .../apache/doris/common/FeMetaVersion.java | 4 +- fe/fe-core/src/main/cup/sql_parser.cup | 5 + .../doris/analysis/CreateFunctionStmt.java | 31 +++ .../doris/analysis/FunctionCallExpr.java | 10 + .../org/apache/doris/catalog/Function.java | 15 ++ .../apache/doris/catalog/FunctionUtil.java | 7 +- .../glue/translator/ExpressionTranslator.java | 9 + .../expressions/functions/udf/JavaUdtf.java | 171 +++++++++++++++++ .../functions/udf/JavaUdtfBuilder.java | 86 +++++++++ .../TableGeneratingFunctionVisitor.java | 4 + gensrc/thrift/Types.thrift | 1 + 23 files changed, 722 insertions(+), 29 deletions(-) create mode 100644 be/src/vec/exprs/table_function/udf_table_function.cpp create mode 100644 be/src/vec/exprs/table_function/udf_table_function.h create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtfBuilder.java diff --git a/be/src/pipeline/exec/table_function_operator.cpp b/be/src/pipeline/exec/table_function_operator.cpp index b4d993ef035acb8..13951076c8182f9 100644 --- a/be/src/pipeline/exec/table_function_operator.cpp +++ b/be/src/pipeline/exec/table_function_operator.cpp @@ -53,13 +53,15 @@ Status TableFunctionLocalState::open(RuntimeState* state) { for (size_t i = 0; i < _vfn_ctxs.size(); i++) { RETURN_IF_ERROR(p._vfn_ctxs[i]->clone(state, _vfn_ctxs[i])); - const std::string& tf_name = _vfn_ctxs[i]->root()->fn().name.function_name; vectorized::TableFunction* fn = nullptr; - RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(tf_name, state->obj_pool(), &fn)); + RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(_vfn_ctxs[i]->root()->fn(), + state->obj_pool(), &fn)); fn->set_expr_context(_vfn_ctxs[i]); _fns.push_back(fn); } - + for (auto* fn : _fns) { + RETURN_IF_ERROR(fn->open()); + } _cur_child_offset = -1; return Status::OK(); } @@ -269,9 +271,8 @@ Status TableFunctionOperatorX::init(const TPlanNode& tnode, RuntimeState* state) _vfn_ctxs.push_back(ctx); auto root = ctx->root(); - const std::string& tf_name = root->fn().name.function_name; vectorized::TableFunction* fn = nullptr; - RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(tf_name, _pool, &fn)); + RETURN_IF_ERROR(vectorized::TableFunctionFactory::get_fn(root->fn(), _pool, &fn)); fn->set_expr_context(ctx); _fns.push_back(fn); } diff --git a/be/src/pipeline/exec/table_function_operator.h b/be/src/pipeline/exec/table_function_operator.h index 49dd242bfe78d95..8a7b7bd43d45d19 100644 --- a/be/src/pipeline/exec/table_function_operator.h +++ b/be/src/pipeline/exec/table_function_operator.h @@ -56,6 +56,13 @@ class TableFunctionLocalState final : public PipelineXLocalState<> { ~TableFunctionLocalState() override = default; Status open(RuntimeState* state) override; + Status close(RuntimeState* state) override { + for (auto* fn : _fns) { + RETURN_IF_ERROR(fn->close()); + } + RETURN_IF_ERROR(PipelineXLocalState<>::close(state)); + return Status::OK(); + } void process_next_child_row(); Status get_expanded_block(RuntimeState* state, vectorized::Block* output_block, bool* eos); @@ -74,7 +81,7 @@ class TableFunctionLocalState final : public PipelineXLocalState<> { std::vector _fns; vectorized::VExprContextSPtrs _vfn_ctxs; - int64_t _cur_child_offset = 0; + int64_t _cur_child_offset = -1; std::unique_ptr _child_block; int _current_row_insert_times = 0; bool _child_eos = false; diff --git a/be/src/vec/columns/column_array.cpp b/be/src/vec/columns/column_array.cpp index 442ffd444227073..471f1b0ef568f25 100644 --- a/be/src/vec/columns/column_array.cpp +++ b/be/src/vec/columns/column_array.cpp @@ -413,7 +413,7 @@ void ColumnArray::insert(const Field& x) { } void ColumnArray::insert_from(const IColumn& src_, size_t n) { - DCHECK(n < src_.size()); + DCHECK(n < src_.size()) << n << " " << src_.size(); const ColumnArray& src = assert_cast(src_); size_t size = src.size_at(n); size_t offset = src.offset_at(n); diff --git a/be/src/vec/exec/vtable_function_node.cpp b/be/src/vec/exec/vtable_function_node.cpp index 0c35fae806ea934..9909d954bda728c 100644 --- a/be/src/vec/exec/vtable_function_node.cpp +++ b/be/src/vec/exec/vtable_function_node.cpp @@ -58,9 +58,8 @@ Status VTableFunctionNode::init(const TPlanNode& tnode, RuntimeState* state) { _vfn_ctxs.push_back(ctx); auto root = ctx->root(); - const std::string& tf_name = root->fn().name.function_name; TableFunction* fn = nullptr; - RETURN_IF_ERROR(TableFunctionFactory::get_fn(tf_name, _pool, &fn)); + RETURN_IF_ERROR(TableFunctionFactory::get_fn(root->fn(), _pool, &fn)); fn->set_expr_context(ctx); _fns.push_back(fn); } @@ -165,6 +164,7 @@ Status VTableFunctionNode::_get_expanded_block(RuntimeState* state, Block* outpu for (int i = 0; i < _fn_num; i++) { if (columns[i + _child_slots.size()]->is_nullable()) { _fns[i]->set_nullable(); + LOG(INFO) << "_fns[i]->set_nullable(): " << i + _child_slots.size(); } } diff --git a/be/src/vec/exec/vtable_function_node.h b/be/src/vec/exec/vtable_function_node.h index 0b64fe47cc541f3..41dbd8bab642300 100644 --- a/be/src/vec/exec/vtable_function_node.h +++ b/be/src/vec/exec/vtable_function_node.h @@ -58,7 +58,11 @@ class VTableFunctionNode final : public ExecNode { Status alloc_resource(RuntimeState* state) override { SCOPED_TIMER(_exec_timer); RETURN_IF_ERROR(ExecNode::alloc_resource(state)); - return VExpr::open(_vfn_ctxs, state); + RETURN_IF_ERROR(VExpr::open(_vfn_ctxs, state)); + for (auto* fn : _fns) { + RETURN_IF_ERROR(fn->open()); + } + return Status::OK(); } Status get_next(RuntimeState* state, Block* block, bool* eos) override; bool need_more_input_data() const { return !_child_block->rows() && !_child_eos; } @@ -67,6 +71,9 @@ class VTableFunctionNode final : public ExecNode { if (_num_rows_filtered_counter != nullptr) { COUNTER_SET(_num_rows_filtered_counter, static_cast(_num_rows_filtered)); } + for (auto* fn : _fns) { + static_cast(fn->close()); + } ExecNode::release_resource(state); } @@ -145,7 +152,7 @@ class VTableFunctionNode final : public ExecNode { std::shared_ptr _child_block; std::vector _child_slots; std::vector _output_slots; - int64_t _cur_child_offset = 0; + int64_t _cur_child_offset = -1; VExprContextSPtrs _vfn_ctxs; diff --git a/be/src/vec/exprs/table_function/table_function_factory.cpp b/be/src/vec/exprs/table_function/table_function_factory.cpp index e42c0a27fd11115..b6662e1215d2f5c 100644 --- a/be/src/vec/exprs/table_function/table_function_factory.cpp +++ b/be/src/vec/exprs/table_function/table_function_factory.cpp @@ -17,10 +17,14 @@ #include "vec/exprs/table_function/table_function_factory.h" +#include + +#include #include #include "common/object_pool.h" #include "vec/exprs/table_function/table_function.h" +#include "vec/exprs/table_function/udf_table_function.h" #include "vec/exprs/table_function/vexplode.h" #include "vec/exprs/table_function/vexplode_bitmap.h" #include "vec/exprs/table_function/vexplode_json_array.h" @@ -65,22 +69,26 @@ const std::unordered_map {}}, {"explode", TableFunctionCreator {}}}; -Status TableFunctionFactory::get_fn(const std::string& fn_name_raw, ObjectPool* pool, - TableFunction** fn) { - bool is_outer = match_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER); - std::string fn_name_real = - is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw; +Status TableFunctionFactory::get_fn(const TFunction& t_fn, ObjectPool* pool, TableFunction** fn) { + const std::string fn_name_raw = t_fn.name.function_name; + if (t_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { + *fn = pool->add(UDFTableFunction::create_unique(t_fn).release()); + return Status::OK(); + } else { + bool is_outer = match_suffix(t_fn.name.function_name, COMBINATOR_SUFFIX_OUTER); + std::string fn_name_real = + is_outer ? remove_suffix(fn_name_raw, COMBINATOR_SUFFIX_OUTER) : fn_name_raw; - auto fn_iterator = _function_map.find(fn_name_real); - if (fn_iterator != _function_map.end()) { - *fn = pool->add(fn_iterator->second().release()); - if (is_outer) { - (*fn)->set_outer(); - } + auto fn_iterator = _function_map.find(fn_name_real); + if (fn_iterator != _function_map.end()) { + *fn = pool->add(fn_iterator->second().release()); + if (is_outer) { + (*fn)->set_outer(); + } - return Status::OK(); + return Status::OK(); + } } - return Status::NotSupported("Table function {} is not support", fn_name_raw); } diff --git a/be/src/vec/exprs/table_function/table_function_factory.h b/be/src/vec/exprs/table_function/table_function_factory.h index a68a1763fc44e90..cd06c202f3778c3 100644 --- a/be/src/vec/exprs/table_function/table_function_factory.h +++ b/be/src/vec/exprs/table_function/table_function_factory.h @@ -17,6 +17,8 @@ #pragma once +#include + #include #include #include @@ -33,7 +35,7 @@ class TableFunction; class TableFunctionFactory { public: TableFunctionFactory() = delete; - static Status get_fn(const std::string& fn_name_raw, ObjectPool* pool, TableFunction** fn); + static Status get_fn(const TFunction& t_fn, ObjectPool* pool, TableFunction** fn); const static std::unordered_map()>> _function_map; diff --git a/be/src/vec/exprs/table_function/udf_table_function.cpp b/be/src/vec/exprs/table_function/udf_table_function.cpp new file mode 100644 index 000000000000000..cf5d7401fc1aa46 --- /dev/null +++ b/be/src/vec/exprs/table_function/udf_table_function.cpp @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/exprs/table_function/udf_table_function.h" + +#include + +#include +#include +#include +#include +#include + +#include "common/status.h" +#include "runtime/user_function_cache.h" +#include "vec/columns/column.h" +#include "vec/columns/column_array.h" +#include "vec/columns/column_nullable.h" +#include "vec/columns/column_string.h" +#include "vec/common/assert_cast.h" +#include "vec/core/block.h" +#include "vec/core/column_with_type_and_name.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_factory.hpp" +#include "vec/exec/jni_connector.h" +#include "vec/exprs/vexpr.h" +#include "vec/exprs/vexpr_context.h" + +namespace doris::vectorized { +const char* EXECUTOR_CLASS = "org/apache/doris/udf/UdfExecutor"; +const char* EXECUTOR_CTOR_SIGNATURE = "([B)V"; +const char* EXECUTOR_EVALUATE_SIGNATURE = "(Ljava/util/Map;Ljava/util/Map;)J"; +const char* EXECUTOR_CLOSE_SIGNATURE = "()V"; +UDFTableFunction::UDFTableFunction(const TFunction& t_fn) : TableFunction(), _t_fn(t_fn) { + _fn_name = _t_fn.name.function_name; + _return_type = DataTypeFactory::instance().create_data_type( + TypeDescriptor::from_thrift(t_fn.ret_type)); + _return_type = std::make_shared(make_nullable(_return_type)); +} + +Status UDFTableFunction::open() { + JNIEnv* env = nullptr; + RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); + if (env == nullptr) { + return Status::InternalError("Failed to get/create JVM"); + } + _jni_ctx = std::make_shared(); + // Add a scoped cleanup jni reference object. This cleans up local refs made below. + JniLocalFrame jni_frame; + { + std::string local_location; + auto* function_cache = UserFunctionCache::instance(); + RETURN_IF_ERROR(function_cache->get_jarpath(_t_fn.id, _t_fn.hdfs_location, _t_fn.checksum, + &local_location)); + TJavaUdfExecutorCtorParams ctor_params; + ctor_params.__set_fn(_t_fn); + ctor_params.__set_location(local_location); + jbyteArray ctor_params_bytes; + // Pushed frame will be popped when jni_frame goes out-of-scope. + RETURN_IF_ERROR(jni_frame.push(env)); + RETURN_IF_ERROR(SerializeThriftMsg(env, &ctor_params, &ctor_params_bytes)); + RETURN_IF_ERROR(JniUtil::GetGlobalClassRef(env, EXECUTOR_CLASS, &_jni_ctx->executor_cl)); + _jni_ctx->executor_ctor_id = + env->GetMethodID(_jni_ctx->executor_cl, "", EXECUTOR_CTOR_SIGNATURE); + _jni_ctx->executor_evaluate_id = + env->GetMethodID(_jni_ctx->executor_cl, "evaluate", EXECUTOR_EVALUATE_SIGNATURE); + _jni_ctx->executor_close_id = + env->GetMethodID(_jni_ctx->executor_cl, "close", EXECUTOR_CLOSE_SIGNATURE); + _jni_ctx->executor = env->NewObject(_jni_ctx->executor_cl, _jni_ctx->executor_ctor_id, + ctor_params_bytes); + jbyte* pBytes = env->GetByteArrayElements(ctor_params_bytes, nullptr); + env->ReleaseByteArrayElements(ctor_params_bytes, pBytes, JNI_ABORT); + env->DeleteLocalRef(ctor_params_bytes); + } + RETURN_ERROR_IF_EXC(env); + RETURN_IF_ERROR(JniUtil::LocalToGlobalRef(env, _jni_ctx->executor, &_jni_ctx->executor)); + _jni_ctx->open_successes = true; + return Status::OK(); +} + +Status UDFTableFunction::process_init(Block* block, RuntimeState* state) { + auto child_size = _expr_context->root()->children().size(); + std::vector child_column_idxs; + child_column_idxs.resize(child_size); + for (int i = 0; i < child_size; ++i) { + int result_id = -1; + RETURN_IF_ERROR(_expr_context->root()->children()[i]->execute(_expr_context.get(), block, + &result_id)); + DCHECK_NE(result_id, -1); + child_column_idxs[i] = result_id; + } + JNIEnv* env = nullptr; + RETURN_IF_ERROR(JniUtil::GetJNIEnv(&env)); + std::unique_ptr input_table; + RETURN_IF_ERROR( + JniConnector::to_java_table(block, block->rows(), child_column_idxs, input_table)); + auto input_table_schema = JniConnector::parse_table_schema(block, child_column_idxs, true); + std::map input_params = { + {"meta_address", std::to_string((long)input_table.get())}, + {"required_fields", input_table_schema.first}, + {"columns_types", input_table_schema.second}}; + + jobject input_map = JniUtil::convert_to_java_map(env, input_params); + _array_result_column = _return_type->create_column(); + _result_column_idx = block->columns(); + block->insert({_array_result_column, _return_type, "res"}); + auto output_table_schema = JniConnector::parse_table_schema(block, {_result_column_idx}, true); + std::string output_nullable = _return_type->is_nullable() ? "true" : "false"; + std::map output_params = {{"is_nullable", output_nullable}, + {"required_fields", output_table_schema.first}, + {"columns_types", output_table_schema.second}}; + + jobject output_map = JniUtil::convert_to_java_map(env, output_params); + DCHECK(_jni_ctx != nullptr); + DCHECK(_jni_ctx->executor != nullptr); + long output_address = env->CallLongMethod(_jni_ctx->executor, _jni_ctx->executor_evaluate_id, + input_map, output_map); + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + env->DeleteLocalRef(input_map); + env->DeleteLocalRef(output_map); + RETURN_IF_ERROR(JniConnector::fill_block(block, {_result_column_idx}, output_address)); + block->erase(_result_column_idx); + if (!extract_column_array_info(*_array_result_column, _array_column_detail)) { + return Status::NotSupported("column type {} not supported now", + block->get_by_position(_result_column_idx).column->get_name()); + } + return Status::OK(); +} + +void UDFTableFunction::process_row(size_t row_idx) { + TableFunction::process_row(row_idx); + if (!_array_column_detail.array_nullmap_data || + !_array_column_detail.array_nullmap_data[row_idx]) { + _array_offset = (*_array_column_detail.offsets_ptr)[row_idx - 1]; + _cur_size = (*_array_column_detail.offsets_ptr)[row_idx] - _array_offset; + } +} + +void UDFTableFunction::process_close() { + _array_result_column = nullptr; + _array_column_detail.reset(); + _array_offset = 0; +} + +void UDFTableFunction::get_value(MutableColumnPtr& column) { + size_t pos = _array_offset + _cur_offset; + if (current_empty() || (_array_column_detail.nested_nullmap_data && + _array_column_detail.nested_nullmap_data[pos])) { + column->insert_default(); + } else { + if (_is_nullable) { + auto* nullable_column = assert_cast(column.get()); + auto nested_column = nullable_column->get_nested_column_ptr(); + auto nullmap_column = nullable_column->get_null_map_column_ptr(); + nested_column->insert_from(*_array_column_detail.nested_col, pos); + assert_cast(nullmap_column.get())->insert_default(); + } else { + column->insert_from(*_array_column_detail.nested_col, pos); + } + } +} +} // namespace doris::vectorized diff --git a/be/src/vec/exprs/table_function/udf_table_function.h b/be/src/vec/exprs/table_function/udf_table_function.h new file mode 100644 index 000000000000000..3739f98e02ea1f8 --- /dev/null +++ b/be/src/vec/exprs/table_function/udf_table_function.h @@ -0,0 +1,110 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include +#include + +#include "common/status.h" +#include "jni.h" +#include "util/jni-util.h" +#include "vec/columns/column.h" +#include "vec/common/string_ref.h" +#include "vec/data_types/data_type.h" +#include "vec/exprs/table_function/table_function.h" +#include "vec/functions/array/function_array_utils.h" +namespace doris { +namespace vectorized { +class Block; +class ColumnString; +} // namespace vectorized +} // namespace doris + +namespace doris::vectorized { + +class UDFTableFunction final : public TableFunction { + ENABLE_FACTORY_CREATOR(UDFTableFunction); + +public: + UDFTableFunction(const TFunction& t_fn); + ~UDFTableFunction() override = default; + + Status open() override; + Status process_init(Block* block, RuntimeState* state) override; + void process_row(size_t row_idx) override; + void process_close() override; + void get_value(MutableColumnPtr& column) override; + Status close() override { + if (_jni_ctx) { + RETURN_IF_ERROR(_jni_ctx->close()); + } + return TableFunction::close(); + } + +private: + struct JniContext { + // Do not save parent directly, because parent is in VExpr, but jni context is in FunctionContext + // The deconstruct sequence is not determined, it will core. + // JniContext's lifecycle should same with function context, not related with expr + jclass executor_cl; + jmethodID executor_ctor_id; + jmethodID executor_evaluate_id; + jmethodID executor_close_id; + jobject executor = nullptr; + bool is_closed = false; + bool open_successes = false; + + JniContext() = default; + + Status close() { + if (!open_successes) { + LOG_WARNING("maybe open failed, need check the reason"); + return Status::OK(); //maybe open failed, so can't call some jni + } + if (is_closed) { + return Status::OK(); + } + VLOG_DEBUG << "Free resources for JniContext"; + JNIEnv* env = nullptr; + Status status = JniUtil::GetJNIEnv(&env); + if (!status.ok() || env == nullptr) { + LOG(WARNING) << "errors while get jni env " << status; + return status; + } + env->CallNonvirtualVoidMethodA(executor, executor_cl, executor_close_id, nullptr); + env->DeleteGlobalRef(executor); + env->DeleteGlobalRef(executor_cl); + RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); + is_closed = true; + return Status::OK(); + } + }; + + const TFunction& _t_fn; + std::shared_ptr _jni_ctx = nullptr; + DataTypePtr _return_type = nullptr; + ColumnPtr _array_result_column = nullptr; + ColumnArrayExecutionData _array_column_detail; + size_t _result_column_idx = 0; // _array_result_column pos in block + size_t _array_offset = 0; // start offset of array[row_idx] +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/exprs/vectorized_fn_call.cpp b/be/src/vec/exprs/vectorized_fn_call.cpp index 1c08b721cf90890..9867f5f33cf68e1 100644 --- a/be/src/vec/exprs/vectorized_fn_call.cpp +++ b/be/src/vec/exprs/vectorized_fn_call.cpp @@ -39,6 +39,7 @@ #include "vec/data_types/data_type_agg_state.h" #include "vec/exprs/vexpr_context.h" #include "vec/functions/function_agg_state.h" +#include "vec/functions/function_fake.h" #include "vec/functions/function_java_udf.h" #include "vec/functions/function_rpc.h" #include "vec/functions/simple_function_factory.h" @@ -67,12 +68,17 @@ Status VectorizedFnCall::prepare(RuntimeState* state, const RowDescriptor& desc, _expr_name = fmt::format("VectorizedFnCall[{}](arguments={},return={})", _fn.name.function_name, get_child_names(), _data_type->get_name()); - + LOG(INFO) << "_expr_name: " << _expr_name; if (_fn.binary_type == TFunctionBinaryType::RPC) { _function = FunctionRPC::create(_fn, argument_template, _data_type); } else if (_fn.binary_type == TFunctionBinaryType::JAVA_UDF) { if (config::enable_java_support) { - _function = JavaFunctionCall::create(_fn, argument_template, _data_type); + if (_fn.is_udtf_function) { + _function = FakeJavaUDTF::create(_fn, argument_template, _data_type); + // _function = JavaFunctionCall::create(_fn, argument_template, _data_type); + } else { + _function = JavaFunctionCall::create(_fn, argument_template, _data_type); + } } else { return Status::InternalError( "Java UDF is not enabled, you can change be config enable_java_support to true " diff --git a/be/src/vec/functions/function_fake.h b/be/src/vec/functions/function_fake.h index 0dabdfb3c83f829..d7eb5c33957fcc8 100644 --- a/be/src/vec/functions/function_fake.h +++ b/be/src/vec/functions/function_fake.h @@ -63,4 +63,37 @@ class FunctionFake : public IFunction { } }; +struct UDTFImpl { + static DataTypePtr get_return_type_impl(const DataTypes& arguments) { + DCHECK(false) << "get_return_type_impl not supported"; + return nullptr; + } + static std::string get_error_msg() { return "Fake function do not support execute"; } +}; + +class FakeJavaUDTF : public FunctionFake { +public: + FakeJavaUDTF(const TFunction& fn, const DataTypes& argument_types, + const DataTypePtr& return_type) + : _fn(fn), _argument_types(argument_types), _return_type(return_type) {} + + static FunctionPtr create(const TFunction& fn, const ColumnsWithTypeAndName& argument_types, + const DataTypePtr& return_type) { + DataTypes data_types(argument_types.size()); + for (size_t i = 0; i < argument_types.size(); ++i) { + data_types[i] = argument_types[i].type; + } + return std::make_shared(fn, data_types, return_type); + } + String get_name() const override { return _fn.name.function_name; } + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return _return_type; + } + +private: + const TFunction& _fn; + const DataTypes _argument_types; + const DataTypePtr _return_type; +}; + } // namespace doris::vectorized diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index 8ad171d60138f1b..2cb8ed5351fd37b 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -17,6 +17,7 @@ package org.apache.doris.udf; +import org.apache.doris.catalog.ArrayType; import org.apache.doris.catalog.Type; import org.apache.doris.common.exception.InternalException; import org.apache.doris.common.exception.UdfRuntimeException; @@ -88,6 +89,9 @@ public BaseExecutor(byte[] thriftParams) throws Exception { fn = request.fn; String jarFile = request.location; Type funcRetType = Type.fromThrift(request.fn.ret_type); + if (request.fn.is_udtf_function) { + funcRetType = ArrayType.create(funcRetType, true); + } init(request, jarFile, funcRetType, parameterTypes); } diff --git a/fe/fe-common/src/main/java/org/apache/doris/common/FeMetaVersion.java b/fe/fe-common/src/main/java/org/apache/doris/common/FeMetaVersion.java index ad273b3b2a5bc5a..6251632bd29d6f5 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/common/FeMetaVersion.java +++ b/fe/fe-common/src/main/java/org/apache/doris/common/FeMetaVersion.java @@ -80,9 +80,11 @@ public final class FeMetaVersion { public static final int VERSION_129 = 129; public static final int VERSION_130 = 130; + // for java-udtf add a bool field to write + public static final int VERSION_131 = 131; // note: when increment meta version, should assign the latest version to VERSION_CURRENT - public static final int VERSION_CURRENT = VERSION_130; + public static final int VERSION_CURRENT = VERSION_131; // all logs meta version should >= the minimum version, so that we could remove many if clause, for example // if (FE_METAVERSION < VERSION_94) ... diff --git a/fe/fe-core/src/main/cup/sql_parser.cup b/fe/fe-core/src/main/cup/sql_parser.cup index bbb54aa8cf250af..47808f5f7945365 100644 --- a/fe/fe-core/src/main/cup/sql_parser.cup +++ b/fe/fe-core/src/main/cup/sql_parser.cup @@ -1802,6 +1802,11 @@ create_stmt ::= {: RESULT = new CreateFunctionStmt(type, ifNotExists, functionName, args, parameters, func); :} + | KW_CREATE opt_var_type:type KW_TABLES KW_FUNCTION opt_if_not_exists:ifNotExists function_name:functionName LPAREN func_args_def:args RPAREN + KW_RETURNS type_def:returnType opt_intermediate_type:intermediateType opt_properties:properties + {: + RESULT = new CreateFunctionStmt(type, ifNotExists, functionName, args, returnType, intermediateType, properties); + :} /* Table */ | KW_CREATE opt_external:isExternal KW_TABLE opt_if_not_exists:ifNotExists table_name:name KW_LIKE table_name:existed_name KW_WITH KW_ROLLUP LPAREN ident_list:rollupNames RPAREN {: diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java index d498d1f75bc6656..b000190d42f0336 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CreateFunctionStmt.java @@ -104,6 +104,7 @@ public class CreateFunctionStmt extends DdlStmt { private final FunctionName functionName; private final boolean isAggregate; private final boolean isAlias; + private boolean isTableFunction; private final FunctionArgsDef argsDef; private final TypeDef returnType; private TypeDef intermediateType; @@ -140,10 +141,18 @@ public CreateFunctionStmt(SetType type, boolean ifNotExists, boolean isAggregate this.properties = ImmutableSortedMap.copyOf(properties, String.CASE_INSENSITIVE_ORDER); } this.isAlias = false; + this.isTableFunction = false; this.parameters = ImmutableList.of(); this.originFunction = null; } + public CreateFunctionStmt(SetType type, boolean ifNotExists, FunctionName functionName, + FunctionArgsDef argsDef, + TypeDef returnType, TypeDef intermediateType, Map properties) { + this(type, ifNotExists, false, functionName, argsDef, returnType, intermediateType, properties); + this.isTableFunction = true; + } + public CreateFunctionStmt(SetType type, boolean ifNotExists, FunctionName functionName, FunctionArgsDef argsDef, List parameters, Expr originFunction) { this.type = type; @@ -158,6 +167,7 @@ public CreateFunctionStmt(SetType type, boolean ifNotExists, FunctionName functi } this.originFunction = originFunction; this.isAggregate = false; + this.isTableFunction = false; this.returnType = new TypeDef(Type.VARCHAR); this.properties = ImmutableSortedMap.of(); } @@ -208,6 +218,8 @@ public void analyze(Analyzer analyzer) throws UserException { analyzeUda(); } else if (isAlias) { analyzeAliasFunction(); + } else if (isTableFunction) { + analyzeTableFunction(); } else { analyzeUdf(); } @@ -301,6 +313,25 @@ private void computeObjectChecksum() throws IOException, NoSuchAlgorithmExceptio } } + private void analyzeTableFunction() throws AnalysisException { + String symbol = properties.get(SYMBOL_KEY); + if (Strings.isNullOrEmpty(symbol)) { + throw new AnalysisException("No 'symbol' in properties"); + } + if (!returnType.getType().isArrayType()) { + throw new AnalysisException("JAVA_UDF OF UDTF return type must be array type"); + } + analyzeJavaUdf(symbol); + URI location = URI.create(userFile); + function = ScalarFunction.createUdf(binaryType, + functionName, argsDef.getArgTypes(), + ((ArrayType) (returnType.getType())).getItemType(), argsDef.isVariadic(), + location, symbol, null, null); + function.setChecksum(checksum); + function.setNullableMode(returnNullMode); + function.setUDTFunction(true); + } + private void analyzeUda() throws AnalysisException { AggregateFunction.AggregateFunctionBuilder builder = AggregateFunction.AggregateFunctionBuilder.createUdfBuilder(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java index b703948468cb1f9..d64aa47215b521f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/FunctionCallExpr.java @@ -1681,6 +1681,16 @@ && collectChildReturnTypes()[0].isDecimalV3()) { fn = getTableFunction(fnName.getFunction(), childTypes, Function.CompareMode.IS_NONSTRICT_SUPERTYPE_OF); } + // find user defined functions + if (fn == null) { + fn = findUdf(fnName, analyzer); + if (fn != null) { + FunctionUtil.checkEnableJavaUdf(); + if (!fn.isUDTFunction()) { + throw new AnalysisException(getFunctionNotFoundError(argTypes)); + } + } + } if (fn == null) { throw new AnalysisException(getFunctionNotFoundError(argTypes)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java index 7dbf3a0ec0a1a93..336e12f9bf33bfa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Function.java @@ -138,6 +138,8 @@ public enum NullableMode { // If true, this function is global function protected boolean isGlobal = false; + // If true, this function is table function, mainly used by java-udtf + protected boolean isUDTFunction = false; // Only used for serialization protected Function() { @@ -563,6 +565,7 @@ public TFunction toThrift(Type realReturnType, Type[] realArgTypes, Boolean[] re fn.setChecksum(checksum); } fn.setVectorized(vectorized); + fn.setIsUdtfFunction(isUDTFunction); return fn; } @@ -671,6 +674,7 @@ protected void writeFields(DataOutput output) throws IOException { IOUtils.writeOptionString(output, libUrl); IOUtils.writeOptionString(output, checksum); output.writeUTF(nullableMode.toString()); + output.writeBoolean(isUDTFunction); } @Override @@ -708,6 +712,9 @@ public void readFields(DataInput input) throws IOException { if (Env.getCurrentEnvJournalVersion() >= FeMetaVersion.VERSION_126) { nullableMode = NullableMode.valueOf(input.readUTF()); } + if (Env.getCurrentEnvJournalVersion() >= FeMetaVersion.VERSION_131) { + isUDTFunction = input.readBoolean(); + } } public static Function read(DataInput input) throws IOException { @@ -775,6 +782,14 @@ public NullableMode getNullableMode() { return nullableMode; } + public void setUDTFunction(boolean isUDTFunction) { + this.isUDTFunction = isUDTFunction; + } + + public boolean isUDTFunction() { + return this.isUDTFunction; + } + // Try to serialize this function and write to nowhere. // Just for checking if we forget to implement write() method for some Exprs. // To avoid FE exist when writing edit log. diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java index e6c7e073579bdc3..4c5dd85d4be07ba 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionUtil.java @@ -28,6 +28,7 @@ import org.apache.doris.nereids.trees.expressions.functions.udf.AliasUdf; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf; +import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdtf; import org.apache.doris.nereids.types.DataType; import com.google.common.base.Strings; @@ -238,7 +239,11 @@ public static boolean translateToNereids(String dbName, Function function) { if (function instanceof AliasFunction) { AliasUdf.translateToNereidsFunction(dbName, ((AliasFunction) function)); } else if (function instanceof ScalarFunction) { - JavaUdf.translateToNereidsFunction(dbName, ((ScalarFunction) function)); + if (function.isUDTFunction()) { + JavaUdtf.translateToNereidsFunction(dbName, ((ScalarFunction) function)); + } else { + JavaUdf.translateToNereidsFunction(dbName, ((ScalarFunction) function)); + } } else if (function instanceof AggregateFunction) { JavaUdaf.translateToNereidsFunction(dbName, ((AggregateFunction) function)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java index 6c7a1bd82c10786..70f1de8c55554bc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java @@ -96,6 +96,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf; import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdf; +import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdtf; import org.apache.doris.nereids.trees.expressions.functions.window.WindowFunction; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; @@ -650,6 +651,14 @@ public Expr visitJavaUdf(JavaUdf udf, PlanTranslatorContext context) { return new FunctionCallExpr(udf.getCatalogFunction(), exprs); } + @Override + public Expr visitJavaUdtf(JavaUdtf udf, PlanTranslatorContext context) { + FunctionParams exprs = new FunctionParams(udf.children().stream() + .map(expression -> expression.accept(this, context)) + .collect(Collectors.toList())); + return new FunctionCallExpr(udf.getCatalogFunction(), exprs); + } + @Override public Expr visitJavaUdaf(JavaUdaf udaf, PlanTranslatorContext context) { FunctionParams exprs = new FunctionParams(udaf.isDistinct(), udaf.children().stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java new file mode 100644 index 000000000000000..48bf65edc574aee --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtf.java @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.udf; + +import org.apache.doris.analysis.FunctionName; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.Function; +import org.apache.doris.catalog.Function.NullableMode; +import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.util.URI; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.Udf; +import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; +import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.thrift.TFunctionBinaryType; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.util.Arrays; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Java UDTF for Nereids + */ +public class JavaUdtf extends TableGeneratingFunction implements ExplicitlyCastableSignature, Udf { + private final String dbName; + private final long functionId; + private final TFunctionBinaryType binaryType; + private final FunctionSignature signature; + private final NullableMode nullableMode; + private final String objectFile; + private final String symbol; + private final String prepareFn; + private final String closeFn; + private final String checkSum; + + /** + * Constructor of UDTF + */ + public JavaUdtf(String name, long functionId, String dbName, TFunctionBinaryType binaryType, + FunctionSignature signature, + NullableMode nullableMode, String objectFile, String symbol, String prepareFn, String closeFn, + String checkSum, Expression... args) { + super(name, args); + this.dbName = dbName; + this.functionId = functionId; + this.binaryType = binaryType; + this.signature = signature; + this.nullableMode = nullableMode; + this.objectFile = objectFile; + this.symbol = symbol; + this.prepareFn = prepareFn; + this.closeFn = closeFn; + this.checkSum = checkSum; + } + + /** + * withChildren. + */ + @Override + public JavaUdtf withChildren(List children) { + Preconditions.checkArgument(children.size() == this.children.size()); + return new JavaUdtf(getName(), functionId, dbName, binaryType, signature, nullableMode, + objectFile, symbol, prepareFn, closeFn, checkSum, children.toArray(new Expression[0])); + } + + @Override + public List getSignatures() { + return ImmutableList.of(signature); + } + + @Override + public boolean hasVarArguments() { + return signature.hasVarArgs; + } + + @Override + public int arity() { + return signature.argumentsTypes.size(); + } + + @Override + public Function getCatalogFunction() { + try { + org.apache.doris.catalog.ScalarFunction expr = org.apache.doris.catalog.ScalarFunction.createUdf( + binaryType, + new FunctionName(dbName, getName()), + signature.argumentsTypes.stream().map(DataType::toCatalogDataType).toArray(Type[]::new), + signature.returnType.toCatalogDataType(), + signature.hasVarArgs, + URI.create(objectFile), + symbol, + prepareFn, + closeFn + ); + expr.setNullableMode(nullableMode); + expr.setChecksum(checkSum); + expr.setId(functionId); + expr.setUDTFunction(true); + return expr; + } catch (Exception e) { + throw new AnalysisException(e.getMessage(), e.getCause()); + } + } + + /** + * translate catalog java udf to nereids java udf + */ + public static void translateToNereidsFunction(String dbName, org.apache.doris.catalog.ScalarFunction scalar) { + String fnName = scalar.functionName(); + DataType retType = DataType.fromCatalogType(scalar.getReturnType()); + List argTypes = Arrays.stream(scalar.getArgs()) + .map(DataType::fromCatalogType) + .collect(Collectors.toList()); + + FunctionSignature.FuncSigBuilder sigBuilder = FunctionSignature.ret(retType); + FunctionSignature sig = scalar.hasVarArgs() + ? sigBuilder.varArgs(argTypes.toArray(new DataType[0])) + : sigBuilder.args(argTypes.toArray(new DataType[0])); + + VirtualSlotReference[] virtualSlots = argTypes.stream() + .map(type -> new VirtualSlotReference(type.toString(), type, Optional.empty(), + (shape) -> ImmutableList.of())) + .toArray(VirtualSlotReference[]::new); + + JavaUdtf udf = new JavaUdtf(fnName, scalar.getId(), dbName, scalar.getBinaryType(), sig, + scalar.getNullableMode(), + scalar.getLocation().getLocation(), + scalar.getSymbolName(), + scalar.getPrepareFnSymbol(), + scalar.getCloseFnSymbol(), + scalar.getChecksum(), + virtualSlots); + + JavaUdtfBuilder builder = new JavaUdtfBuilder(udf); + Env.getCurrentEnv().getFunctionRegistry().addUdf(dbName, fnName, builder); + } + + @Override + public NullableMode getNullableMode() { + return nullableMode; + } + + @Override + public R accept(ExpressionVisitor visitor, C context) { + return visitor.visitJavaUdtf(this, context); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtfBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtfBuilder.java new file mode 100644 index 000000000000000..88ac5cc84d55940 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdtfBuilder.java @@ -0,0 +1,86 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions.udf; + +import org.apache.doris.common.util.ReflectionUtils; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.util.TypeCoercionUtils; + +import com.google.common.base.Suppliers; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * function builder for java udtf + */ +public class JavaUdtfBuilder extends UdfBuilder { + private final JavaUdtf udf; + private final int arity; + private final boolean isVarArgs; + + public JavaUdtfBuilder(JavaUdtf udf) { + this.udf = udf; + this.isVarArgs = udf.hasVarArguments(); + this.arity = udf.arity(); + } + + @Override + public List getArgTypes() { + return Suppliers.memoize(() -> udf.getSignatures().get(0).argumentsTypes.stream() + .map(DataType.class::cast) + .collect(Collectors.toList())).get(); + } + + @Override + public Class functionClass() { + return JavaUdtf.class; + } + + @Override + public boolean canApply(List arguments) { + if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) { + return false; + } + for (Object argument : arguments) { + if (!(argument instanceof Expression)) { + Optional primitiveType = ReflectionUtils.getPrimitiveType(argument.getClass()); + if (!primitiveType.isPresent() || !Expression.class.isAssignableFrom(primitiveType.get())) { + return false; + } + } + } + return true; + } + + @Override + public BoundFunction build(String name, List arguments) { + List exprs = arguments.stream().map(Expression.class::cast).collect(Collectors.toList()); + List argTypes = udf.getSignatures().get(0).argumentsTypes; + + List processedExprs = Lists.newArrayList(); + for (int i = 0; i < exprs.size(); ++i) { + processedExprs.add(TypeCoercionUtils.castIfNotSameType(exprs.get(i), ((DataType) argTypes.get(i)))); + } + return udf.withChildren(processedExprs); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/TableGeneratingFunctionVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/TableGeneratingFunctionVisitor.java index 4e4c8ab2bd492db..61042283a5180e6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/TableGeneratingFunctionVisitor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/visitor/TableGeneratingFunctionVisitor.java @@ -36,6 +36,7 @@ import org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeSplit; import org.apache.doris.nereids.trees.expressions.functions.generator.ExplodeSplitOuter; import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction; +import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdtf; /** * visitor function for all table generating function. @@ -115,4 +116,7 @@ default R visitExplodeJsonArrayJsonOuter(ExplodeJsonArrayJsonOuter explodeJsonAr return visitTableGeneratingFunction(explodeJsonArrayJsonOuter, context); } + default R visitJavaUdtf(JavaUdtf udtf, C context) { + return visitTableGeneratingFunction(udtf, context); + } } diff --git a/gensrc/thrift/Types.thrift b/gensrc/thrift/Types.thrift index 66694645d74f665..9404f892052a635 100644 --- a/gensrc/thrift/Types.thrift +++ b/gensrc/thrift/Types.thrift @@ -381,6 +381,7 @@ struct TFunction { 11: optional i64 id 12: optional string checksum 13: optional bool vectorized = false + 14: optional bool is_udtf_function = false } enum TJdbcOperation {