diff --git a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h index d51c219f3f108f..defd33b47546a2 100644 --- a/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h +++ b/be/src/vec/aggregate_functions/aggregate_function_java_udaf.h @@ -25,6 +25,7 @@ #include "common/compiler_util.h" #include "common/exception.h" +#include "common/logging.h" #include "common/status.h" #include "gutil/strings/substitute.h" #include "runtime/user_function_cache.h" @@ -55,15 +56,7 @@ const char* UDAF_EXECUTOR_RESET_SIGNATURE = "(J)V"; struct AggregateJavaUdafData { public: AggregateJavaUdafData() = default; - AggregateJavaUdafData(int64_t num_args) { - argument_size = num_args; - output_value_buffer = std::make_unique(0); - output_null_value = std::make_unique(0); - output_offsets_ptr = std::make_unique(0); - output_intermediate_state_ptr = std::make_unique(0); - output_array_null_ptr = std::make_unique(0); - output_array_string_offsets_ptr = std::make_unique(0); - } + AggregateJavaUdafData(int64_t num_args) { argument_size = num_args; } ~AggregateJavaUdafData() { JNIEnv* env; @@ -89,16 +82,6 @@ struct AggregateJavaUdafData { ctor_params.__set_fn(fn); ctor_params.__set_location(local_location); - ctor_params.__set_output_buffer_ptr((int64_t)output_value_buffer.get()); - - ctor_params.__set_output_null_ptr((int64_t)output_null_value.get()); - ctor_params.__set_output_offsets_ptr((int64_t)output_offsets_ptr.get()); - ctor_params.__set_output_intermediate_state_ptr( - (int64_t)output_intermediate_state_ptr.get()); - ctor_params.__set_output_array_null_ptr((int64_t)output_array_null_ptr.get()); - ctor_params.__set_output_array_string_offsets_ptr( - (int64_t)output_array_string_offsets_ptr.get()); - jbyteArray ctor_params_bytes; // Pushed frame will be popped when jni_frame goes out-of-scope. @@ -295,23 +278,27 @@ struct AggregateJavaUdafData { to.insert_default(); JNIEnv* env = nullptr; RETURN_NOT_OK_STATUS_WITH_WARN(JniUtil::GetJNIEnv(&env), "Java-Udaf get value function"); + int64_t nullmap_address = 0; if (result_type->is_nullable()) { auto& nullable = assert_cast(to); - *output_null_value = + nullmap_address = reinterpret_cast(nullable.get_null_map_column().get_raw_data().data); auto& data_col = nullable.get_nested_column(); - RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address)); } else { - *output_null_value = -1; + nullmap_address = -1; auto& data_col = to; - RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col)); + RETURN_IF_ERROR(get_result(to, result_type, place, env, data_col, nullmap_address)); } return JniUtil::GetJniExceptionMsg(env); } private: - Status get_result(IColumn& to, const DataTypePtr& result_type, int64_t place, JNIEnv* env, - IColumn& data_col) const { + Status get_result(IColumn& to, const DataTypePtr& return_type, int64_t place, JNIEnv* env, + IColumn& data_col, int64_t nullmap_address) const { + jobject result_obj = env->CallNonvirtualObjectMethod(executor_obj, executor_cl, + executor_get_value_id, place); + bool result_nullable = return_type->is_nullable(); if (data_col.is_column_string()) { const ColumnString* str_col = check_and_get_column(data_col); ColumnString::Chars& chars = const_cast(str_col->get_chars()); @@ -320,109 +307,119 @@ struct AggregateJavaUdafData { int increase_buffer_size = 0; int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); chars.resize(buffer_size); - *output_value_buffer = reinterpret_cast(chars.data()); - *output_offsets_ptr = reinterpret_cast(offsets.data()); - *output_intermediate_state_ptr = chars.size(); - jboolean res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - while (res != JNI_TRUE) { - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); - increase_buffer_size++; - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - try { - chars.resize(buffer_size); - } catch (std::bad_alloc const& e) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "memory allocate failed in column string, " - "buffer:{},size:{},reason:{}", - increase_buffer_size, buffer_size, e.what()); - } - *output_value_buffer = reinterpret_cast(chars.data()); - *output_intermediate_state_ptr = chars.size(); - res = env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, - executor_result_id, to.size() - 1, place); - } + env->CallNonvirtualVoidMethod( + executor_obj, executor_cl, executor_copy_basic_result_id, result_obj, + to.size() - 1, nullmap_address, reinterpret_cast(chars.data()), + reinterpret_cast(&chars), reinterpret_cast(offsets.data())); } else if (data_col.is_numeric() || data_col.is_column_decimal()) { - *output_value_buffer = reinterpret_cast(data_col.get_raw_data().data); - env->CallNonvirtualBooleanMethod(executor_obj, executor_cl, executor_result_id, - to.size() - 1, place); + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_basic_result_id, + result_obj, to.size() - 1, nullmap_address, + reinterpret_cast(data_col.get_raw_data().data), + 0, 0); } else if (data_col.is_column_array()) { - ColumnArray& array_col = assert_cast(data_col); + jclass arraylist_class = env->FindClass("Ljava/util/ArrayList;"); + ColumnArray* array_col = assert_cast(&data_col); ColumnNullable& array_nested_nullable = - assert_cast(array_col.get_data()); + assert_cast(array_col->get_data()); auto data_column_null_map = array_nested_nullable.get_null_map_column_ptr(); auto data_column = array_nested_nullable.get_nested_column_ptr(); - auto& offset_column = array_col.get_offsets_column(); - int increase_buffer_size = 0; - int64_t buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - *output_offsets_ptr = reinterpret_cast(offset_column.get_raw_data().data); - data_column_null_map->resize(buffer_size); + auto& offset_column = array_col->get_offsets_column(); + auto offset_address = reinterpret_cast(offset_column.get_raw_data().data); auto& null_map_data = assert_cast*>(data_column_null_map.get())->get_data(); - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); - *output_intermediate_state_ptr = buffer_size; + auto nested_nullmap_address = reinterpret_cast(null_map_data.data()); + jmethodID list_size = env->GetMethodID(arraylist_class, "size", "()I"); + + size_t has_put_element_size = array_col->get_offsets().back(); + size_t arrar_list_size = env->CallIntMethod(result_obj, list_size); + size_t element_size = has_put_element_size + arrar_list_size; + array_nested_nullable.resize(element_size); + memset(null_map_data.data() + has_put_element_size, 0, arrar_list_size); + int64_t nested_data_address = 0, nested_offset_address = 0; if (data_column->is_column_string()) { ColumnString* str_col = assert_cast(data_column.get()); ColumnString::Chars& chars = assert_cast(str_col->get_chars()); ColumnString::Offsets& offsets = assert_cast(str_col->get_offsets()); - chars.resize(buffer_size); - offsets.resize(buffer_size); - *output_value_buffer = reinterpret_cast(chars.data()); - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); - jboolean res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - while (res != JNI_TRUE) { - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); - increase_buffer_size++; - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - try { - null_map_data.resize(buffer_size); - chars.resize(buffer_size); - offsets.resize(buffer_size); - } catch (std::bad_alloc const& e) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "memory allocate failed in array column string, " - "buffer:{},size:{},reason:{}", - increase_buffer_size, buffer_size, e.what()); - } - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); - *output_value_buffer = reinterpret_cast(chars.data()); - *output_array_string_offsets_ptr = reinterpret_cast(offsets.data()); - *output_intermediate_state_ptr = buffer_size; - res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - } + nested_data_address = reinterpret_cast(&chars); + nested_offset_address = reinterpret_cast(offsets.data()); } else { - data_column->resize(buffer_size); - *output_value_buffer = reinterpret_cast(data_column->get_raw_data().data); - jboolean res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - while (res != JNI_TRUE) { - RETURN_IF_ERROR(JniUtil::GetJniExceptionMsg(env)); - increase_buffer_size++; - buffer_size = JniUtil::IncreaseReservedBufferSize(increase_buffer_size); - try { - null_map_data.resize(buffer_size); - data_column->resize(buffer_size); - } catch (std::bad_alloc const& e) { - throw doris::Exception(ErrorCode::INTERNAL_ERROR, - "memory allocate failed in array number column, " - "buffer:{},size:{},reason:{}", - increase_buffer_size, buffer_size, e.what()); - } - *output_array_null_ptr = reinterpret_cast(null_map_data.data()); - *output_value_buffer = - reinterpret_cast(data_column->get_raw_data().data); - *output_intermediate_state_ptr = buffer_size; - res = env->CallNonvirtualBooleanMethod( - executor_obj, executor_cl, executor_result_id, to.size() - 1, place); - } + nested_data_address = reinterpret_cast(data_column->get_raw_data().data); + } + int row = to.size() - 1; + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_array_result_id, + has_put_element_size, result_nullable, row, result_obj, + nullmap_address, offset_address, nested_nullmap_address, + nested_data_address, nested_offset_address); + env->DeleteLocalRef(arraylist_class); + } else if (data_col.is_column_map()) { + jclass hashmap_class = env->FindClass("Ljava/util/HashMap;"); + ColumnMap* map_col = assert_cast(&data_col); + auto& offset_column = map_col->get_offsets_column(); + auto offset_address = reinterpret_cast(offset_column.get_raw_data().data); + ColumnNullable& map_key_column_nullable = + assert_cast(map_col->get_keys()); + auto key_data_column_null_map = map_key_column_nullable.get_null_map_column_ptr(); + auto key_data_column = map_key_column_nullable.get_nested_column_ptr(); + auto& key_null_map_data = + assert_cast*>(key_data_column_null_map.get())->get_data(); + auto key_nested_nullmap_address = reinterpret_cast(key_null_map_data.data()); + ColumnNullable& map_value_column_nullable = + assert_cast(map_col->get_values()); + auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr(); + auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); + auto& value_null_map_data = + assert_cast*>(value_data_column_null_map.get())->get_data(); + auto value_nested_nullmap_address = + reinterpret_cast(value_null_map_data.data()); + jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I"); + size_t has_put_element_size = map_col->get_offsets().back(); + size_t hashmap_size = env->CallIntMethod(result_obj, map_size); + size_t element_size = has_put_element_size + hashmap_size; + map_key_column_nullable.resize(element_size); + memset(key_null_map_data.data() + has_put_element_size, 0, hashmap_size); + map_value_column_nullable.resize(element_size); + memset(value_null_map_data.data() + has_put_element_size, 0, hashmap_size); + + int64_t key_nested_data_address = 0, key_nested_offset_address = 0; + if (key_data_column->is_column_string()) { + ColumnString* str_col = assert_cast(key_data_column.get()); + ColumnString::Chars& chars = + assert_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + assert_cast(str_col->get_offsets()); + key_nested_data_address = reinterpret_cast(&chars); + key_nested_offset_address = reinterpret_cast(offsets.data()); + } else { + key_nested_data_address = + reinterpret_cast(key_data_column->get_raw_data().data); } + + int64_t value_nested_data_address = 0, value_nested_offset_address = 0; + if (value_data_column->is_column_string()) { + ColumnString* str_col = assert_cast(value_data_column.get()); + ColumnString::Chars& chars = + assert_cast(str_col->get_chars()); + ColumnString::Offsets& offsets = + assert_cast(str_col->get_offsets()); + value_nested_data_address = reinterpret_cast(&chars); + value_nested_offset_address = reinterpret_cast(offsets.data()); + } else { + value_nested_data_address = + reinterpret_cast(value_data_column->get_raw_data().data); + } + int row = to.size() - 1; + env->CallNonvirtualVoidMethod(executor_obj, executor_cl, executor_copy_map_result_id, + has_put_element_size, result_nullable, row, result_obj, + nullmap_address, offset_address, + key_nested_nullmap_address, key_nested_data_address, + key_nested_offset_address, value_nested_nullmap_address, + value_nested_data_address, value_nested_offset_address); + env->DeleteLocalRef(hashmap_class); } else { return Status::InvalidArgument(strings::Substitute( - "Java UDAF doesn't support return type is $0 now !", result_type->get_name())); + "Java UDAF doesn't support return type is $0 now !", return_type->get_name())); } return Status::OK(); } @@ -438,14 +435,12 @@ struct AggregateJavaUdafData { return s; }; RETURN_IF_ERROR(register_id("", UDAF_EXECUTOR_CTOR_SIGNATURE, executor_ctor_id)); - RETURN_IF_ERROR(register_id("add", UDAF_EXECUTOR_ADD_SIGNATURE, executor_add_id)); RETURN_IF_ERROR(register_id("reset", UDAF_EXECUTOR_RESET_SIGNATURE, executor_reset_id)); RETURN_IF_ERROR(register_id("close", UDAF_EXECUTOR_CLOSE_SIGNATURE, executor_close_id)); RETURN_IF_ERROR(register_id("merge", UDAF_EXECUTOR_MERGE_SIGNATURE, executor_merge_id)); RETURN_IF_ERROR( register_id("serialize", UDAF_EXECUTOR_SERIALIZE_SIGNATURE, executor_serialize_id)); - RETURN_IF_ERROR( - register_id("getValue", UDAF_EXECUTOR_RESULT_SIGNATURE, executor_result_id)); + RETURN_IF_ERROR(register_id("getValue", "(J)Ljava/lang/Object;", executor_get_value_id)); RETURN_IF_ERROR( register_id("destroy", UDAF_EXECUTOR_DESTROY_SIGNATURE, executor_destroy_id)); RETURN_IF_ERROR(register_id("convertBasicArguments", "(IZIIJJJ)[Ljava/lang/Object;", @@ -454,6 +449,16 @@ struct AggregateJavaUdafData { executor_convert_array_argument_id)); RETURN_IF_ERROR(register_id("convertMapArguments", "(IZIIJJJJJJJJ)[Ljava/lang/Object;", executor_convert_map_argument_id)); + + RETURN_IF_ERROR(register_id("copyTupleBasicResult", "(Ljava/lang/Object;IJJJJ)V", + executor_copy_basic_result_id)); + + RETURN_IF_ERROR(register_id("copyTupleArrayResult", "(JZILjava/lang/Object;JJJJJ)V", + executor_copy_array_result_id)); + + RETURN_IF_ERROR(register_id("copyTupleMapResult", "(JZILjava/lang/Object;JJJJJJJJ)V", + executor_copy_map_result_id)); + RETURN_IF_ERROR( register_id("addBatch", "(ZIIJI[Ljava/lang/Object;)V", executor_add_batch_id)); return Status::OK(); @@ -466,24 +471,19 @@ struct AggregateJavaUdafData { jobject executor_obj; jmethodID executor_ctor_id; - jmethodID executor_add_id; jmethodID executor_add_batch_id; jmethodID executor_merge_id; jmethodID executor_serialize_id; - jmethodID executor_result_id; + jmethodID executor_get_value_id; jmethodID executor_reset_id; jmethodID executor_close_id; jmethodID executor_destroy_id; jmethodID executor_convert_basic_argument_id; jmethodID executor_convert_array_argument_id; jmethodID executor_convert_map_argument_id; - std::unique_ptr output_value_buffer; - std::unique_ptr output_null_value; - std::unique_ptr output_offsets_ptr; - std::unique_ptr output_intermediate_state_ptr; - std::unique_ptr output_array_null_ptr; - std::unique_ptr output_array_string_offsets_ptr; - + jmethodID executor_copy_basic_result_id; + jmethodID executor_copy_array_result_id; + jmethodID executor_copy_map_result_id; int argument_size = 0; std::string serialize_data; }; diff --git a/be/src/vec/functions/function_java_udf.cpp b/be/src/vec/functions/function_java_udf.cpp index a2d41245517046..7c50e74117716b 100644 --- a/be/src/vec/functions/function_java_udf.cpp +++ b/be/src/vec/functions/function_java_udf.cpp @@ -346,6 +346,27 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, auto& key_null_map_data = assert_cast*>(key_data_column_null_map.get())->get_data(); auto key_nested_nullmap_address = reinterpret_cast(key_null_map_data.data()); + ColumnNullable& map_value_column_nullable = + assert_cast(map_col->get_values()); + auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr(); + auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); + auto& value_null_map_data = + assert_cast*>(value_data_column_null_map.get())->get_data(); + auto value_nested_nullmap_address = reinterpret_cast(value_null_map_data.data()); + jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I"); + int element_size = 0; // get all element size in num_rows of map column + for (int i = 0; i < num_rows; ++i) { + jobject obj = env->GetObjectArrayElement(result_obj, i); + if (obj == nullptr) { + continue; + } + element_size = element_size + env->CallIntMethod(obj, map_size); + env->DeleteLocalRef(obj); + } + map_key_column_nullable.resize(element_size); + memset(key_null_map_data.data(), 0, element_size); + map_value_column_nullable.resize(element_size); + memset(value_null_map_data.data(), 0, element_size); int64_t key_nested_data_address = 0, key_nested_offset_address = 0; if (key_data_column->is_column_string()) { ColumnString* str_col = assert_cast(key_data_column.get()); @@ -358,16 +379,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, key_nested_data_address = reinterpret_cast(key_data_column->get_raw_data().data); } - - ColumnNullable& map_value_column_nullable = - assert_cast(map_col->get_values()); - auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr(); - auto value_data_column = map_value_column_nullable.get_nested_column_ptr(); - auto& value_null_map_data = - assert_cast*>(value_data_column_null_map.get())->get_data(); - auto value_nested_nullmap_address = reinterpret_cast(value_null_map_data.data()); int64_t value_nested_data_address = 0, value_nested_offset_address = 0; - // array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address if (value_data_column->is_column_string()) { ColumnString* str_col = assert_cast(value_data_column.get()); ColumnString::Chars& chars = assert_cast(str_col->get_chars()); @@ -379,20 +391,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block, value_nested_data_address = reinterpret_cast(value_data_column->get_raw_data().data); } - jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I"); - int element_size = 0; // get all element size in num_rows of map column - for (int i = 0; i < num_rows; ++i) { - jobject obj = env->GetObjectArrayElement(result_obj, i); - if (obj == nullptr) { - continue; - } - element_size = element_size + env->CallIntMethod(obj, map_size); - env->DeleteLocalRef(obj); - } - map_key_column_nullable.resize(element_size); - memset(key_null_map_data.data(), 0, element_size); - map_value_column_nullable.resize(element_size); - memset(value_null_map_data.data(), 0, element_size); env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl, jni_env->executor_result_map_batch_id, result_nullable, num_rows, result_obj, nullmap_address, offset_address, diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java index eae5270872cc3f..20f36866c8b168 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/BaseExecutor.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.exception.InternalException; import org.apache.doris.common.exception.UdfRuntimeException; +import org.apache.doris.common.jni.utils.JNINativeMethod; import org.apache.doris.common.jni.utils.UdfUtils; import org.apache.doris.common.jni.utils.UdfUtils.JavaUdfDataType; import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; @@ -73,24 +74,6 @@ public abstract class BaseExecutor { // The JavaUdfDataType enum maps it to corresponding primitive type. protected JavaUdfDataType[] argTypes; protected JavaUdfDataType retType; - - // Input buffer from the backend. This is valid for the duration of an - // evaluate() call. - // These buffers are allocated in the BE. - protected final long inputBufferPtrs; - protected final long inputNullsPtrs; - protected final long inputOffsetsPtrs; - protected final long inputArrayNullsPtrs; - protected final long inputArrayStringOffsetsPtrs; - - // Output buffer to return non-string values. These buffers are allocated in the - // BE. - protected final long outputBufferPtr; - protected final long outputNullPtr; - protected final long outputOffsetsPtr; - protected final long outputArrayNullPtr; - protected final long outputArrayStringOffsetsPtr; - protected final long outputIntermediateStatePtr; protected Class[] argClass; protected MethodAccess methodAccess; @@ -108,18 +91,6 @@ public BaseExecutor(byte[] thriftParams) throws Exception { } catch (TException e) { throw new InternalException(e.getMessage()); } - inputBufferPtrs = request.input_buffer_ptrs; - inputNullsPtrs = request.input_nulls_ptrs; - inputOffsetsPtrs = request.input_offsets_ptrs; - inputArrayNullsPtrs = request.input_array_nulls_buffer_ptr; - inputArrayStringOffsetsPtrs = request.input_array_string_offsets_ptrs; - outputBufferPtr = request.output_buffer_ptr; - outputNullPtr = request.output_null_ptr; - outputOffsetsPtr = request.output_offsets_ptr; - outputIntermediateStatePtr = request.output_intermediate_state_ptr; - outputArrayNullPtr = request.output_array_null_ptr; - outputArrayStringOffsetsPtr = request.output_array_string_offsets_ptr; - Type[] parameterTypes = new Type[request.fn.arg_types.size()]; for (int i = 0; i < request.fn.arg_types.size(); ++i) { parameterTypes[i] = Type.fromThrift(request.fn.arg_types.get(i)); @@ -132,359 +103,6 @@ public BaseExecutor(byte[] thriftParams) throws Exception { protected abstract void init(TJavaUdfExecutorCtorParams request, String jarPath, Type funcRetType, Type... parameterTypes) throws UdfRuntimeException; - protected Object[] allocateInputObjects(long row, int argClassOffset) throws UdfRuntimeException { - Object[] inputObjects = new Object[argTypes.length]; - - for (int i = 0; i < argTypes.length; ++i) { - if (UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) != -1 - && (UdfUtils.UNSAFE.getByte(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputNullsPtrs, i)) + row) == 1)) { - inputObjects[i] = null; - continue; - } - switch (argTypes[i]) { - case BOOLEAN: - inputObjects[i] = UdfUtils.UNSAFE.getBoolean(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case TINYINT: - inputObjects[i] = UdfUtils.UNSAFE.getByte(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) + row); - break; - case SMALLINT: - inputObjects[i] = UdfUtils.UNSAFE.getShort(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case INT: - inputObjects[i] = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case BIGINT: - inputObjects[i] = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case FLOAT: - inputObjects[i] = UdfUtils.UNSAFE.getFloat(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DOUBLE: - inputObjects[i] = UdfUtils.UNSAFE.getDouble(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - break; - case DATE: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateToJavaDate(data, argClass[i + argClassOffset]); - break; - } - case DATETIME: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeToJavaDateTime(data, argClass[i + argClassOffset]); - break; - } - case DATEV2: { - int data = UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateV2ToJavaDate(data, argClass[i + argClassOffset]); - break; - } - case DATETIMEV2: { - long data = UdfUtils.UNSAFE.getLong(null, - UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row); - inputObjects[i] = UdfUtils.convertDateTimeV2ToJavaDateTime(data, argClass[i + argClassOffset]); - break; - } - case LARGEINT: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - inputObjects[i] = new BigInteger(UdfUtils.convertByteOrder(bytes)); - break; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - long base = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + argTypes[i].getLen() * row; - byte[] bytes = new byte[argTypes[i].getLen()]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, argTypes[i].getLen()); - - BigInteger value = new BigInteger(UdfUtils.convertByteOrder(bytes)); - inputObjects[i] = new BigDecimal(value, argTypes[i].getScale()); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - long offset = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * row)); - long numBytes = row == 0 ? offset - : offset - Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, i)) + 4L * (row - 1))); - long base = row == 0 - ? UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - : UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, i)) - + offset - numBytes; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - inputObjects[i] = new String(bytes, StandardCharsets.UTF_8); - break; - } - case ARRAY_TYPE: { - Type type = argTypes[i].getItemType(); - inputObjects[i] = arrayTypeInputData(type, i, row); - break; - } - default: - throw new UdfRuntimeException("Unsupported argument type: " + argTypes[i]); - } - } - return inputObjects; - } - - public ArrayList arrayTypeInputData(Type type, int argIdx, long row) - throws UdfRuntimeException { - long offsetStart = (row == 0) ? 0 - : Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, argIdx)) + 8L * (row - 1))); - long offsetEnd = Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, - UdfUtils.getAddressAtOffset(inputOffsetsPtrs, argIdx)) + 8L * row)); - long arrayNullMapBase = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputArrayNullsPtrs, argIdx)); - long arrayInputBufferBase = UdfUtils.UNSAFE.getLong(null, UdfUtils.getAddressAtOffset(inputBufferPtrs, argIdx)); - - switch (type.getPrimitiveType()) { - case BOOLEAN: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - boolean value = UdfUtils.UNSAFE.getBoolean(null, arrayInputBufferBase + offsetRow); - data.add(value); - } - } - return data; - } - case TINYINT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - byte value = UdfUtils.UNSAFE.getByte(null, arrayInputBufferBase + offsetRow); - data.add(value); - } - } - return data; - } - case SMALLINT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - short value = UdfUtils.UNSAFE.getShort(null, arrayInputBufferBase + 2L * offsetRow); - data.add(value); - } - } - return data; - } - case INT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - int value = UdfUtils.UNSAFE.getInt(null, arrayInputBufferBase + 4L * offsetRow); - data.add(value); - } - } - return data; - } - case BIGINT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - data.add(value); - } - } - return data; - } - case FLOAT: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - float value = UdfUtils.UNSAFE.getFloat(null, arrayInputBufferBase + 4L * offsetRow); - data.add(value); - } - } - return data; - } - case DOUBLE: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - double value = UdfUtils.UNSAFE.getDouble(null, arrayInputBufferBase + 8L * offsetRow); - data.add(value); - } - } - return data; - } - case DATE: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - // TODO: now argClass[argIdx + argClassOffset] is java.util.ArrayList, can't get - // nested class type - // LocalDate obj = UdfUtils.convertDateToJavaDate(value, argClass[argIdx + - // argClassOffset]); - LocalDate obj = (LocalDate) UdfUtils.convertDateToJavaDate(value, LocalDate.class); - data.add(obj); - } - } - return data; - } - case DATETIME: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - // Object obj = UdfUtils.convertDateTimeToJavaDateTime(value, argClass[argIdx + - // argClassOffset]); - LocalDateTime obj = (LocalDateTime) UdfUtils.convertDateTimeToJavaDateTime(value, - LocalDateTime.class); - data.add(obj); - } - } - return data; - } - case DATEV2: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - int value = UdfUtils.UNSAFE.getInt(null, arrayInputBufferBase + 4L * offsetRow); - // Object obj = UdfUtils.convertDateV2ToJavaDate(value, argClass[argIdx + - // argClassOffset]); - LocalDate obj = (LocalDate) UdfUtils.convertDateV2ToJavaDate(value, LocalDate.class); - data.add(obj); - } - } - return data; - } - case DATETIMEV2: { - ArrayList data = new ArrayList<>(); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 8L * offsetRow); - LocalDateTime obj = (LocalDateTime) UdfUtils.convertDateTimeV2ToJavaDateTime(value, - LocalDateTime.class); - data.add(obj); - } - } - return data; - } - case LARGEINT: { - ArrayList data = new ArrayList<>(); - byte[] bytes = new byte[16]; - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + 16L * offsetRow); - UdfUtils.copyMemory(null, value, bytes, UdfUtils.BYTE_ARRAY_OFFSET, 16); - data.add(new BigInteger(UdfUtils.convertByteOrder(bytes))); - } - } - return data; - } - case DECIMALV2: - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - int len; - if (type.getPrimitiveType() == PrimitiveType.DECIMAL32) { - len = 4; - } else if (type.getPrimitiveType() == PrimitiveType.DECIMAL64) { - len = 8; - } else { - len = 16; - } - ArrayList data = new ArrayList<>(); - byte[] bytes = new byte[len]; - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long value = UdfUtils.UNSAFE.getLong(null, arrayInputBufferBase + len * offsetRow); - UdfUtils.copyMemory(null, value, bytes, UdfUtils.BYTE_ARRAY_OFFSET, len); - BigInteger bigInteger = new BigInteger(UdfUtils.convertByteOrder(bytes)); - data.add(new BigDecimal(bigInteger, argTypes[argIdx].getScale())); - } - } - return data; - } - case CHAR: - case VARCHAR: - case STRING: { - ArrayList data = new ArrayList<>(); - long strOffsetBase = UdfUtils.UNSAFE - .getLong(null, UdfUtils.getAddressAtOffset(inputArrayStringOffsetsPtrs, argIdx)); - for (long offsetRow = offsetStart; offsetRow < offsetEnd; ++offsetRow) { - if ((UdfUtils.UNSAFE.getByte(null, arrayNullMapBase + offsetRow) == 1)) { - data.add(null); - } else { - long stringOffsetStart = (offsetRow == 0) ? 0 - : Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, strOffsetBase + 4L * (offsetRow - 1))); - long stringOffsetEnd = Integer - .toUnsignedLong(UdfUtils.UNSAFE.getInt(null, strOffsetBase + 4L * offsetRow)); - - long numBytes = stringOffsetEnd - stringOffsetStart; - long base = arrayInputBufferBase + stringOffsetStart; - byte[] bytes = new byte[(int) numBytes]; - UdfUtils.copyMemory(null, base, bytes, UdfUtils.BYTE_ARRAY_OFFSET, numBytes); - data.add(new String(bytes, StandardCharsets.UTF_8)); - } - } - return data; - } - default: - throw new UdfRuntimeException("Unsupported argument type in nested array: " + type); - } - } - - protected abstract long getCurrentOutputOffset(long row, boolean isArrayType); - /** * Close the class loader we may have created. */ @@ -502,76 +120,74 @@ public void close() { classLoader = null; } - // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ - protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) != -1) { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 0); - } + public void copyTupleBasicResult(Object obj, long row, Class retClass, + long outputBufferBase, long charsAddress, long offsetsAddr, JavaUdfDataType retType) + throws UdfRuntimeException { switch (retType) { case BOOLEAN: { boolean val = (boolean) obj; - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putByte(outputBufferBase + row * retType.getLen(), val ? (byte) 1 : 0); - return true; + break; } case TINYINT: { - UdfUtils.UNSAFE.putByte(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putByte(outputBufferBase + row * retType.getLen(), (byte) obj); - return true; + break; } case SMALLINT: { - UdfUtils.UNSAFE.putShort(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putShort(outputBufferBase + row * retType.getLen(), (short) obj); - return true; + break; } case INT: { - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putInt(outputBufferBase + row * retType.getLen(), (int) obj); - return true; + break; } case BIGINT: { - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), (long) obj); - return true; + break; } case FLOAT: { - UdfUtils.UNSAFE.putFloat(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putFloat(outputBufferBase + row * retType.getLen(), (float) obj); - return true; + break; } case DOUBLE: { - UdfUtils.UNSAFE.putDouble(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), + UdfUtils.UNSAFE.putDouble(outputBufferBase + row * retType.getLen(), (double) obj); - return true; + break; } case DATE: { long time = UdfUtils.convertToDate(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), time); + break; } case DATETIME: { long time = UdfUtils.convertToDateTime(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), time); + break; } case DATEV2: { int time = UdfUtils.convertToDateV2(obj, retClass); - UdfUtils.UNSAFE.putInt(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putInt(outputBufferBase + row * retType.getLen(), time); + break; } case DATETIMEV2: { long time = UdfUtils.convertToDateTimeV2(obj, retClass); - UdfUtils.UNSAFE.putLong(UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), time); - return true; + UdfUtils.UNSAFE.putLong(outputBufferBase + row * retType.getLen(), time); + break; } case LARGEINT: { BigInteger data = (BigInteger) obj; byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //here value is 16 bytes, so if result data greater than the maximum of 16 bytes - //it will return a wrong num to backend; + // here value is 16 bytes, so if result data greater than the maximum of 16 + // bytesit will return a wrong num to backend; byte[] value = new byte[16]; - //check data is negative + // check data is negative if (data.signum() == -1) { Arrays.fill(value, (byte) -1); } @@ -580,14 +196,14 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; + outputBufferBase + row * retType.getLen(), value.length); + break; } case DECIMALV2: { BigDecimal retValue = ((BigDecimal) obj).setScale(9, RoundingMode.HALF_EVEN); BigInteger data = retValue.unscaledValue(); byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle + // TODO: here is maybe overflow also, and may find a better way to handle byte[] value = new byte[16]; if (data.signum() == -1) { Arrays.fill(value, (byte) -1); @@ -598,8 +214,8 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; + outputBufferBase + row * retType.getLen(), value.length); + break; } case DECIMAL32: case DECIMAL64: @@ -607,7 +223,7 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud BigDecimal retValue = ((BigDecimal) obj).setScale(retType.getScale(), RoundingMode.HALF_EVEN); BigInteger data = retValue.unscaledValue(); byte[] bytes = UdfUtils.convertByteOrder(data.toByteArray()); - //TODO: here is maybe overflow also, and may find a better way to handle + // TODO: here is maybe overflow also, and may find a better way to handle byte[] value = new byte[retType.getLen()]; if (data.signum() == -1) { Arrays.fill(value, (byte) -1); @@ -618,413 +234,29 @@ protected boolean storeUdfResult(Object obj, long row, Class retClass) throws Ud } UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + row * retType.getLen(), value.length); - return true; + outputBufferBase + row * retType.getLen(), value.length); + break; } case CHAR: case VARCHAR: case STRING: { - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); byte[] bytes = ((String) obj).getBytes(StandardCharsets.UTF_8); - long offset = getCurrentOutputOffset(row, false); - if (offset + bytes.length > bufferSize) { - return false; - } + long offset = UdfUtils.UNSAFE.getInt(null, offsetsAddr + 4L * (row - 1)); + int needLen = (int) (offset + bytes.length); + outputBufferBase = JNINativeMethod.resizeStringColumn(charsAddress, needLen); offset += bytes.length; - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * row, - Integer.parseUnsignedInt(String.valueOf(offset))); - UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - UdfUtils.UNSAFE.getLong(null, outputBufferPtr) + offset - bytes.length, bytes.length); + UdfUtils.UNSAFE.putInt(null, offsetsAddr + 4L * row, Integer.parseUnsignedInt(String.valueOf(offset))); + UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, outputBufferBase + offset - bytes.length, + bytes.length); updateOutputOffset(offset); - return true; - } - case ARRAY_TYPE: { - Type type = retType.getItemType(); - return arrayTypeOutputData(obj, type, row); + break; } + case ARRAY_TYPE: default: throw new UdfRuntimeException("Unsupported return type: " + retType); } } - public boolean arrayTypeOutputData(Object obj, Type type, long row) throws UdfRuntimeException { - long offset = getCurrentOutputOffset(row, true); - long bufferSize = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr); - long outputNullMapBase = UdfUtils.UNSAFE.getLong(null, outputArrayNullPtr); - long outputBufferBase = UdfUtils.UNSAFE.getLong(null, outputBufferPtr); - switch (type.getPrimitiveType()) { - case BOOLEAN: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Boolean value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putByte(outputBufferBase + (offset + i), value ? (byte) 1 : 0); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case TINYINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Byte value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putByte(outputBufferBase + (offset + i), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case SMALLINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Short value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putShort(outputBufferBase + ((offset + i) * 2L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case INT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Integer value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putInt(outputBufferBase + ((offset + i) * 4L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case BIGINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Long value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case FLOAT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Float value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putFloat(outputBufferBase + ((offset + i) * 4L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DOUBLE: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - Double value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - UdfUtils.UNSAFE.putDouble(outputBufferBase + ((offset + i) * 8L), value); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATE: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDate value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - long time = UdfUtils.convertToDate(value, LocalDate.class); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATETIME: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDateTime value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - long time = UdfUtils.convertToDateTime(value, LocalDateTime.class); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATEV2: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDate value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - int time = UdfUtils.convertToDateV2(value, LocalDate.class); - UdfUtils.UNSAFE.putInt(outputBufferBase + ((offset + i) * 4L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DATETIMEV2: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - LocalDateTime value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - long time = UdfUtils.convertToDateTimeV2(value, LocalDateTime.class); - UdfUtils.UNSAFE.putLong(outputBufferBase + ((offset + i) * 8L), time); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case LARGEINT: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - BigInteger bigInteger = data.get(i); - if (bigInteger == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - byte[] bytes = UdfUtils.convertByteOrder(bigInteger.toByteArray()); - byte[] value = new byte[16]; - // check data is negative - if (bigInteger.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + ((offset + i) * 16L), value.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DECIMALV2: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - BigDecimal bigDecimal = data.get(i); - if (bigDecimal == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - BigInteger bigInteger = bigDecimal.setScale(9, RoundingMode.HALF_EVEN).unscaledValue(); - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - byte[] bytes = UdfUtils.convertByteOrder(bigInteger.toByteArray()); - byte[] value = new byte[16]; - // check data is negative - if (bigInteger.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + ((offset + i) * 16L), value.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case DECIMAL32: - case DECIMAL64: - case DECIMAL128: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - for (int i = 0; i < num; ++i) { - BigDecimal bigDecimal = data.get(i); - if (bigDecimal == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - BigInteger bigInteger = bigDecimal.setScale(retType.getScale(), RoundingMode.HALF_EVEN) - .unscaledValue(); - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - byte[] bytes = UdfUtils.convertByteOrder(bigInteger.toByteArray()); - byte[] value = new byte[16]; - // check data is negative - if (bigInteger.signum() == -1) { - Arrays.fill(value, (byte) -1); - } - for (int index = 0; index < Math.min(bytes.length, value.length); ++index) { - value[index] = bytes[index]; - } - UdfUtils.copyMemory(value, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + ((offset + i) * 16L), value.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - case CHAR: - case VARCHAR: - case STRING: { - ArrayList data = (ArrayList) obj; - int num = data.size(); - if (offset + num > bufferSize) { - return false; - } - long outputStrOffsetBase = UdfUtils.UNSAFE.getLong(null, outputArrayStringOffsetsPtr); - for (int i = 0; i < num; ++i) { - String value = data.get(i); - if (value == null) { - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 1); - } else { - byte[] bytes = value.getBytes(StandardCharsets.UTF_8); - long strOffset = (offset + i == 0) ? 0 - : Integer.toUnsignedLong(UdfUtils.UNSAFE.getInt(null, - outputStrOffsetBase + ((offset + i - 1) * 4L))); - if (strOffset + bytes.length > bufferSize) { - return false; - } - UdfUtils.UNSAFE.putByte(outputNullMapBase + (offset + i), (byte) 0); - strOffset += bytes.length; - UdfUtils.UNSAFE.putInt(null, outputStrOffsetBase + 4L * (offset + i), - Integer.parseUnsignedInt(String.valueOf(strOffset))); - UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, - outputBufferBase + strOffset - bytes.length, bytes.length); - } - } - offset += num; - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(offset))); - updateOutputOffset(offset); - return true; - } - default: - throw new UdfRuntimeException("Unsupported argument type in nested array: " + type); - } - } protected void updateOutputOffset(long offset) { } @@ -1556,120 +788,129 @@ public void copyBatchArrayResultImpl(boolean isNullable, int numRows, Object[] r PrimitiveType type) { long hasPutElementNum = 0; for (int row = 0; row < numRows; ++row) { - switch (type) { - case BOOLEAN: { - hasPutElementNum = UdfConvert - .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case TINYINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case SMALLINT: { - hasPutElementNum = UdfConvert - .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case INT: { - hasPutElementNum = UdfConvert - .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case BIGINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case LARGEINT: { - hasPutElementNum = UdfConvert - .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case FLOAT: { - hasPutElementNum = UdfConvert - .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DOUBLE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case CHAR: - case VARCHAR: - case STRING: { - hasPutElementNum = UdfConvert - .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); - break; - } - case DATE: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIME: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DATETIMEV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMALV2: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL32: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL64: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - case DECIMAL128: { - hasPutElementNum = UdfConvert - .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, - result, nullMapAddr, - offsetsAddr, nestedNullMapAddr, dataAddr); - break; - } - default: { - Preconditions.checkState(false, "Not support type in array: " + retType); - break; - } + hasPutElementNum = copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result[row], nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr, type); + } + } + + public long copyTupleArrayResultImpl(long hasPutElementNum, boolean isNullable, int row, Object result, + long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr, + PrimitiveType type) { + switch (type) { + case BOOLEAN: { + hasPutElementNum = UdfConvert + .copyBatchArrayBooleanResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case TINYINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayTinyIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case SMALLINT: { + hasPutElementNum = UdfConvert + .copyBatchArraySmallIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case INT: { + hasPutElementNum = UdfConvert + .copyBatchArrayIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case BIGINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayBigIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case LARGEINT: { + hasPutElementNum = UdfConvert + .copyBatchArrayLargeIntResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case FLOAT: { + hasPutElementNum = UdfConvert + .copyBatchArrayFloatResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DOUBLE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDoubleResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case CHAR: + case VARCHAR: + case STRING: { + hasPutElementNum = UdfConvert + .copyBatchArrayStringResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr, strOffsetAddr); + break; + } + case DATE: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIME: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DATETIMEV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDateTimeV2Result(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMALV2: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalResult(hasPutElementNum, isNullable, row, result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL32: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 4L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL64: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 8L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + case DECIMAL128: { + hasPutElementNum = UdfConvert + .copyBatchArrayDecimalV3Result(retType.getScale(), 16L, hasPutElementNum, isNullable, row, + result, nullMapAddr, + offsetsAddr, nestedNullMapAddr, dataAddr); + break; + } + default: { + Preconditions.checkState(false, "Not support type in array: " + retType); + break; } } + return hasPutElementNum; } public void buildArrayListFromHashMap(Object[] result, PrimitiveType keyType, PrimitiveType valueType, diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java index dff689ed4038d6..fa19ad32888d2f 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdafExecutor.java @@ -163,43 +163,6 @@ public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset, } } - /** - * invoke add function, add row in loop [rowStart, rowEnd). - */ - public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException { - try { - long idx = rowStart; - do { - Long curPlace = null; - if (isSinglePlace) { - curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr)); - } else { - curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx); - } - Object[] inputArgs = new Object[argTypes.length + 1]; - Object state = stateObjMap.get(curPlace); - if (state != null) { - inputArgs[0] = state; - } else { - Object newState = createAggState(); - stateObjMap.put(curPlace, newState); - inputArgs[0] = newState; - } - do { - Object[] inputObjects = allocateInputObjects(idx, 1); - for (int i = 0; i < argTypes.length; ++i) { - inputArgs[i + 1] = inputObjects[i]; - } - allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs); - idx++; - } while (isSinglePlace && idx < rowEnd); - } while (idx < rowEnd); - } catch (Exception e) { - LOG.warn("invoke add function meet some error: " + e.getCause().toString()); - throw new UdfRuntimeException("UDAF failed to add: ", e); - } - } - /** * invoke user create function to get obj. */ @@ -292,40 +255,71 @@ public void merge(long place, byte[] data) throws UdfRuntimeException { /** * invoke getValue to return finally result. */ - public boolean getValue(long row, long place) throws UdfRuntimeException { + + public Object getValue(long place) throws UdfRuntimeException { try { if (stateObjMap.get(place) == null) { stateObjMap.put(place, createAggState()); } - return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)), - row, retClass); + return allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)); } catch (Exception e) { LOG.warn("invoke getValue function meet some error: " + e.getCause().toString()); throw new UdfRuntimeException("UDAF failed to result", e); } } - @Override - protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { - if (obj == null) { - // If result is null, return true directly when row == 0 as we have already inserted default value. - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { + public void copyTupleBasicResult(Object result, int row, long outputNullMapPtr, long outputBufferBase, + long charsAddress, + long offsetsAddr) throws UdfRuntimeException { + if (result == null) { + // put null obj + if (outputNullMapPtr == -1) { throw new UdfRuntimeException("UDAF failed to store null data to not null column"); + } else { + UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 1); } - return true; + return; + } + try { + if (outputNullMapPtr != -1) { + UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 0); + } + copyTupleBasicResult(result, row, retClass, outputBufferBase, charsAddress, + offsetsAddr, retType); + } catch (UdfRuntimeException e) { + LOG.info(e.toString()); } - return super.storeUdfResult(obj, row, retClass); } - @Override - protected long getCurrentOutputOffset(long row, boolean isArrayType) { - if (isArrayType) { - return Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1))); - } else { - return Integer.toUnsignedLong( - UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1))); + public void copyTupleArrayResult(long hasPutElementNum, boolean isNullable, int row, Object result, + long nullMapAddr, + long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) throws UdfRuntimeException { + if (nullMapAddr > 0) { + UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0); + } + copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result, nullMapAddr, offsetsAddr, nestedNullMapAddr, + dataAddr, strOffsetAddr, retType.getItemType().getPrimitiveType()); + } + + public void copyTupleMapResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, + long offsetsAddr, + long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr, + long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) throws UdfRuntimeException { + if (nullMapAddr > 0) { + UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0); } + PrimitiveType keyType = retType.getKeyType().getPrimitiveType(); + PrimitiveType valueType = retType.getValueType().getPrimitiveType(); + Object[] keyCol = new Object[1]; + Object[] valueCol = new Object[1]; + Object[] resultArr = new Object[1]; + resultArr[0] = result; + buildArrayListFromHashMap(resultArr, keyType, valueType, keyCol, valueCol); + copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, + valueCol[0], nullMapAddr, offsetsAddr, + valueNsestedNullMapAddr, valueDataAddr, valueStrOffsetAddr, valueType); + copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, keyCol[0], nullMapAddr, offsetsAddr, + keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType); } @Override diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java index dc835408859fbe..7b3a151f0065af 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfConvert.java @@ -707,9 +707,9 @@ public static void copyBatchStringResult(boolean isNullable, int numRows, String //////////////////////////////////// copyBatchArray////////////////////////////////////////////////////////// - public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -741,9 +741,9 @@ public static long copyBatchArrayBooleanResult(long hasPutElementNum, boolean is return hasPutElementNum; } - public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -775,9 +775,9 @@ public static long copyBatchArrayTinyIntResult(long hasPutElementNum, boolean is return hasPutElementNum; } - public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -809,9 +809,9 @@ public static long copyBatchArraySmallIntResult(long hasPutElementNum, boolean i return hasPutElementNum; } - public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -843,9 +843,9 @@ public static long copyBatchArrayIntResult(long hasPutElementNum, boolean isNull return hasPutElementNum; } - public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -877,9 +877,9 @@ public static long copyBatchArrayBigIntResult(long hasPutElementNum, boolean isN return hasPutElementNum; } - public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -911,9 +911,9 @@ public static long copyBatchArrayFloatResult(long hasPutElementNum, boolean isNu return hasPutElementNum; } - public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -945,9 +945,9 @@ public static long copyBatchArrayDoubleResult(long hasPutElementNum, boolean isN return hasPutElementNum; } - public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -981,9 +981,9 @@ public static long copyBatchArrayDateResult(long hasPutElementNum, boolean isNul return hasPutElementNum; } - public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1017,9 +1017,9 @@ public static long copyBatchArrayDateTimeResult(long hasPutElementNum, boolean i return hasPutElementNum; } - public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1054,9 +1054,9 @@ public static long copyBatchArrayDateV2Result(long hasPutElementNum, boolean isN } public static long copyBatchArrayDateTimeV2Result(long hasPutElementNum, boolean isNullable, int row, - Object[] result, + Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1090,9 +1090,9 @@ public static long copyBatchArrayDateTimeV2Result(long hasPutElementNum, boolean return hasPutElementNum; } - public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1140,9 +1140,9 @@ public static long copyBatchArrayLargeIntResult(long hasPutElementNum, boolean i return hasPutElementNum; } - public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean isNullable, int row, Object[] result, + public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1194,9 +1194,9 @@ public static long copyBatchArrayDecimalResult(long hasPutElementNum, boolean is public static long copyBatchArrayDecimalV3Result(int scale, long typeLen, long hasPutElementNum, boolean isNullable, int row, - Object[] result, + Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1247,9 +1247,9 @@ public static long copyBatchArrayDecimalV3Result(int scale, long typeLen, long h } public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isNullable, int row, - Object[] result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr, + Object result, long nullMapAddr, long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) { - ArrayList data = (ArrayList) result[row]; + ArrayList data = (ArrayList) result; if (isNullable) { if (data == null) { UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 1); @@ -1270,8 +1270,12 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN offset += byteRes[i].length; offsets[i] = offset; } - byte[] bytes = new byte[offsets[num - 1] - oldOffsetNum]; - long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, offsets[num - 1]); + int oldSzie = 0; + if (num > 0) { + oldSzie = offsets[num - 1]; + } + byte[] bytes = new byte[oldSzie - oldOffsetNum]; + long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldSzie); int dst = 0; for (int i = 0; i < num; i++) { for (int j = 0; j < byteRes[i].length; j++) { @@ -1281,7 +1285,7 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN UdfUtils.copyMemory(offsets, UdfUtils.INT_ARRAY_OFFSET, null, strOffsetAddr + (4L * hasPutElementNum), num * 4L); UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, bytesAddr + oldOffsetNum, - offsets[num - 1] - oldOffsetNum); + oldSzie - oldOffsetNum); hasPutElementNum = hasPutElementNum + num; } } else { @@ -1300,9 +1304,13 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN offset += byteRes[i].length; offsets[i] = offset; } - byte[] bytes = new byte[offsets[num - 1]]; int oldOffsetNum = UdfUtils.UNSAFE.getInt(null, strOffsetAddr + ((hasPutElementNum - 1) * 4L)); - long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldOffsetNum + offsets[num - 1]); + int oldSzie = 0; + if (num > 0) { + oldSzie = offsets[num - 1]; + } + byte[] bytes = new byte[oldSzie]; + long bytesAddr = JNINativeMethod.resizeStringColumn(dataAddr, oldOffsetNum + oldSzie); int dst = 0; for (int i = 0; i < num; i++) { for (int j = 0; j < byteRes[i].length; j++) { @@ -1312,7 +1320,7 @@ public static long copyBatchArrayStringResult(long hasPutElementNum, boolean isN UdfUtils.copyMemory(offsets, UdfUtils.INT_ARRAY_OFFSET, null, strOffsetAddr + (4L * oldOffsetNum), num * 4L); UdfUtils.copyMemory(bytes, UdfUtils.BYTE_ARRAY_OFFSET, null, bytesAddr + oldOffsetNum, - offsets[num - 1]); + oldSzie); hasPutElementNum = hasPutElementNum + num; } UdfUtils.UNSAFE.putLong(null, offsetsAddr + 8L * row, hasPutElementNum); diff --git a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java index 2f6ca99fdd5120..a77b441b67d997 100644 --- a/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java +++ b/fe/be-java-extensions/java-udf/src/main/java/org/apache/doris/udf/UdfExecutor.java @@ -74,50 +74,6 @@ public void close() { super.close(); } - /** - * evaluate function called by the backend. The inputs to the UDF have - * been serialized to 'input' - */ - public void evaluate() throws UdfRuntimeException { - int batchSize = UdfUtils.UNSAFE.getInt(null, batchSizePtr); - try { - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.VARCHAR) - || retType.equals(JavaUdfDataType.CHAR) || retType.equals(JavaUdfDataType.ARRAY_TYPE) - || retType.equals(JavaUdfDataType.MAP_TYPE)) { - // If this udf return variable-size type (e.g.) String, we have to allocate output - // buffer multiple times until buffer size is enough to store output column. So we - // always begin with the last evaluated row instead of beginning of this batch. - rowIdx = UdfUtils.UNSAFE.getLong(null, outputIntermediateStatePtr + 8); - if (rowIdx == 0) { - outputOffset = 0L; - } - } else { - rowIdx = 0; - } - for (; rowIdx < batchSize; rowIdx++) { - inputObjects = allocateInputObjects(rowIdx, 0); - // `storeUdfResult` is called to store udf result to output column. If true - // is returned, current value is stored successfully. Otherwise, current result is - // not processed successfully (e.g. current output buffer is not large enough) so - // we break this loop directly. - if (!storeUdfResult(evaluate(inputObjects), rowIdx, method.getReturnType())) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); - return; - } - } - } catch (Exception e) { - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.ARRAY_TYPE) - || retType.equals(JavaUdfDataType.MAP_TYPE)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, batchSize); - } - throw new UdfRuntimeException("UDF::evaluate() ran into a problem.", e); - } - if (retType.equals(JavaUdfDataType.STRING) || retType.equals(JavaUdfDataType.ARRAY_TYPE) - || retType.equals(JavaUdfDataType.MAP_TYPE)) { - UdfUtils.UNSAFE.putLong(null, outputIntermediateStatePtr + 8, rowIdx); - } - } - public Object[] convertBasicArguments(int argIdx, boolean isNullable, int numRows, long nullMapAddr, long columnAddr, long strOffsetAddr) { return convertBasicArg(true, argIdx, isNullable, 0, numRows, nullMapAddr, columnAddr, strOffsetAddr); @@ -211,30 +167,6 @@ public Method getMethod() { return method; } - // Sets the result object 'obj' into the outputBufferPtr and outputNullPtr_ - @Override - protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException { - if (obj == null) { - if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) { - throw new UdfRuntimeException("UDF failed to store null data to not null column"); - } - UdfUtils.UNSAFE.putByte(null, UdfUtils.UNSAFE.getLong(null, outputNullPtr) + row, (byte) 1); - if (retType.equals(JavaUdfDataType.STRING)) { - UdfUtils.UNSAFE.putInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) - + 4L * row, Integer.parseUnsignedInt(String.valueOf(outputOffset))); - } else if (retType.equals(JavaUdfDataType.ARRAY_TYPE)) { - UdfUtils.UNSAFE.putLong(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * row, - Long.parseUnsignedLong(String.valueOf(outputOffset))); - } - return true; - } - return super.storeUdfResult(obj, row, retClass); - } - - @Override - protected long getCurrentOutputOffset(long row, boolean isArrayType) { - return outputOffset; - } @Override protected void updateOutputOffset(long offset) { diff --git a/fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java b/fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java deleted file mode 100644 index 7c725da50e00ee..00000000000000 --- a/fe/be-java-extensions/java-udf/src/test/java/org/apache/doris/udf/UdfExecutorTest.java +++ /dev/null @@ -1,600 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.udf; - -import org.apache.doris.common.jni.utils.UdfUtils; -import org.apache.doris.thrift.TFunction; -import org.apache.doris.thrift.TFunctionBinaryType; -import org.apache.doris.thrift.TFunctionName; -import org.apache.doris.thrift.TJavaUdfExecutorCtorParams; -import org.apache.doris.thrift.TPrimitiveType; -import org.apache.doris.thrift.TScalarFunction; -import org.apache.doris.thrift.TScalarType; -import org.apache.doris.thrift.TTypeDesc; -import org.apache.doris.thrift.TTypeNode; -import org.apache.doris.thrift.TTypeNodeType; - -import org.apache.thrift.TSerializer; -import org.apache.thrift.protocol.TBinaryProtocol; -import org.junit.Test; - -import java.math.BigDecimal; -import java.math.BigInteger; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; - -public class UdfExecutorTest { - - @Test - public void testDateTimeUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.DateTimeUdf"; - - TFunction fn = new TFunction(); - fn.setBinaryType(TFunctionBinaryType.JAVA_UDF); - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.setScalarType(new TScalarType(TPrimitiveType.INT)); - fn.setRetType(new TTypeDesc(Collections.singletonList(typeNode))); - - TTypeNode typeNodeArg = new TTypeNode(TTypeNodeType.SCALAR); - typeNodeArg.setScalarType(new TScalarType(TPrimitiveType.DATETIME)); - TTypeDesc typeDescArg = new TTypeDesc(Collections.singletonList(typeNodeArg)); - fn.arg_types = Arrays.asList(typeDescArg); - - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("DateTimeUdf"); - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 1; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(8 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - - long[] inputLongDateTime = - new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, - 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, - 565212791469375654L, 565494266446086310L}; - - for (int i = 0; i < batchSize; ++i) { - UdfUtils.UNSAFE.putLong(null, inputBuffer1 + i * 8, inputLongDateTime[i]); - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - } - - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor = new UdfExecutor(serializer.serialize(params)); - executor.evaluate(); - - for (int i = 0; i < batchSize; ++i) { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == (2000 + i)); - } - } - - @Test - public void testDecimalUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.DecimalUdf"; - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - TScalarType scalarType = new TScalarType(TPrimitiveType.DECIMALV2); - scalarType.setScale(9); - scalarType.setPrecision(27); - typeNode.scalar_type = scalarType; - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("DecimalUdf"); - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); - - long[] inputLong = - new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, - 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, - 565212791469375654L, 565494266446086310L}; - - BigDecimal[] decimalArray = new BigDecimal[10]; - for (int i = 0; i < batchSize; ++i) { - BigInteger temp = BigInteger.valueOf(inputLong[i]); - decimalArray[i] = new BigDecimal(temp, 9); - } - - BigDecimal decimal2 = new BigDecimal(BigInteger.valueOf(0L), 9); - byte[] intput2 = convertByteOrder(decimal2.unscaledValue().toByteArray()); - byte[] value2 = new byte[16]; - if (decimal2.signum() == -1) { - Arrays.fill(value2, (byte) -1); - } - for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) { - value2[index] = intput2[index]; - } - - for (int i = 0; i < batchSize; ++i) { - byte[] intput1 = convertByteOrder(decimalArray[i].unscaledValue().toByteArray()); - byte[] value1 = new byte[16]; - if (decimalArray[i].signum() == -1) { - Arrays.fill(value1, (byte) -1); - } - for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) { - value1[index] = intput1[index]; - } - UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length); - UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length); - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); - } - - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params)); - udfExecutor.evaluate(); - - for (int i = 0; i < batchSize; ++i) { - byte[] bytes = new byte[16]; - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length); - - BigInteger integer = new BigInteger(convertByteOrder(bytes)); - BigDecimal result = new BigDecimal(integer, 9); - assert (result.equals(decimalArray[i])); - } - } - - @Test - public void testConstantOneUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.ConstantOneUdf"; - - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.INT); - fn.ret_type = new TTypeDesc(Collections.singletonList(typeNode)); - fn.arg_types = new ArrayList<>(); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("ConstantOne"); - - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - params.setInputBufferPtrs(0); - params.setInputNullsPtrs(0); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = - new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor; - executor = new UdfExecutor(serializer.serialize(params)); - - executor.evaluate(); - for (int i = 0; i < 10; i++) { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == 1); - } - } - - @Test - public void testSimpleAddUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.SimpleAddUdf"; - - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.INT); - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("SimpleAdd"); - - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(4); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); - - for (int i = 0; i < batchSize; i++) { - UdfUtils.UNSAFE.putInt(null, inputBuffer1 + i * 4, i); - UdfUtils.UNSAFE.putInt(null, inputBuffer2 + i * 4, i); - - if (i % 2 == 0) { - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 1); - } else { - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - } - UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); - } - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = - new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor; - executor = new UdfExecutor(serializer.serialize(params)); - - executor.evaluate(); - for (int i = 0; i < batchSize; i++) { - if (i % 2 == 0) { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 1); - } else { - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - assert (UdfUtils.UNSAFE.getInt(outputBuffer + 4 * i) == i * 2); - } - } - } - - @Test - public void testStringConcatUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.StringConcatUdf"; - - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.STRING); - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("StringConcat"); - - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(32); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputIntermediateStatePtr = UdfUtils.UNSAFE.allocateMemory(8 * 2); - - String[] input1 = new String[batchSize]; - String[] input2 = new String[batchSize]; - long[] inputOffsets1 = new long[batchSize]; - long[] inputOffsets2 = new long[batchSize]; - long inputBufferSize1 = 0; - long inputBufferSize2 = 0; - for (int i = 0; i < batchSize; i++) { - input1[i] = "Input1_" + i; - input2[i] = "Input2_" + i; - inputOffsets1[i] = i == 0 ? input1[i].getBytes(StandardCharsets.UTF_8).length - : inputOffsets1[i - 1] + input1[i].getBytes(StandardCharsets.UTF_8).length; - inputOffsets2[i] = i == 0 ? input2[i].getBytes(StandardCharsets.UTF_8).length - : inputOffsets2[i - 1] + input2[i].getBytes(StandardCharsets.UTF_8).length; - inputBufferSize1 += input1[i].getBytes(StandardCharsets.UTF_8).length; - inputBufferSize2 += input2[i].getBytes(StandardCharsets.UTF_8).length; - } - // In our test case, output buffer is (8 + 1) bytes * batchSize - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + inputBufferSize2 + batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - long outputOffset = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - UdfUtils.UNSAFE.putLong(outputOffsetsPtr, outputOffset); - // reserved buffer size - UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr, inputBufferSize1 + inputBufferSize2 + batchSize); - // current row id - UdfUtils.UNSAFE.putLong(outputIntermediateStatePtr + 8, 0); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - params.setOutputOffsetsPtr(outputOffsetsPtr); - params.setOutputIntermediateStatePtr(outputIntermediateStatePtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputOffsetsPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize1 + batchSize); - long inputOffset1 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(inputBufferSize2 + batchSize); - long inputOffset2 = UdfUtils.UNSAFE.allocateMemory(4 * batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, -1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, -1); - UdfUtils.UNSAFE.putLong(inputOffsetsPtr, inputOffset1); - UdfUtils.UNSAFE.putLong(inputOffsetsPtr + 8, inputOffset2); - - for (int i = 0; i < batchSize; i++) { - if (i == 0) { - UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1, - input1[i].getBytes(StandardCharsets.UTF_8).length); - UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2, - input2[i].getBytes(StandardCharsets.UTF_8).length); - } else { - UdfUtils.copyMemory(input1[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + inputOffsets1[i - 1], - input1[i].getBytes(StandardCharsets.UTF_8).length); - UdfUtils.copyMemory(input2[i].getBytes(StandardCharsets.UTF_8), - UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + inputOffsets2[i - 1], - input2[i].getBytes(StandardCharsets.UTF_8).length); - } - UdfUtils.UNSAFE.putInt(null, inputOffset1 + 4L * i, - Integer.parseUnsignedInt(String.valueOf(inputOffsets1[i]))); - UdfUtils.UNSAFE.putInt(null, inputOffset2 + 4L * i, - Integer.parseUnsignedInt(String.valueOf(inputOffsets2[i]))); - } - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(inputOffsetsPtr); - - TBinaryProtocol.Factory factory = - new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor executor; - executor = new UdfExecutor(serializer.serialize(params)); - - executor.evaluate(); - for (int i = 0; i < batchSize; i++) { - byte[] bytes = new byte[input1[i].getBytes(StandardCharsets.UTF_8).length - + input2[i].getBytes(StandardCharsets.UTF_8).length]; - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - if (i == 0) { - UdfUtils.copyMemory(null, outputBuffer, bytes, UdfUtils.BYTE_ARRAY_OFFSET, - bytes.length); - } else { - long lastOffset = UdfUtils.UNSAFE.getInt(null, outputOffset + 4 * (i - 1)); - UdfUtils.copyMemory(null, outputBuffer + lastOffset, bytes, UdfUtils.BYTE_ARRAY_OFFSET, - bytes.length); - } - assert (new String(bytes, StandardCharsets.UTF_8).equals(input1[i] + input2[i])); - assert (UdfUtils.UNSAFE.getByte(null, outputNull + i) == 0); - } - } - - @Test - public void testLargeIntUdf() throws Exception { - TScalarFunction scalarFunction = new TScalarFunction(); - scalarFunction.symbol = "org.apache.doris.udf.LargeIntUdf"; - TFunction fn = new TFunction(); - fn.binary_type = TFunctionBinaryType.JAVA_UDF; - TTypeNode typeNode = new TTypeNode(TTypeNodeType.SCALAR); - typeNode.scalar_type = new TScalarType(TPrimitiveType.LARGEINT); - - TTypeDesc typeDesc = new TTypeDesc(Collections.singletonList(typeNode)); - - fn.ret_type = typeDesc; - fn.arg_types = Arrays.asList(typeDesc, typeDesc); - fn.scalar_fn = scalarFunction; - fn.name = new TFunctionName("LargeIntUdf"); - - long batchSizePtr = UdfUtils.UNSAFE.allocateMemory(8); - int batchSize = 10; - UdfUtils.UNSAFE.putInt(batchSizePtr, batchSize); - - TJavaUdfExecutorCtorParams params = new TJavaUdfExecutorCtorParams(); - params.setBatchSizePtr(batchSizePtr); - params.setFn(fn); - - long outputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8); - long outputNullPtr = UdfUtils.UNSAFE.allocateMemory(8); - - long outputBuffer = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long outputNull = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(outputBufferPtr, outputBuffer); - UdfUtils.UNSAFE.putLong(outputNullPtr, outputNull); - - params.setOutputBufferPtr(outputBufferPtr); - params.setOutputNullPtr(outputNullPtr); - - int numCols = 2; - long inputBufferPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - long inputNullPtr = UdfUtils.UNSAFE.allocateMemory(8 * numCols); - - long inputBuffer1 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull1 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - long inputBuffer2 = UdfUtils.UNSAFE.allocateMemory(16 * batchSize); - long inputNull2 = UdfUtils.UNSAFE.allocateMemory(batchSize); - - UdfUtils.UNSAFE.putLong(inputBufferPtr, inputBuffer1); - UdfUtils.UNSAFE.putLong(inputBufferPtr + 8, inputBuffer2); - UdfUtils.UNSAFE.putLong(inputNullPtr, inputNull1); - UdfUtils.UNSAFE.putLong(inputNullPtr + 8, inputNull2); - - long[] inputLong = - new long[] {562960991655690406L, 563242466632401062L, 563523941609111718L, 563805416585822374L, - 564086891562533030L, 564368366539243686L, 564649841515954342L, 564931316492664998L, - 565212791469375654L, 565494266446086310L}; - - BigInteger[] integerArray = new BigInteger[10]; - for (int i = 0; i < batchSize; ++i) { - integerArray[i] = BigInteger.valueOf(inputLong[i]); - } - BigInteger integer2 = BigInteger.valueOf(1L); - byte[] intput2 = convertByteOrder(integer2.toByteArray()); - byte[] value2 = new byte[16]; - if (integer2.signum() == -1) { - Arrays.fill(value2, (byte) -1); - } - for (int index = 0; index < Math.min(intput2.length, value2.length); ++index) { - value2[index] = intput2[index]; - } - - for (int i = 0; i < batchSize; ++i) { - byte[] intput1 = convertByteOrder(integerArray[i].toByteArray()); - byte[] value1 = new byte[16]; - if (integerArray[i].signum() == -1) { - Arrays.fill(value1, (byte) -1); - } - for (int index = 0; index < Math.min(intput1.length, value1.length); ++index) { - value1[index] = intput1[index]; - } - UdfUtils.copyMemory(value1, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer1 + i * 16, value1.length); - UdfUtils.copyMemory(value2, UdfUtils.BYTE_ARRAY_OFFSET, null, inputBuffer2 + i * 16, value2.length); - UdfUtils.UNSAFE.putByte(null, inputNull1 + i, (byte) 0); - UdfUtils.UNSAFE.putByte(null, inputNull2 + i, (byte) 0); - } - - params.setInputBufferPtrs(inputBufferPtr); - params.setInputNullsPtrs(inputNullPtr); - params.setInputOffsetsPtrs(0); - - TBinaryProtocol.Factory factory = new TBinaryProtocol.Factory(); - TSerializer serializer = new TSerializer(factory); - - UdfExecutor udfExecutor = new UdfExecutor(serializer.serialize(params)); - udfExecutor.evaluate(); - - for (int i = 0; i < batchSize; ++i) { - byte[] bytes = new byte[16]; - assert (UdfUtils.UNSAFE.getByte(outputNull + i) == 0); - UdfUtils.copyMemory(null, outputBuffer + 16 * i, bytes, UdfUtils.BYTE_ARRAY_OFFSET, bytes.length); - BigInteger result = new BigInteger(convertByteOrder(bytes)); - assert (result.equals(integerArray[i].add(BigInteger.valueOf(1)))); - } - } - - public byte[] convertByteOrder(byte[] bytes) { - int length = bytes.length; - for (int i = 0; i < length / 2; ++i) { - byte temp = bytes[i]; - bytes[i] = bytes[length - 1 - i]; - bytes[length - 1 - i] = temp; - } - return bytes; - } -} diff --git a/regression-test/data/javaudf_p0/test_javaudaf_return_map.out b/regression-test/data/javaudf_p0/test_javaudaf_return_map.out new file mode 100644 index 00000000000000..1a4ff1b2bfd09d --- /dev/null +++ b/regression-test/data/javaudf_p0/test_javaudaf_return_map.out @@ -0,0 +1,31 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select_1 -- +{1:10, 2:20, 3:30, 4:40, 5:50} + +-- !select_2 -- +{1:0.01, 2:0.02, 3:0.03, 4:0.04, 5:0.05} + +-- !select_3 -- +{1:10} +{2:20} +{3:30} +{4:40} +{5:50} + +-- !select_4 -- +{1:0.01} +{2:0.02} +{3:0.03} +{4:0.04} +{5:0.05} + +-- !select_5 -- +{"2 114":"0.02 514", "3 114":"0.03 514", "1 114":"0.01 514", "5 114":"0.05 514", "4 114":"0.04 514"} + +-- !select_6 -- +{"1 114":"0.01 514"} +{"2 114":"0.02 514"} +{"3 114":"0.03 514"} +{"4 114":"0.04 514"} +{"5 114":"0.05 514"} + diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java new file mode 100644 index 00000000000000..a416a8371e4abc --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MyReturnMapString.java @@ -0,0 +1,75 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.*; + + +public class MyReturnMapString { + private static final Logger LOG = Logger.getLogger(MyReturnMapString.class); + public static class State { + public HashMap counter = new HashMap<>(); + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, Integer k, Double v) { + LOG.info("udaf nest k v " + k + " " + v); + state.counter.put(k, v); + } + + public void serialize(State state, DataOutputStream out) throws IOException { + int size = state.counter.size(); + out.writeInt(size); + for(Map.Entry it : state.counter.entrySet()){ + out.writeInt(it.getKey()); + out.writeDouble(it.getValue()); + } + } + + public void deserialize(State state, DataInputStream in) throws IOException { + int size = in.readInt(); + for (int i = 0; i < size; ++i) { + Integer key = in.readInt(); + Double value = in.readDouble(); + state.counter.put(key, value); + } + } + + public void merge(State state, State rhs) { + for(Map.Entry it : rhs.counter.entrySet()){ + state.counter.put(it.getKey(), it.getValue()); + } + } + + public HashMap getValue(State state) { + //sort for regression test + HashMap map = new HashMap<>(); + for(Map.Entry it : state.counter.entrySet()){ + map.put(it.getKey() + " 114", it.getValue() + " 514"); + } + return map; + } +} \ No newline at end of file diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java new file mode 100644 index 00000000000000..cab664ef36168e --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapInt.java @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.*; + + +public class MySumReturnMapInt { + private static final Logger LOG = Logger.getLogger(MySumReturnMapInt.class); + public static class State { + public HashMap counter = new HashMap<>(); + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, Integer val) { + if (val == null) { + return; + } + state.counter.put(val, 10 * val); + } + + public void serialize(State state, DataOutputStream out) throws IOException { + int size = state.counter.size(); + out.writeInt(size); + for(Map.Entry it : state.counter.entrySet()){ + out.writeInt(it.getKey()); + out.writeInt(it.getValue()); + } + } + + public void deserialize(State state, DataInputStream in) throws IOException { + int size = in.readInt(); + for (int i = 0; i < size; ++i) { + Integer key = in.readInt(); + Integer value = in.readInt(); + state.counter.put(key, value); + } + } + + public void merge(State state, State rhs) { + for(Map.Entry it : rhs.counter.entrySet()){ + state.counter.put(it.getKey(), it.getValue()); + } + } + + public HashMap getValue(State state) { + //sort for regression test + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java new file mode 100644 index 00000000000000..7a86666ef3b535 --- /dev/null +++ b/regression-test/java-udf-src/src/main/java/org/apache/doris/udf/MySumReturnMapIntDou.java @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package org.apache.doris.udf; +import org.apache.log4j.Logger; + +import com.carrotsearch.hppc.DoubleByteAssociativeContainer; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.*; + + +public class MySumReturnMapIntDou { + private static final Logger LOG = Logger.getLogger(MySumReturnMapIntDou.class); + public static class State { + public HashMap counter = new HashMap<>(); + } + + public State create() { + return new State(); + } + + public void destroy(State state) { + } + + public void add(State state, Integer k, Double v) { + LOG.info("udaf nest k v " + k + " " + v); + state.counter.put(k, v); + } + + public void serialize(State state, DataOutputStream out) throws IOException { + int size = state.counter.size(); + out.writeInt(size); + for(Map.Entry it : state.counter.entrySet()){ + out.writeInt(it.getKey()); + out.writeDouble(it.getValue()); + } + } + + public void deserialize(State state, DataInputStream in) throws IOException { + int size = in.readInt(); + for (int i = 0; i < size; ++i) { + Integer key = in.readInt(); + Double value = in.readDouble(); + state.counter.put(key, value); + } + } + + public void merge(State state, State rhs) { + for(Map.Entry it : rhs.counter.entrySet()){ + state.counter.put(it.getKey(), it.getValue()); + } + } + + public HashMap getValue(State state) { + //sort for regression test + return state.counter; + } +} \ No newline at end of file diff --git a/regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy b/regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy new file mode 100644 index 00000000000000..85b6d042a030ea --- /dev/null +++ b/regression-test/suites/javaudf_p0/test_javaudaf_return_map.groovy @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +import org.codehaus.groovy.runtime.IOGroovyMethods + +import java.nio.charset.StandardCharsets +import java.nio.file.Files +import java.nio.file.Paths + +suite("test_javaudaf_return_map") { + def jarPath = """${context.file.parent}/jars/java-udf-case-jar-with-dependencies.jar""" + log.info("Jar path: ${jarPath}".toString()) + try { + try_sql("DROP FUNCTION IF EXISTS aggmap(int);") + try_sql("DROP FUNCTION IF EXISTS aggmap2(int,double);") + try_sql("DROP FUNCTION IF EXISTS aggmap3(int,double);") + try_sql("DROP TABLE IF EXISTS aggdb") + sql """ + CREATE TABLE IF NOT EXISTS aggdb( + `id` INT NULL COMMENT "" , + `d` Double NULL COMMENT "" + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ); + """ + + + + sql """ INSERT INTO aggdb VALUES(1,0.01); """ + sql """ INSERT INTO aggdb VALUES(2,0.02); """ + sql """ INSERT INTO aggdb VALUES(3,0.03); """ + sql """ INSERT INTO aggdb VALUES(4,0.04); """ + sql """ INSERT INTO aggdb VALUES(5,0.05); """ + + + sql """ + + CREATE AGGREGATE FUNCTION aggmap(int) RETURNS Map PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumReturnMapInt", + "type"="JAVA_UDF" + ); + + """ + + sql """ + + CREATE AGGREGATE FUNCTION aggmap2(int,double) RETURNS Map PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MySumReturnMapIntDou", + "type"="JAVA_UDF" + ); + + + """ + + + sql """ + + CREATE AGGREGATE FUNCTION aggmap3(int,double) RETURNS Map PROPERTIES ( + "file"="file://${jarPath}", + "symbol"="org.apache.doris.udf.MyReturnMapString", + "type"="JAVA_UDF" + ); + + + """ + + qt_select_1 """ select aggmap(id) from aggdb; """ + + qt_select_2 """ select aggmap2(id,d) from aggdb; """ + + qt_select_3 """ select aggmap(id) from aggdb group by id; """ + + qt_select_4 """ select aggmap2(id,d) from aggdb group by id; """ + + qt_select_5 """ select aggmap3(id,d) from aggdb; """ + + qt_select_6 """ select aggmap3(id,d) from aggdb group by id; """ + } finally { + try_sql("DROP FUNCTION IF EXISTS aggmap(int);") + try_sql("DROP FUNCTION IF EXISTS aggmap2(int,double);") + try_sql("DROP FUNCTION IF EXISTS aggmap3(int,double);") + try_sql("DROP TABLE IF EXISTS aggdb") + } +}