From d4d724143286eec63605e10585a6133ce3c52b9d Mon Sep 17 00:00:00 2001 From: Rong Ma Date: Sun, 25 Aug 2024 13:11:33 +0800 Subject: [PATCH] [VL] Support create temporary function for native hive udf (#6829) --- backends-velox/pom.xml | 7 ++ .../backendsapi/velox/VeloxRuleApi.scala | 2 +- .../velox/VeloxSparkPlanExecApi.scala | 7 ++ .../spark/sql/expression/UDFResolver.scala | 30 +++--- .../sql/hive/VeloxHiveUDFTransformer.scala | 49 +++++++++ .../gluten/expression/VeloxUdfSuite.scala | 99 +++++++++++++++++++ cpp/velox/udf/examples/MyUDF.cc | 39 ++++++++ .../gluten/backendsapi/SparkPlanExecApi.scala | 8 +- .../expression/ExpressionConverter.scala | 4 +- .../spark/sql/hive/HiveUDFTransformer.scala | 6 ++ 10 files changed, 235 insertions(+), 16 deletions(-) create mode 100644 backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala diff --git a/backends-velox/pom.xml b/backends-velox/pom.xml index 0fe8f5f6fd8e..417f64999b95 100755 --- a/backends-velox/pom.xml +++ b/backends-velox/pom.xml @@ -140,6 +140,13 @@ spark-core_${scala.binary.version} test-jar + + org.apache.spark + spark-hive_${scala.binary.version} + ${spark.version} + test-jar + test + org.apache.spark spark-sql_${scala.binary.version} diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala index abb39c5bb23d..438895b25ae9 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxRuleApi.scala @@ -47,7 +47,7 @@ private object VeloxRuleApi { // Regular Spark rules. injector.injectOptimizerRule(CollectRewriteRule.apply) injector.injectOptimizerRule(HLLRewriteRule.apply) - UDFResolver.getFunctionSignatures.foreach(injector.injectFunction) + UDFResolver.getFunctionSignatures().foreach(injector.injectFunction) injector.injectPostHocResolutionRule(ArrowConvertorRule.apply) } diff --git a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala index bd390004feda..554b3791dad3 100644 --- a/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/org/apache/gluten/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec import org.apache.spark.sql.execution.utils.ExecUtil import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction} +import org.apache.spark.sql.hive.VeloxHiveUDFTransformer import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.vectorized.ColumnarBatch @@ -819,4 +820,10 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi { case other => other } } + + override def genHiveUDFTransformer( + expr: Expression, + attributeSeq: Seq[Attribute]): ExpressionTransformer = { + VeloxHiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq) + } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala index ab83c55ee306..39032e46f381 100644 --- a/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala +++ b/backends-velox/src/main/scala/org/apache/spark/sql/expression/UDFResolver.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.expression import org.apache.gluten.backendsapi.velox.VeloxBackendSettings -import org.apache.gluten.exception.GlutenException +import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException} import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable} import org.apache.gluten.udf.UdfJniWrapper import org.apache.gluten.vectorized.JniWorkspace @@ -95,11 +95,14 @@ case class UDAFSignature( case class UDFExpression( name: String, + alias: String, dataType: DataType, nullable: Boolean, children: Seq[Expression]) extends Unevaluable with Transformable { + override def nodeName: String = alias + override protected def withNewChildrenInternal( newChildren: IndexedSeq[Expression]): Expression = { this.copy(children = newChildren) @@ -118,11 +121,11 @@ case class UDFExpression( } object UDFResolver extends Logging { - private val UDFNames = mutable.HashSet[String]() + val UDFNames = mutable.HashSet[String]() // (udf_name, arg1, arg2, ...) => return type private val UDFMap = mutable.HashMap[String, mutable.ListBuffer[UDFSignature]]() - private val UDAFNames = mutable.HashSet[String]() + val UDAFNames = mutable.HashSet[String]() // (udaf_name, arg1, arg2, ...) => return type, intermediate attributes private val UDAFMap = mutable.HashMap[String, mutable.ListBuffer[UDAFSignature]]() @@ -331,7 +334,7 @@ object UDFResolver extends Logging { .mkString(",") } - def getFunctionSignatures: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = { + def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = { val sparkContext = SparkContext.getActive.get val sparkConf = sparkContext.conf val udfLibPaths = sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS) @@ -341,13 +344,12 @@ object UDFResolver extends Logging { Seq.empty case Some(_) => UdfJniWrapper.getFunctionSignatures() - UDFNames.map { name => ( new FunctionIdentifier(name), new ExpressionInfo(classOf[UDFExpression].getName, name), - (e: Seq[Expression]) => getUdfExpression(name)(e)) + (e: Seq[Expression]) => getUdfExpression(name, name)(e)) }.toSeq ++ UDAFNames.map { name => ( @@ -364,27 +366,29 @@ object UDFResolver extends Logging { .toBoolean } - private def getUdfExpression(name: String)(children: Seq[Expression]) = { + def getUdfExpression(name: String, alias: String)(children: Seq[Expression]): UDFExpression = { def errorMessage: String = s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered." val allowTypeConversion = checkAllowTypeConversion val signatures = - UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage)); + UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage)); signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match { case Some(sig) => UDFExpression( name, + alias, sig.expressionType.dataType, sig.expressionType.nullable, if (!allowTypeConversion && !sig.allowTypeConversion) children - else applyCast(children, sig)) + else applyCast(children, sig) + ) case None => - throw new UnsupportedOperationException(errorMessage) + throw new GlutenNotSupportException(errorMessage) } } - private def getUdafExpression(name: String)(children: Seq[Expression]) = { + def getUdafExpression(name: String)(children: Seq[Expression]): UserDefinedAggregateFunction = { def errorMessage: String = s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered." @@ -392,7 +396,7 @@ object UDFResolver extends Logging { val signatures = UDAFMap.getOrElse( name, - throw new UnsupportedOperationException(errorMessage) + throw new GlutenNotSupportException(errorMessage) ) signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match { case Some(sig) => @@ -405,7 +409,7 @@ object UDFResolver extends Logging { sig.intermediateAttrs ) case None => - throw new UnsupportedOperationException(errorMessage) + throw new GlutenNotSupportException(errorMessage) } } diff --git a/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala new file mode 100644 index 000000000000..d895faa31702 --- /dev/null +++ b/backends-velox/src/main/scala/org/apache/spark/sql/hive/VeloxHiveUDFTransformer.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive + +import org.apache.gluten.exception.GlutenNotSupportException +import org.apache.gluten.expression.{ExpressionConverter, ExpressionTransformer} + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} +import org.apache.spark.sql.expression.UDFResolver + +object VeloxHiveUDFTransformer { + def replaceWithExpressionTransformer( + expr: Expression, + attributeSeq: Seq[Attribute]): ExpressionTransformer = { + val (udfName, udfClassName) = expr match { + case s: HiveSimpleUDF => + (s.name.stripPrefix("default."), s.funcWrapper.functionClassName) + case g: HiveGenericUDF => + (g.name.stripPrefix("default."), g.funcWrapper.functionClassName) + case _ => + throw new GlutenNotSupportException( + s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF") + } + + if (UDFResolver.UDFNames.contains(udfClassName)) { + UDFResolver + .getUdfExpression(udfClassName, udfName)(expr.children) + .getTransformer( + ExpressionConverter.replaceWithExpressionTransformer(expr.children, attributeSeq) + ) + } else { + HiveUDFTransformer.genTransformerFromUDFMappings(udfName, expr, attributeSeq) + } + } +} diff --git a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala index 008337b9400e..596757df35d9 100644 --- a/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala +++ b/backends-velox/src/test/scala/org/apache/gluten/expression/VeloxUdfSuite.scala @@ -22,6 +22,7 @@ import org.apache.gluten.tags.{SkipTestTags, UDFTest} import org.apache.spark.SparkConf import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.expression.UDFResolver import java.nio.file.Paths import java.sql.Date @@ -56,12 +57,31 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { .builder() .master(master) .config(sparkConf) + .enableHiveSupport() .getOrCreate() } _spark.sparkContext.setLogLevel("info") } + override def afterAll(): Unit = { + try { + super.afterAll() + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + doThreadPostAudit() + } + } + override protected def spark = _spark protected def sparkConf: SparkConf = { @@ -128,6 +148,85 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper { .sameElements(Array(Row(1.0, 1.0, 1L)))) } } + + test("test hive udf replacement") { + val tbl = "test_hive_udf_replacement" + withTempPath { + dir => + try { + spark.sql(s""" + |CREATE EXTERNAL TABLE $tbl + |LOCATION 'file://$dir' + |AS select * from values (1, '1'), (2, '2'), (3, '3') + |""".stripMargin) + + // Check native hive udf has been registered. + assert( + UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString")) + + spark.sql(""" + |CREATE TEMPORARY FUNCTION hive_string_string + |AS 'org.apache.spark.sql.hive.execution.UDFStringString' + |""".stripMargin) + + val nativeResult = + spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect() + // Unregister native hive udf to fallback. + UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString") + val fallbackResult = + spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect() + assert(nativeResult.sameElements(fallbackResult)) + + // Add an unimplemented udf to the map to test fallback of registered native hive udf. + UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString") + spark.sql(""" + |CREATE TEMPORARY FUNCTION hive_int_to_string + |AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString' + |""".stripMargin) + val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""") + checkAnswer(df, Seq(Row("1"), Row("2"), Row("3"))) + } finally { + spark.sql(s"DROP TABLE IF EXISTS $tbl") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string") + } + } + } + + test("test udf fallback in partition filter") { + withTempPath { + dir => + try { + spark.sql(""" + |CREATE TEMPORARY FUNCTION hive_int_to_string + |AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString' + |""".stripMargin) + + spark.sql(s""" + |CREATE EXTERNAL TABLE t(i INT, p INT) + |LOCATION 'file://$dir' + |PARTITIONED BY (p)""".stripMargin) + + spark + .range(0, 10, 1) + .selectExpr("id as col") + .createOrReplaceTempView("temp") + + for (part <- Seq(1, 2, 3, 4)) { + spark.sql(s""" + |INSERT OVERWRITE TABLE t PARTITION (p=$part) + |SELECT col FROM temp""".stripMargin) + } + + val df = spark.sql("SELECT i FROM t WHERE hive_int_to_string(p) = '4'") + checkAnswer(df, (0 until 10).map(Row(_))) + } finally { + spark.sql("DROP TABLE IF EXISTS t") + spark.sql("DROP VIEW IF EXISTS temp") + spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string") + } + } + } } @UDFTest diff --git a/cpp/velox/udf/examples/MyUDF.cc b/cpp/velox/udf/examples/MyUDF.cc index db1c5d7709f0..75e68413a842 100644 --- a/cpp/velox/udf/examples/MyUDF.cc +++ b/cpp/velox/udf/examples/MyUDF.cc @@ -30,6 +30,7 @@ namespace { static const char* kInteger = "int"; static const char* kBigInt = "bigint"; static const char* kDate = "date"; +static const char* kVarChar = "varchar"; namespace myudf { @@ -248,6 +249,43 @@ class MyDate2Registerer final : public gluten::UdfRegisterer { }; } // namespace mydate +namespace hivestringstring { +template +struct HiveStringStringFunction { + VELOX_DEFINE_FUNCTION_TYPES(T); + + FOLLY_ALWAYS_INLINE void call(out_type& result, const arg_type& a, const arg_type& b) { + result.append(a.data()); + result.append(" "); + result.append(b.data()); + } +}; + +// name: org.apache.spark.sql.hive.execution.UDFStringString +// signatures: +// varchar, varchar -> varchar +// type: SimpleFunction +class HiveStringStringRegisterer final : public gluten::UdfRegisterer { + public: + int getNumUdf() override { + return 1; + } + + void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override { + // Set `allowTypeConversion` for hive udf. + udfEntries[index++] = {name_.c_str(), kVarChar, 2, arg_, false, true}; + } + + void registerSignatures() override { + facebook::velox::registerFunction({name_}); + } + + private: + const std::string name_ = "org.apache.spark.sql.hive.execution.UDFStringString"; + const char* arg_[2] = {kVarChar, kVarChar}; +}; +} // namespace hivestringstring + std::vector>& globalRegisters() { static std::vector> registerers; return registerers; @@ -264,6 +302,7 @@ void setupRegisterers() { registerers.push_back(std::make_shared()); registerers.push_back(std::make_shared()); registerers.push_back(std::make_shared()); + registerers.push_back(std::make_shared()); inited = true; } } // namespace diff --git a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala index 0227ed5da127..fb87a9ac93c0 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/backendsapi/SparkPlanExecApi.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.BuildSideRelation import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.python.ArrowEvalPythonExec -import org.apache.spark.sql.hive.HiveTableScanExecTransformer +import org.apache.spark.sql.hive.{HiveTableScanExecTransformer, HiveUDFTransformer} import org.apache.spark.sql.types.{DecimalType, LongType, NullType, StructType} import org.apache.spark.sql.vectorized.ColumnarBatch @@ -670,4 +670,10 @@ trait SparkPlanExecApi { DecimalType(math.min(integralLeastNumDigits + newScale, 38), newScale) } } + + def genHiveUDFTransformer( + expr: Expression, + attributeSeq: Seq[Attribute]): ExpressionTransformer = { + HiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq) + } } diff --git a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala index 8bca5dbf8605..d5ca31bb5e78 100644 --- a/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala +++ b/gluten-core/src/main/scala/org/apache/gluten/expression/ExpressionConverter.scala @@ -128,7 +128,9 @@ object ExpressionConverter extends SQLConfHelper with Logging { case s: ScalaUDF => return replaceScalaUDFWithExpressionTransformer(s, attributeSeq, expressionsMap) case _ if HiveUDFTransformer.isHiveUDF(expr) => - return HiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq) + return BackendsApiManager.getSparkPlanExecApiInstance.genHiveUDFTransformer( + expr, + attributeSeq) case i: StaticInvoke => val objectName = i.staticObject.getName.stripSuffix("$") if (objectName.endsWith("UrlCodec")) { diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala index 5cd64cc212f8..52739aaca439 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/hive/HiveUDFTransformer.scala @@ -43,7 +43,13 @@ object HiveUDFTransformer { throw new GlutenNotSupportException( s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF") } + genTransformerFromUDFMappings(udfName, expr, attributeSeq) + } + def genTransformerFromUDFMappings( + udfName: String, + expr: Expression, + attributeSeq: Seq[Attribute]): GenericExpressionTransformer = { UDFMappings.hiveUDFMap.get(udfName.toLowerCase(Locale.ROOT)) match { case Some(name) => GenericExpressionTransformer(