Skip to content

Commit

Permalink
[refactor](udaf) refactor call udaf function and support map type in …
Browse files Browse the repository at this point in the history
…return (#22508)
  • Loading branch information
Mryange authored Aug 9, 2023
1 parent 0d75a54 commit 768088c
Show file tree
Hide file tree
Showing 12 changed files with 760 additions and 1,830 deletions.
246 changes: 123 additions & 123 deletions be/src/vec/aggregate_functions/aggregate_function_java_udaf.h

Large diffs are not rendered by default.

44 changes: 21 additions & 23 deletions be/src/vec/functions/function_java_udf.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,27 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
auto& key_null_map_data =
assert_cast<ColumnVector<UInt8>*>(key_data_column_null_map.get())->get_data();
auto key_nested_nullmap_address = reinterpret_cast<int64_t>(key_null_map_data.data());
ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(value_null_map_data.data());
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of map column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, map_size);
env->DeleteLocalRef(obj);
}
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data(), 0, element_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data(), 0, element_size);
int64_t key_nested_data_address = 0, key_nested_offset_address = 0;
if (key_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(key_data_column.get());
Expand All @@ -358,16 +379,7 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
key_nested_data_address =
reinterpret_cast<int64_t>(key_data_column->get_raw_data().data);
}

ColumnNullable& map_value_column_nullable =
assert_cast<ColumnNullable&>(map_col->get_values());
auto value_data_column_null_map = map_value_column_nullable.get_null_map_column_ptr();
auto value_data_column = map_value_column_nullable.get_nested_column_ptr();
auto& value_null_map_data =
assert_cast<ColumnVector<UInt8>*>(value_data_column_null_map.get())->get_data();
auto value_nested_nullmap_address = reinterpret_cast<int64_t>(value_null_map_data.data());
int64_t value_nested_data_address = 0, value_nested_offset_address = 0;
// array type need pass address: [nullmap_address], offset_address, nested_nullmap_address, nested_data_address/nested_char_address,nested_offset_address
if (value_data_column->is_column_string()) {
ColumnString* str_col = assert_cast<ColumnString*>(value_data_column.get());
ColumnString::Chars& chars = assert_cast<ColumnString::Chars&>(str_col->get_chars());
Expand All @@ -379,20 +391,6 @@ Status JavaFunctionCall::execute(FunctionContext* context, Block& block,
value_nested_data_address =
reinterpret_cast<int64_t>(value_data_column->get_raw_data().data);
}
jmethodID map_size = env->GetMethodID(hashmap_class, "size", "()I");
int element_size = 0; // get all element size in num_rows of map column
for (int i = 0; i < num_rows; ++i) {
jobject obj = env->GetObjectArrayElement(result_obj, i);
if (obj == nullptr) {
continue;
}
element_size = element_size + env->CallIntMethod(obj, map_size);
env->DeleteLocalRef(obj);
}
map_key_column_nullable.resize(element_size);
memset(key_null_map_data.data(), 0, element_size);
map_value_column_nullable.resize(element_size);
memset(value_null_map_data.data(), 0, element_size);
env->CallNonvirtualVoidMethod(jni_ctx->executor, jni_env->executor_cl,
jni_env->executor_result_map_batch_id, result_nullable,
num_rows, result_obj, nullmap_address, offset_address,
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -163,43 +163,6 @@ public void addBatchPlaces(int rowStart, int rowEnd, long placeAddr, int offset,
}
}

/**
* invoke add function, add row in loop [rowStart, rowEnd).
*/
public void add(boolean isSinglePlace, long rowStart, long rowEnd) throws UdfRuntimeException {
try {
long idx = rowStart;
do {
Long curPlace = null;
if (isSinglePlace) {
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr));
} else {
curPlace = UdfUtils.UNSAFE.getLong(null, UdfUtils.UNSAFE.getLong(null, inputPlacesPtr) + 8L * idx);
}
Object[] inputArgs = new Object[argTypes.length + 1];
Object state = stateObjMap.get(curPlace);
if (state != null) {
inputArgs[0] = state;
} else {
Object newState = createAggState();
stateObjMap.put(curPlace, newState);
inputArgs[0] = newState;
}
do {
Object[] inputObjects = allocateInputObjects(idx, 1);
for (int i = 0; i < argTypes.length; ++i) {
inputArgs[i + 1] = inputObjects[i];
}
allMethods.get(UDAF_ADD_FUNCTION).invoke(udf, inputArgs);
idx++;
} while (isSinglePlace && idx < rowEnd);
} while (idx < rowEnd);
} catch (Exception e) {
LOG.warn("invoke add function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to add: ", e);
}
}

/**
* invoke user create function to get obj.
*/
Expand Down Expand Up @@ -292,40 +255,71 @@ public void merge(long place, byte[] data) throws UdfRuntimeException {
/**
* invoke getValue to return finally result.
*/
public boolean getValue(long row, long place) throws UdfRuntimeException {

public Object getValue(long place) throws UdfRuntimeException {
try {
if (stateObjMap.get(place) == null) {
stateObjMap.put(place, createAggState());
}
return storeUdfResult(allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place)),
row, retClass);
return allMethods.get(UDAF_RESULT_FUNCTION).invoke(udf, stateObjMap.get((Long) place));
} catch (Exception e) {
LOG.warn("invoke getValue function meet some error: " + e.getCause().toString());
throw new UdfRuntimeException("UDAF failed to result", e);
}
}

@Override
protected boolean storeUdfResult(Object obj, long row, Class retClass) throws UdfRuntimeException {
if (obj == null) {
// If result is null, return true directly when row == 0 as we have already inserted default value.
if (UdfUtils.UNSAFE.getLong(null, outputNullPtr) == -1) {
public void copyTupleBasicResult(Object result, int row, long outputNullMapPtr, long outputBufferBase,
long charsAddress,
long offsetsAddr) throws UdfRuntimeException {
if (result == null) {
// put null obj
if (outputNullMapPtr == -1) {
throw new UdfRuntimeException("UDAF failed to store null data to not null column");
} else {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 1);
}
return true;
return;
}
try {
if (outputNullMapPtr != -1) {
UdfUtils.UNSAFE.putByte(outputNullMapPtr + row, (byte) 0);
}
copyTupleBasicResult(result, row, retClass, outputBufferBase, charsAddress,
offsetsAddr, retType);
} catch (UdfRuntimeException e) {
LOG.info(e.toString());
}
return super.storeUdfResult(obj, row, retClass);
}

@Override
protected long getCurrentOutputOffset(long row, boolean isArrayType) {
if (isArrayType) {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 8L * (row - 1)));
} else {
return Integer.toUnsignedLong(
UdfUtils.UNSAFE.getInt(null, UdfUtils.UNSAFE.getLong(null, outputOffsetsPtr) + 4L * (row - 1)));
public void copyTupleArrayResult(long hasPutElementNum, boolean isNullable, int row, Object result,
long nullMapAddr,
long offsetsAddr, long nestedNullMapAddr, long dataAddr, long strOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, result, nullMapAddr, offsetsAddr, nestedNullMapAddr,
dataAddr, strOffsetAddr, retType.getItemType().getPrimitiveType());
}

public void copyTupleMapResult(long hasPutElementNum, boolean isNullable, int row, Object result, long nullMapAddr,
long offsetsAddr,
long keyNsestedNullMapAddr, long keyDataAddr, long keyStrOffsetAddr,
long valueNsestedNullMapAddr, long valueDataAddr, long valueStrOffsetAddr) throws UdfRuntimeException {
if (nullMapAddr > 0) {
UdfUtils.UNSAFE.putByte(nullMapAddr + row, (byte) 0);
}
PrimitiveType keyType = retType.getKeyType().getPrimitiveType();
PrimitiveType valueType = retType.getValueType().getPrimitiveType();
Object[] keyCol = new Object[1];
Object[] valueCol = new Object[1];
Object[] resultArr = new Object[1];
resultArr[0] = result;
buildArrayListFromHashMap(resultArr, keyType, valueType, keyCol, valueCol);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row,
valueCol[0], nullMapAddr, offsetsAddr,
valueNsestedNullMapAddr, valueDataAddr, valueStrOffsetAddr, valueType);
copyTupleArrayResultImpl(hasPutElementNum, isNullable, row, keyCol[0], nullMapAddr, offsetsAddr,
keyNsestedNullMapAddr, keyDataAddr, keyStrOffsetAddr, keyType);
}

@Override
Expand Down
Loading

0 comments on commit 768088c

Please sign in to comment.