Skip to content

Commit

Permalink
[VL] UDF: Support variable arity in function sigatures (#5495)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Apr 26, 2024
1 parent c191753 commit 1e07bde
Show file tree
Hide file tree
Showing 11 changed files with 472 additions and 130 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.types.DataTypeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -72,6 +73,25 @@ case class UserDefinedAggregateFunction(
}
}

trait UDFSignatureBase {
val expressionType: ExpressionType
val children: Seq[DataType]
val variableArity: Boolean
}

case class UDFSignature(
expressionType: ExpressionType,
children: Seq[DataType],
variableArity: Boolean)
extends UDFSignatureBase

case class UDAFSignature(
expressionType: ExpressionType,
children: Seq[DataType],
variableArity: Boolean,
intermediateAttrs: Seq[AttributeReference])
extends UDFSignatureBase

case class UDFExpression(
name: String,
dataType: DataType,
Expand Down Expand Up @@ -109,31 +129,40 @@ case class UDFExpression(
object UDFResolver extends Logging {
private val UDFNames = mutable.HashSet[String]()
// (udf_name, arg1, arg2, ...) => return type
private val UDFMap = mutable.HashMap[(String, Seq[DataType]), ExpressionType]()
private val UDFMap = mutable.HashMap[String, mutable.MutableList[UDFSignature]]()

private val UDAFNames = mutable.HashSet[String]()
// (udaf_name, arg1, arg2, ...) => return type, intermediate attributes
private val UDAFMap =
mutable.HashMap[(String, Seq[DataType]), (ExpressionType, Seq[AttributeReference])]()
mutable.HashMap[String, mutable.MutableList[UDAFSignature]]()

private val LIB_EXTENSION = ".so"

// Called by JNI.
def registerUDF(name: String, returnType: Array[Byte], argTypes: Array[Byte]): Unit = {
def registerUDF(
name: String,
returnType: Array[Byte],
argTypes: Array[Byte],
variableArity: Boolean): Unit = {
registerUDF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes))
ConverterUtils.parseFromBytes(argTypes),
variableArity)
}

private def registerUDF(
name: String,
returnType: ExpressionType,
argTypes: ExpressionType): Unit = {
argTypes: ExpressionType,
variableArity: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])
UDFMap.put(
(name, argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)),
returnType)
val v =
UDFMap.getOrElseUpdate(name, mutable.MutableList[UDFSignature]())
v += UDFSignature(
returnType,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
variableArity)
UDFNames += name
logInfo(s"Registered UDF: $name($argTypes) -> $returnType")
}
Expand All @@ -142,20 +171,23 @@ object UDFResolver extends Logging {
name: String,
returnType: Array[Byte],
argTypes: Array[Byte],
intermediateTypes: Array[Byte]): Unit = {
intermediateTypes: Array[Byte],
variableArity: Boolean): Unit = {
registerUDAF(
name,
ConverterUtils.parseFromBytes(returnType),
ConverterUtils.parseFromBytes(argTypes),
ConverterUtils.parseFromBytes(intermediateTypes)
ConverterUtils.parseFromBytes(intermediateTypes),
variableArity
)
}

private def registerUDAF(
name: String,
returnType: ExpressionType,
argTypes: ExpressionType,
intermediateTypes: ExpressionType): Unit = {
intermediateTypes: ExpressionType,
variableArity: Boolean): Unit = {
assert(argTypes.dataType.isInstanceOf[StructType])
assert(intermediateTypes.dataType.isInstanceOf[StructType])

Expand All @@ -164,10 +196,14 @@ object UDFResolver extends Logging {
case (f, index) =>
AttributeReference(s"inter_$index", f.dataType, f.nullable)()
}
UDAFMap.put(
(name, argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)),
(returnType, aggBufferAttributes)
)

val v =
UDAFMap.getOrElseUpdate(name, mutable.MutableList[UDAFSignature]())
v += UDAFSignature(
returnType,
argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType),
variableArity,
aggBufferAttributes)
UDAFNames += name
logInfo(s"Registered UDAF: $name($argTypes) -> $returnType")
}
Expand Down Expand Up @@ -319,30 +355,81 @@ object UDFResolver extends Logging {
}

private def getUdfExpression(name: String)(children: Seq[Expression]) = {
val expressionType =
UDFMap.getOrElse(
(name, children.map(_.dataType)),
throw new UnsupportedOperationException(
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} " +
s"is not registered.")
)
UDFExpression(name, expressionType.dataType, expressionType.nullable, children)
def errorMessage: String =
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val signatures =
UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage));

signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
case Some(sig) =>
UDFExpression(name, sig.expressionType.dataType, sig.expressionType.nullable, children)
case None =>
throw new UnsupportedOperationException(errorMessage)
}
}

private def getUdafExpression(name: String)(children: Seq[Expression]) = {
val (expressionType, aggBufferAttributes) =
def errorMessage: String =
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val signatures =
UDAFMap.getOrElse(
(name, children.map(_.dataType)),
throw new UnsupportedOperationException(
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} " +
s"is not registered.")
name,
throw new UnsupportedOperationException(errorMessage)
)

UserDefinedAggregateFunction(
name,
expressionType.dataType,
expressionType.nullable,
children,
aggBufferAttributes)
signatures.find(sig => tryBind(sig, children.map(_.dataType))) match {
case Some(sig) =>
UserDefinedAggregateFunction(
name,
sig.expressionType.dataType,
sig.expressionType.nullable,
children,
sig.intermediateAttrs)
case None =>
throw new UnsupportedOperationException(errorMessage)
}
}

// Returns true if required data types match the function signature.
// If the function signature is variable arity, the number of the last argument can be zero
// or more.
private def tryBind(sig: UDFSignatureBase, requiredDataTypes: Seq[DataType]): Boolean = {
if (!sig.variableArity) {
sig.children.size == requiredDataTypes.size &&
sig.children
.zip(requiredDataTypes)
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
} else {
// If variableArity is true, there must be at least one argument in the signature.
if (requiredDataTypes.size < sig.children.size - 1) {
false
} else if (requiredDataTypes.size == sig.children.size - 1) {
sig.children
.dropRight(1)
.zip(requiredDataTypes)
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
} else {
val varArgStartIndex = sig.children.size - 1
// First check all var args has the same type with the last argument of the signature.
if (
!requiredDataTypes
.drop(varArgStartIndex)
.forall(argType => DataTypeUtils.sameType(sig.children.last, argType))
) {
false
} else if (varArgStartIndex == 0) {
// No fixed args.
true
} else {
// Whether fixed args matches.
sig.children
.dropRight(1)
.zip(requiredDataTypes.dropRight(1 + requiredDataTypes.size - sig.children.size))
.forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) }
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,20 +71,24 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.set("spark.memory.offHeap.size", "1024MB")
}

testWithSpecifiedSparkVersion("test udf", Some("3.2")) {
test("test udf") {
val df = spark.sql("""select
| myudf1(1),
| myudf1(1L),
| myudf2(100L),
| myudf1(100L),
| myudf2(1),
| myudf2(1L),
| myudf3(),
| myudf3(1),
| myudf3(1, 2, 3),
| myudf3(1L),
| myudf3(1L, 2L, 3L),
| mydate(cast('2024-03-25' as date), 5)
|""".stripMargin)
df.collect()
assert(
df.collect()
.sameElements(Array(Row(6, 6L, 105, Date.valueOf("2024-03-30")))))
.sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, Date.valueOf("2024-03-30")))))
}

testWithSpecifiedSparkVersion("test udaf", Some("3.2")) {
test("test udaf") {
val df = spark.sql("""select
| myavg(1),
| myavg(1L),
Expand Down
9 changes: 5 additions & 4 deletions cpp/velox/jni/JniUdf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ void gluten::initVeloxJniUDF(JNIEnv* env) {
udfResolverClass = createGlobalClassReferenceOrError(env, kUdfResolverClassPath.c_str());

// methods
registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", "(Ljava/lang/String;[B[B)V");
registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDAF", "(Ljava/lang/String;[B[B[B)V");
registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", "(Ljava/lang/String;[B[BZ)V");
registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDAF", "(Ljava/lang/String;[B[B[BZ)V");
}

void gluten::finalizeVeloxJniUDF(JNIEnv* env) {
Expand Down Expand Up @@ -70,9 +70,10 @@ void gluten::jniGetFunctionSignatures(JNIEnv* env) {
0,
signature->intermediateType.length(),
reinterpret_cast<const jbyte*>(signature->intermediateType.c_str()));
env->CallVoidMethod(instance, registerUDAFMethod, name, returnType, argTypes, intermediateType);
env->CallVoidMethod(
instance, registerUDAFMethod, name, returnType, argTypes, intermediateType, signature->variableArity);
} else {
env->CallVoidMethod(instance, registerUDFMethod, name, returnType, argTypes);
env->CallVoidMethod(instance, registerUDFMethod, name, returnType, argTypes, signature->variableArity);
}
checkException(env);
}
Expand Down
1 change: 1 addition & 0 deletions cpp/velox/udf/Udaf.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ struct UdafEntry {
const char** argTypes;

const char* intermediateType{nullptr};
bool variableArity{false};
};

#define GLUTEN_GET_NUM_UDAF getNumUdaf
Expand Down
2 changes: 2 additions & 0 deletions cpp/velox/udf/Udf.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ struct UdfEntry {

size_t numArgs;
const char** argTypes;

bool variableArity{false};
};

#define GLUTEN_GET_NUM_UDF getNumUdf
Expand Down
31 changes: 27 additions & 4 deletions cpp/velox/udf/UdfLoader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,11 @@ std::unordered_set<std::shared_ptr<UdfLoader::UdfSignature>> UdfLoader::getRegis
const auto& entry = udfEntries[i];
auto dataType = toSubstraitTypeStr(entry.dataType);
auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes);
signatures_.insert(std::make_shared<UdfSignature>(entry.name, dataType, argTypes));
signatures_.insert(std::make_shared<UdfSignature>(entry.name, dataType, argTypes, entry.variableArity));
}
free(udfEntries);
} else {
LOG(INFO) << "No UDFs found in " << libPath;
LOG(INFO) << "No UDF found in " << libPath;
}

// Handle UDAFs.
Expand All @@ -110,11 +110,12 @@ std::unordered_set<std::shared_ptr<UdfLoader::UdfSignature>> UdfLoader::getRegis
auto dataType = toSubstraitTypeStr(entry.dataType);
auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes);
auto intermediateType = toSubstraitTypeStr(entry.intermediateType);
signatures_.insert(std::make_shared<UdfSignature>(entry.name, dataType, argTypes, intermediateType));
signatures_.insert(
std::make_shared<UdfSignature>(entry.name, dataType, argTypes, intermediateType, entry.variableArity));
}
free(udafEntries);
} else {
LOG(INFO) << "No UDAFs found in " << libPath;
LOG(INFO) << "No UDAF found in " << libPath;
}
}
return signatures_;
Expand Down Expand Up @@ -151,4 +152,26 @@ std::shared_ptr<UdfLoader> UdfLoader::getInstance() {
return instance;
}

std::string UdfLoader::toSubstraitTypeStr(const std::string& type) {
auto returnType = parser_.parse(type);
auto substraitType = convertor_.toSubstraitType(arena_, returnType);

std::string output;
substraitType.SerializeToString(&output);
return output;
}

std::string UdfLoader::toSubstraitTypeStr(int32_t numArgs, const char** args) {
std::vector<facebook::velox::TypePtr> argTypes;
argTypes.resize(numArgs);
for (auto i = 0; i < numArgs; ++i) {
argTypes[i] = parser_.parse(args[i]);
}
auto substraitType = convertor_.toSubstraitType(arena_, facebook::velox::ROW(std::move(argTypes)));

std::string output;
substraitType.SerializeToString(&output);
return output;
}

} // namespace gluten
Loading

0 comments on commit 1e07bde

Please sign in to comment.