From 1e07bde4710961ae8e3457c7a07c51d03b08cf09 Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Fri, 26 Apr 2024 08:52:33 +0800 Subject: [PATCH] [VL] UDF: Support variable arity in function sigatures (#5495) --- .../spark/sql/expression/UDFResolver.scala | 155 ++++++++--- .../gluten/expression/VeloxUdfSuite.scala | 18 +- cpp/velox/jni/JniUdf.cc | 9 +- cpp/velox/udf/Udaf.h | 1 + cpp/velox/udf/Udf.h | 2 + cpp/velox/udf/UdfLoader.cc | 31 ++- cpp/velox/udf/UdfLoader.h | 45 ++- cpp/velox/udf/examples/MyUDF.cc | 257 ++++++++++++++---- .../sql/catalyst/types/DataTypeUtils.scala | 28 ++ .../sql/catalyst/types/DataTypeUtils.scala | 28 ++ .../sql/catalyst/types/DataTypeUtils.scala | 28 ++ 11 files changed, 472 insertions(+), 130 deletions(-) create mode 100644 shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala create mode 100644 shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala create mode 100644 shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index 3445e40e5634..bdfd24ed5c1b 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.types.{DataType, StructField, StructType} import org.apache.spark.util.Utils @@ -72,6 +73,25 @@ case class UserDefinedAggregateFunction( } } +trait UDFSignatureBase { + val expressionType: ExpressionType + val children: Seq[DataType] + val variableArity: Boolean +} + +case class UDFSignature( + expressionType: ExpressionType, + children: Seq[DataType], + variableArity: Boolean) + extends UDFSignatureBase + +case class UDAFSignature( + expressionType: ExpressionType, + children: Seq[DataType], + variableArity: Boolean, + intermediateAttrs: Seq[AttributeReference]) + extends UDFSignatureBase + case class UDFExpression( name: String, dataType: DataType, @@ -109,31 +129,40 @@ case class UDFExpression( object UDFResolver extends Logging { private val UDFNames = mutable.HashSet[String]() // (udf_name, arg1, arg2, ...) => return type - private val UDFMap = mutable.HashMap[(String, Seq[DataType]), ExpressionType]() + private val UDFMap = mutable.HashMap[String, mutable.MutableList[UDFSignature]]() private val UDAFNames = mutable.HashSet[String]() // (udaf_name, arg1, arg2, ...) => return type, intermediate attributes private val UDAFMap = - mutable.HashMap[(String, Seq[DataType]), (ExpressionType, Seq[AttributeReference])]() + mutable.HashMap[String, mutable.MutableList[UDAFSignature]]() private val LIB_EXTENSION = ".so" // Called by JNI. - def registerUDF(name: String, returnType: Array[Byte], argTypes: Array[Byte]): Unit = { + def registerUDF( + name: String, + returnType: Array[Byte], + argTypes: Array[Byte], + variableArity: Boolean): Unit = { registerUDF( name, ConverterUtils.parseFromBytes(returnType), - ConverterUtils.parseFromBytes(argTypes)) + ConverterUtils.parseFromBytes(argTypes), + variableArity) } private def registerUDF( name: String, returnType: ExpressionType, - argTypes: ExpressionType): Unit = { + argTypes: ExpressionType, + variableArity: Boolean): Unit = { assert(argTypes.dataType.isInstanceOf[StructType]) - UDFMap.put( - (name, argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)), - returnType) + val v = + UDFMap.getOrElseUpdate(name, mutable.MutableList[UDFSignature]()) + v += UDFSignature( + returnType, + argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType), + variableArity) UDFNames += name logInfo(s"Registered UDF: $name($argTypes) -> $returnType") } @@ -142,12 +171,14 @@ object UDFResolver extends Logging { name: String, returnType: Array[Byte], argTypes: Array[Byte], - intermediateTypes: Array[Byte]): Unit = { + intermediateTypes: Array[Byte], + variableArity: Boolean): Unit = { registerUDAF( name, ConverterUtils.parseFromBytes(returnType), ConverterUtils.parseFromBytes(argTypes), - ConverterUtils.parseFromBytes(intermediateTypes) + ConverterUtils.parseFromBytes(intermediateTypes), + variableArity ) } @@ -155,7 +186,8 @@ object UDFResolver extends Logging { name: String, returnType: ExpressionType, argTypes: ExpressionType, - intermediateTypes: ExpressionType): Unit = { + intermediateTypes: ExpressionType, + variableArity: Boolean): Unit = { assert(argTypes.dataType.isInstanceOf[StructType]) assert(intermediateTypes.dataType.isInstanceOf[StructType]) @@ -164,10 +196,14 @@ object UDFResolver extends Logging { case (f, index) => AttributeReference(s"inter_$index", f.dataType, f.nullable)() } - UDAFMap.put( - (name, argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType)), - (returnType, aggBufferAttributes) - ) + + val v = + UDAFMap.getOrElseUpdate(name, mutable.MutableList[UDAFSignature]()) + v += UDAFSignature( + returnType, + argTypes.dataType.asInstanceOf[StructType].fields.map(_.dataType), + variableArity, + aggBufferAttributes) UDAFNames += name logInfo(s"Registered UDAF: $name($argTypes) -> $returnType") } @@ -319,30 +355,81 @@ object UDFResolver extends Logging { } private def getUdfExpression(name: String)(children: Seq[Expression]) = { - val expressionType = - UDFMap.getOrElse( - (name, children.map(_.dataType)), - throw new UnsupportedOperationException( - s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} " + - s"is not registered.") - ) - UDFExpression(name, expressionType.dataType, expressionType.nullable, children) + def errorMessage: String = + s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered." + + val signatures = + UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage)); + + signatures.find(sig => tryBind(sig, children.map(_.dataType))) match { + case Some(sig) => + UDFExpression(name, sig.expressionType.dataType, sig.expressionType.nullable, children) + case None => + throw new UnsupportedOperationException(errorMessage) + } } private def getUdafExpression(name: String)(children: Seq[Expression]) = { - val (expressionType, aggBufferAttributes) = + def errorMessage: String = + s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered." + + val signatures = UDAFMap.getOrElse( - (name, children.map(_.dataType)), - throw new UnsupportedOperationException( - s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} " + - s"is not registered.") + name, + throw new UnsupportedOperationException(errorMessage) ) - UserDefinedAggregateFunction( - name, - expressionType.dataType, - expressionType.nullable, - children, - aggBufferAttributes) + signatures.find(sig => tryBind(sig, children.map(_.dataType))) match { + case Some(sig) => + UserDefinedAggregateFunction( + name, + sig.expressionType.dataType, + sig.expressionType.nullable, + children, + sig.intermediateAttrs) + case None => + throw new UnsupportedOperationException(errorMessage) + } + } + + // Returns true if required data types match the function signature. + // If the function signature is variable arity, the number of the last argument can be zero + // or more. + private def tryBind(sig: UDFSignatureBase, requiredDataTypes: Seq[DataType]): Boolean = { + if (!sig.variableArity) { + sig.children.size == requiredDataTypes.size && + sig.children + .zip(requiredDataTypes) + .forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) } + } else { + // If variableArity is true, there must be at least one argument in the signature. + if (requiredDataTypes.size < sig.children.size - 1) { + false + } else if (requiredDataTypes.size == sig.children.size - 1) { + sig.children + .dropRight(1) + .zip(requiredDataTypes) + .forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) } + } else { + val varArgStartIndex = sig.children.size - 1 + // First check all var args has the same type with the last argument of the signature. + if ( + !requiredDataTypes + .drop(varArgStartIndex) + .forall(argType => DataTypeUtils.sameType(sig.children.last, argType)) + ) { + false + } else if (varArgStartIndex == 0) { + // No fixed args. + true + } else { + // Whether fixed args matches. + sig.children + .dropRight(1) + .zip(requiredDataTypes.dropRight(1 + requiredDataTypes.size - sig.children.size)) + .forall { case (candidate, required) => DataTypeUtils.sameType(candidate, required) } + } + } + } } } diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala index 40452dfbc9cd..4d2f9fae3147 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala @@ -71,20 +71,24 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { .set("spark.memory.offHeap.size", "1024MB") } - testWithSpecifiedSparkVersion("test udf", Some("3.2")) { + test("test udf") { val df = spark.sql("""select - | myudf1(1), - | myudf1(1L), - | myudf2(100L), + | myudf1(100L), + | myudf2(1), + | myudf2(1L), + | myudf3(), + | myudf3(1), + | myudf3(1, 2, 3), + | myudf3(1L), + | myudf3(1L, 2L, 3L), | mydate(cast('2024-03-25' as date), 5) |""".stripMargin) - df.collect() assert( df.collect() - .sameElements(Array(Row(6, 6L, 105, Date.valueOf("2024-03-30"))))) + .sameElements(Array(Row(105L, 6, 6L, 5, 6, 11, 6L, 11L, Date.valueOf("2024-03-30"))))) } - testWithSpecifiedSparkVersion("test udaf", Some("3.2")) { + test("test udaf") { val df = spark.sql("""select | myavg(1), | myavg(1L), diff --git a/cpp/velox/jni/JniUdf.cc b/cpp/velox/jni/JniUdf.cc index cd5a4f7c8861..cab90b325fe5 100644 --- a/cpp/velox/jni/JniUdf.cc +++ b/cpp/velox/jni/JniUdf.cc @@ -41,8 +41,8 @@ void gluten::initVeloxJniUDF(JNIEnv* env) { udfResolverClass = createGlobalClassReferenceOrError(env, kUdfResolverClassPath.c_str()); // methods - registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", "(Ljava/lang/String;[B[B)V"); - registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDAF", "(Ljava/lang/String;[B[B[B)V"); + registerUDFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDF", "(Ljava/lang/String;[B[BZ)V"); + registerUDAFMethod = getMethodIdOrError(env, udfResolverClass, "registerUDAF", "(Ljava/lang/String;[B[B[BZ)V"); } void gluten::finalizeVeloxJniUDF(JNIEnv* env) { @@ -70,9 +70,10 @@ void gluten::jniGetFunctionSignatures(JNIEnv* env) { 0, signature->intermediateType.length(), reinterpret_cast(signature->intermediateType.c_str())); - env->CallVoidMethod(instance, registerUDAFMethod, name, returnType, argTypes, intermediateType); + env->CallVoidMethod( + instance, registerUDAFMethod, name, returnType, argTypes, intermediateType, signature->variableArity); } else { - env->CallVoidMethod(instance, registerUDFMethod, name, returnType, argTypes); + env->CallVoidMethod(instance, registerUDFMethod, name, returnType, argTypes, signature->variableArity); } checkException(env); } diff --git a/cpp/velox/udf/Udaf.h b/cpp/velox/udf/Udaf.h index 7e8f03402f95..5b33e0611ba2 100644 --- a/cpp/velox/udf/Udaf.h +++ b/cpp/velox/udf/Udaf.h @@ -27,6 +27,7 @@ struct UdafEntry { const char** argTypes; const char* intermediateType{nullptr}; + bool variableArity{false}; }; #define GLUTEN_GET_NUM_UDAF getNumUdaf diff --git a/cpp/velox/udf/Udf.h b/cpp/velox/udf/Udf.h index c3e579c443cc..1fa3c54d5213 100644 --- a/cpp/velox/udf/Udf.h +++ b/cpp/velox/udf/Udf.h @@ -25,6 +25,8 @@ struct UdfEntry { size_t numArgs; const char** argTypes; + + bool variableArity{false}; }; #define GLUTEN_GET_NUM_UDF getNumUdf diff --git a/cpp/velox/udf/UdfLoader.cc b/cpp/velox/udf/UdfLoader.cc index a8a99ce9fe8d..02aa410a95e1 100644 --- a/cpp/velox/udf/UdfLoader.cc +++ b/cpp/velox/udf/UdfLoader.cc @@ -86,11 +86,11 @@ std::unordered_set> UdfLoader::getRegis const auto& entry = udfEntries[i]; auto dataType = toSubstraitTypeStr(entry.dataType); auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes); - signatures_.insert(std::make_shared(entry.name, dataType, argTypes)); + signatures_.insert(std::make_shared(entry.name, dataType, argTypes, entry.variableArity)); } free(udfEntries); } else { - LOG(INFO) << "No UDFs found in " << libPath; + LOG(INFO) << "No UDF found in " << libPath; } // Handle UDAFs. @@ -110,11 +110,12 @@ std::unordered_set> UdfLoader::getRegis auto dataType = toSubstraitTypeStr(entry.dataType); auto argTypes = toSubstraitTypeStr(entry.numArgs, entry.argTypes); auto intermediateType = toSubstraitTypeStr(entry.intermediateType); - signatures_.insert(std::make_shared(entry.name, dataType, argTypes, intermediateType)); + signatures_.insert( + std::make_shared(entry.name, dataType, argTypes, intermediateType, entry.variableArity)); } free(udafEntries); } else { - LOG(INFO) << "No UDAFs found in " << libPath; + LOG(INFO) << "No UDAF found in " << libPath; } } return signatures_; @@ -151,4 +152,26 @@ std::shared_ptr UdfLoader::getInstance() { return instance; } +std::string UdfLoader::toSubstraitTypeStr(const std::string& type) { + auto returnType = parser_.parse(type); + auto substraitType = convertor_.toSubstraitType(arena_, returnType); + + std::string output; + substraitType.SerializeToString(&output); + return output; +} + +std::string UdfLoader::toSubstraitTypeStr(int32_t numArgs, const char** args) { + std::vector argTypes; + argTypes.resize(numArgs); + for (auto i = 0; i < numArgs; ++i) { + argTypes[i] = parser_.parse(args[i]); + } + auto substraitType = convertor_.toSubstraitType(arena_, facebook::velox::ROW(std::move(argTypes))); + + std::string output; + substraitType.SerializeToString(&output); + return output; +} + } // namespace gluten diff --git a/cpp/velox/udf/UdfLoader.h b/cpp/velox/udf/UdfLoader.h index 31098d2f437d..2783beb85511 100644 --- a/cpp/velox/udf/UdfLoader.h +++ b/cpp/velox/udf/UdfLoader.h @@ -36,11 +36,22 @@ class UdfLoader { std::string intermediateType{}; - UdfSignature(std::string name, std::string returnType, std::string argTypes) - : name(name), returnType(returnType), argTypes(argTypes) {} - - UdfSignature(std::string name, std::string returnType, std::string argTypes, std::string intermediateType) - : name(name), returnType(returnType), argTypes(argTypes), intermediateType(intermediateType) {} + bool variableArity; + + UdfSignature(std::string name, std::string returnType, std::string argTypes, bool variableArity) + : name(name), returnType(returnType), argTypes(argTypes), variableArity(variableArity) {} + + UdfSignature( + std::string name, + std::string returnType, + std::string argTypes, + std::string intermediateType, + bool variableArity) + : name(name), + returnType(returnType), + argTypes(argTypes), + intermediateType(intermediateType), + variableArity(variableArity) {} ~UdfSignature() = default; }; @@ -58,27 +69,9 @@ class UdfLoader { private: void loadUdfLibraries0(const std::vector& libPaths); - std::string toSubstraitTypeStr(const std::string& type) { - auto returnType = parser_.parse(type); - auto substraitType = convertor_.toSubstraitType(arena_, returnType); - - std::string output; - substraitType.SerializeToString(&output); - return output; - } - - std::string toSubstraitTypeStr(int32_t numArgs, const char** args) { - std::vector argTypes; - argTypes.resize(numArgs); - for (auto i = 0; i < numArgs; ++i) { - argTypes[i] = parser_.parse(args[i]); - } - auto substraitType = convertor_.toSubstraitType(arena_, facebook::velox::ROW(std::move(argTypes))); - - std::string output; - substraitType.SerializeToString(&output); - return output; - } + std::string toSubstraitTypeStr(const std::string& type); + + std::string toSubstraitTypeStr(int32_t numArgs, const char** args); std::unordered_map handles_; diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc index 578e3effb2b1..88bc3ad85da3 100644 --- a/cpp/velox/udf/examples/MyUDF.cc +++ b/cpp/velox/udf/examples/MyUDF.cc @@ -21,8 +21,6 @@ #include #include "udf/Udf.h" -namespace { - using namespace facebook::velox; using namespace facebook::velox::exec; @@ -30,10 +28,26 @@ static const char* kInteger = "int"; static const char* kBigInt = "bigint"; static const char* kDate = "date"; +class UdfRegisterer { + public: + ~UdfRegisterer() = default; + + // Returns the number of UDFs in populateUdfEntries. + virtual int getNumUdf() = 0; + + // Populate the udfEntries, starting at the given index. + virtual void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) = 0; + + // Register all function signatures to velox. + virtual void registerSignatures() = 0; +}; + +namespace myudf { + template -class PlusConstantFunction : public exec::VectorFunction { +class PlusFiveFunction : public exec::VectorFunction { public: - explicit PlusConstantFunction(int32_t addition) : addition_(addition) {} + explicit PlusFiveFunction() {} void apply( const SelectivityVector& rows, @@ -42,12 +56,6 @@ class PlusConstantFunction : public exec::VectorFunction { exec::EvalCtx& context, VectorPtr& result) const override { using nativeType = typename TypeTraits::NativeType; - VELOX_CHECK_EQ(args.size(), 1); - - auto& arg = args[0]; - - // The argument may be flat or constant. - VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding()); BaseVector::ensureWritable(rows, createScalarType(), context.pool(), result); @@ -56,79 +64,218 @@ class PlusConstantFunction : public exec::VectorFunction { flatResult->clearNulls(rows); - if (arg->isConstantEncoding()) { - auto value = arg->as>()->valueAt(0); - rows.applyToSelected([&](auto row) { rawResult[row] = value + addition_; }); - } else { - auto* rawInput = arg->as>()->rawValues(); + rows.applyToSelected([&](auto row) { rawResult[row] = 5; }); - rows.applyToSelected([&](auto row) { rawResult[row] = rawInput[row] + addition_; }); + if (args.size() == 0) { + return; } - } - private: - const int32_t addition_; -}; - -template -struct MyDateSimpleFunction { - VELOX_DEFINE_FUNCTION_TYPES(T); - - FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type& date, const arg_type addition) { - result = date + addition; + for (int i = 0; i < args.size(); i++) { + auto& arg = args[i]; + VELOX_CHECK(arg->isFlatEncoding() || arg->isConstantEncoding()); + if (arg->isConstantEncoding()) { + auto value = arg->as>()->valueAt(0); + rows.applyToSelected([&](auto row) { rawResult[row] += value; }); + } else { + auto* rawInput = arg->as>()->rawValues(); + rows.applyToSelected([&](auto row) { rawResult[row] += rawInput[row]; }); + } + } } }; -std::shared_ptr makeMyUdf1( +static std::shared_ptr makePlusConstant( const std::string& /*name*/, const std::vector& inputArgs, const core::QueryConfig& /*config*/) { + if (inputArgs.size() == 0) { + return std::make_shared>(); + } auto typeKind = inputArgs[0].type->kind(); switch (typeKind) { case TypeKind::INTEGER: - return std::make_shared>(5); + return std::make_shared>(); case TypeKind::BIGINT: - return std::make_shared>(5); + return std::make_shared>(); default: VELOX_UNREACHABLE(); } } -static std::vector> integerSignatures() { - // integer -> integer, bigint ->bigint - return { - exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(), - exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()}; -} +// name: myudf1 +// signatures: +// bigint -> bigint +// type: VectorFunction +class MyUdf1Registerer final : public UdfRegisterer { + public: + int getNumUdf() override { + return 1; + } -static std::vector> bigintSignatures() { - // bigint -> bigint - return {exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()}; -} + void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override { + udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_}; + } + + void registerSignatures() override { + facebook::velox::exec::registerVectorFunction( + name_, bigintSignatures(), std::make_unique>()); + } + + private: + std::vector> bigintSignatures() { + return {exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()}; + } + + const std::string name_ = "myudf1"; + const char* bigintArg_[1] = {kBigInt}; +}; + +// name: myudf2 +// signatures: +// integer -> integer +// bigint -> bigint +// type: StatefulVectorFunction +class MyUdf2Registerer final : public UdfRegisterer { + public: + int getNumUdf() override { + return 2; + } + + void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override { + udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_}; + udfEntries[index++] = {name_.c_str(), kBigInt, 1, bigintArg_}; + } + + void registerSignatures() override { + facebook::velox::exec::registerStatefulVectorFunction(name_, integerAndBigintSignatures(), makePlusConstant); + } + + private: + std::vector> integerAndBigintSignatures() { + return { + exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").build(), + exec::FunctionSignatureBuilder().returnType("bigint").argumentType("bigint").build()}; + } + + const std::string name_ = "myudf2"; + const char* integerArg_[1] = {kInteger}; + const char* bigintArg_[1] = {kBigInt}; +}; + +// name: myudf3 +// signatures: +// [integer,] ... -> integer +// bigint, [bigint,] ... -> bigint +// type: StatefulVectorFunction with variable arity +class MyUdf3Registerer final : public UdfRegisterer { + public: + int getNumUdf() override { + return 2; + } + + void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override { + udfEntries[index++] = {name_.c_str(), kInteger, 1, integerArg_, true}; + udfEntries[index++] = {name_.c_str(), kBigInt, 2, bigintArgs_, true}; + } + + void registerSignatures() override { + facebook::velox::exec::registerStatefulVectorFunction( + name_, integerAndBigintSignaturesWithVariableArity(), makePlusConstant); + } + + private: + std::vector> integerAndBigintSignaturesWithVariableArity() { + return { + exec::FunctionSignatureBuilder().returnType("integer").argumentType("integer").variableArity().build(), + exec::FunctionSignatureBuilder() + .returnType("bigint") + .argumentType("bigint") + .argumentType("bigint") + .variableArity() + .build()}; + } -} // namespace + const std::string name_ = "myudf3"; + const char* integerArg_[1] = {kInteger}; + const char* bigintArgs_[2] = {kBigInt, kBigInt}; +}; +} // namespace myudf -const int kNumMyUdf = 4; +namespace mydate { +template +struct MyDateSimpleFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(int32_t& result, const arg_type& date, const arg_type addition) { + result = date + addition; + } +}; + +// name: mydate +// signatures: +// date, integer -> bigint +// type: SimpleFunction +class MyDateRegisterer final : public UdfRegisterer { + public: + int getNumUdf() override { + return 1; + } + + void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override { + udfEntries[index++] = {name_.c_str(), kDate, 2, myDateArg_}; + } + + void registerSignatures() override { + facebook::velox::registerFunction({name_}); + } + + private: + const std::string name_ = "mydate"; + const char* myDateArg_[2] = {kDate, kInteger}; +}; +} // namespace mydate + +std::vector>& globalRegisters() { + static std::vector> registerers; + return registerers; +} + +void setupRegisterers() { + static bool inited = false; + if (inited) { + return; + } + auto& registerers = globalRegisters(); + registerers.push_back(std::make_shared()); + registerers.push_back(std::make_shared()); + registerers.push_back(std::make_shared()); + registerers.push_back(std::make_shared()); + inited = true; +} DEFINE_GET_NUM_UDF { - return kNumMyUdf; + setupRegisterers(); + + int numUdf = 0; + for (const auto& registerer : globalRegisters()) { + numUdf += registerer->getNumUdf(); + } + return numUdf; } -const char* myUdf1Arg1[] = {kInteger}; -const char* myUdf1Arg2[] = {kBigInt}; -const char* myUdf2Arg1[] = {kBigInt}; -const char* myDateArg[] = {kDate, kInteger}; DEFINE_GET_UDF_ENTRIES { + setupRegisterers(); + int index = 0; - udfEntries[index++] = {"myudf1", kInteger, 1, myUdf1Arg1}; - udfEntries[index++] = {"myudf1", kBigInt, 1, myUdf1Arg2}; - udfEntries[index++] = {"myudf2", kBigInt, 1, myUdf2Arg1}; - udfEntries[index++] = {"mydate", kDate, 2, myDateArg}; + for (const auto& registerer : globalRegisters()) { + registerer->populateUdfEntries(index, udfEntries); + } } DEFINE_REGISTER_UDF { - facebook::velox::exec::registerStatefulVectorFunction("myudf1", integerSignatures(), makeMyUdf1); - facebook::velox::exec::registerVectorFunction( - "myudf2", bigintSignatures(), std::make_unique>(5)); - facebook::velox::registerFunction({"mydate"}); + setupRegisterers(); + + for (const auto& registerer : globalRegisters()) { + registerer->registerSignatures(); + } } diff --git a/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala new file mode 100644 index 000000000000..597b5936f2d2 --- /dev/null +++ b/shims/spark32/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.types + +import org.apache.spark.sql.types.DataType + +object DataTypeUtils { + + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + def sameType(left: DataType, right: DataType): Boolean = left.sameType(right) +} diff --git a/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala new file mode 100644 index 000000000000..597b5936f2d2 --- /dev/null +++ b/shims/spark33/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.types + +import org.apache.spark.sql.types.DataType + +object DataTypeUtils { + + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + def sameType(left: DataType, right: DataType): Boolean = left.sameType(right) +} diff --git a/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala new file mode 100644 index 000000000000..597b5936f2d2 --- /dev/null +++ b/shims/spark34/src/main/scala/org/apache/spark/sql/catalyst/types/DataTypeUtils.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.types + +import org.apache.spark.sql.types.DataType + +object DataTypeUtils { + + /** + * Check if `this` and `other` are the same data type when ignoring nullability + * (`StructField.nullable`, `ArrayType.containsNull`, and `MapType.valueContainsNull`). + */ + def sameType(left: DataType, right: DataType): Boolean = left.sameType(right) +}