Skip to content

Commit

Permalink
[VL] Support create temporary function for native hive udf (#6829)
Browse files Browse the repository at this point in the history
  • Loading branch information
marin-ma authored Aug 25, 2024
1 parent a575395 commit d4d7241
Show file tree
Hide file tree
Showing 10 changed files with 235 additions and 16 deletions.
7 changes: 7 additions & 0 deletions backends-velox/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,13 @@
<artifactId>spark-core_${scala.binary.version}</artifactId>
<type>test-jar</type>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-hive_${scala.binary.version}</artifactId>
<version>${spark.version}</version>
<type>test-jar</type>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private object VeloxRuleApi {
// Regular Spark rules.
injector.injectOptimizerRule(CollectRewriteRule.apply)
injector.injectOptimizerRule(HLLRewriteRule.apply)
UDFResolver.getFunctionSignatures.foreach(injector.injectFunction)
UDFResolver.getFunctionSignatures().foreach(injector.injectFunction)
injector.injectPostHocResolutionRule(ArrowConvertorRule.apply)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.ArrowEvalPythonExec
import org.apache.spark.sql.execution.utils.ExecUtil
import org.apache.spark.sql.expression.{UDFExpression, UserDefinedAggregateFunction}
import org.apache.spark.sql.hive.VeloxHiveUDFTransformer
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.sql.vectorized.ColumnarBatch
Expand Down Expand Up @@ -819,4 +820,10 @@ class VeloxSparkPlanExecApi extends SparkPlanExecApi {
case other => other
}
}

override def genHiveUDFTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
VeloxHiveUDFTransformer.replaceWithExpressionTransformer(expr, attributeSeq)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
package org.apache.spark.sql.expression

import org.apache.gluten.backendsapi.velox.VeloxBackendSettings
import org.apache.gluten.exception.GlutenException
import org.apache.gluten.exception.{GlutenException, GlutenNotSupportException}
import org.apache.gluten.expression.{ConverterUtils, ExpressionTransformer, ExpressionType, GenericExpressionTransformer, Transformable}
import org.apache.gluten.udf.UdfJniWrapper
import org.apache.gluten.vectorized.JniWorkspace
Expand Down Expand Up @@ -95,11 +95,14 @@ case class UDAFSignature(

case class UDFExpression(
name: String,
alias: String,
dataType: DataType,
nullable: Boolean,
children: Seq[Expression])
extends Unevaluable
with Transformable {
override def nodeName: String = alias

override protected def withNewChildrenInternal(
newChildren: IndexedSeq[Expression]): Expression = {
this.copy(children = newChildren)
Expand All @@ -118,11 +121,11 @@ case class UDFExpression(
}

object UDFResolver extends Logging {
private val UDFNames = mutable.HashSet[String]()
val UDFNames = mutable.HashSet[String]()
// (udf_name, arg1, arg2, ...) => return type
private val UDFMap = mutable.HashMap[String, mutable.ListBuffer[UDFSignature]]()

private val UDAFNames = mutable.HashSet[String]()
val UDAFNames = mutable.HashSet[String]()
// (udaf_name, arg1, arg2, ...) => return type, intermediate attributes
private val UDAFMap =
mutable.HashMap[String, mutable.ListBuffer[UDAFSignature]]()
Expand Down Expand Up @@ -331,7 +334,7 @@ object UDFResolver extends Logging {
.mkString(",")
}

def getFunctionSignatures: Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
def getFunctionSignatures(): Seq[(FunctionIdentifier, ExpressionInfo, FunctionBuilder)] = {
val sparkContext = SparkContext.getActive.get
val sparkConf = sparkContext.conf
val udfLibPaths = sparkConf.getOption(VeloxBackendSettings.GLUTEN_VELOX_UDF_LIB_PATHS)
Expand All @@ -341,13 +344,12 @@ object UDFResolver extends Logging {
Seq.empty
case Some(_) =>
UdfJniWrapper.getFunctionSignatures()

UDFNames.map {
name =>
(
new FunctionIdentifier(name),
new ExpressionInfo(classOf[UDFExpression].getName, name),
(e: Seq[Expression]) => getUdfExpression(name)(e))
(e: Seq[Expression]) => getUdfExpression(name, name)(e))
}.toSeq ++ UDAFNames.map {
name =>
(
Expand All @@ -364,35 +366,37 @@ object UDFResolver extends Logging {
.toBoolean
}

private def getUdfExpression(name: String)(children: Seq[Expression]) = {
def getUdfExpression(name: String, alias: String)(children: Seq[Expression]): UDFExpression = {
def errorMessage: String =
s"UDF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDFMap.getOrElse(name, throw new UnsupportedOperationException(errorMessage));
UDFMap.getOrElse(name, throw new GlutenNotSupportException(errorMessage));
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
UDFExpression(
name,
alias,
sig.expressionType.dataType,
sig.expressionType.nullable,
if (!allowTypeConversion && !sig.allowTypeConversion) children
else applyCast(children, sig))
else applyCast(children, sig)
)
case None =>
throw new UnsupportedOperationException(errorMessage)
throw new GlutenNotSupportException(errorMessage)
}
}

private def getUdafExpression(name: String)(children: Seq[Expression]) = {
def getUdafExpression(name: String)(children: Seq[Expression]): UserDefinedAggregateFunction = {
def errorMessage: String =
s"UDAF $name -> ${children.map(_.dataType.simpleString).mkString(", ")} is not registered."

val allowTypeConversion = checkAllowTypeConversion
val signatures =
UDAFMap.getOrElse(
name,
throw new UnsupportedOperationException(errorMessage)
throw new GlutenNotSupportException(errorMessage)
)
signatures.find(sig => tryBind(sig, children.map(_.dataType), allowTypeConversion)) match {
case Some(sig) =>
Expand All @@ -405,7 +409,7 @@ object UDFResolver extends Logging {
sig.intermediateAttrs
)
case None =>
throw new UnsupportedOperationException(errorMessage)
throw new GlutenNotSupportException(errorMessage)
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.hive

import org.apache.gluten.exception.GlutenNotSupportException
import org.apache.gluten.expression.{ExpressionConverter, ExpressionTransformer}

import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.expression.UDFResolver

object VeloxHiveUDFTransformer {
def replaceWithExpressionTransformer(
expr: Expression,
attributeSeq: Seq[Attribute]): ExpressionTransformer = {
val (udfName, udfClassName) = expr match {
case s: HiveSimpleUDF =>
(s.name.stripPrefix("default."), s.funcWrapper.functionClassName)
case g: HiveGenericUDF =>
(g.name.stripPrefix("default."), g.funcWrapper.functionClassName)
case _ =>
throw new GlutenNotSupportException(
s"Expression $expr is not a HiveSimpleUDF or HiveGenericUDF")
}

if (UDFResolver.UDFNames.contains(udfClassName)) {
UDFResolver
.getUdfExpression(udfClassName, udfName)(expr.children)
.getTransformer(
ExpressionConverter.replaceWithExpressionTransformer(expr.children, attributeSeq)
)
} else {
HiveUDFTransformer.genTransformerFromUDFMappings(udfName, expr, attributeSeq)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import org.apache.gluten.tags.{SkipTestTags, UDFTest}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{GlutenQueryTest, Row, SparkSession}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.expression.UDFResolver

import java.nio.file.Paths
import java.sql.Date
Expand Down Expand Up @@ -56,12 +57,31 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.builder()
.master(master)
.config(sparkConf)
.enableHiveSupport()
.getOrCreate()
}

_spark.sparkContext.setLogLevel("info")
}

override def afterAll(): Unit = {
try {
super.afterAll()
if (_spark != null) {
try {
_spark.sessionState.catalog.reset()
} finally {
_spark.stop()
_spark = null
}
}
} finally {
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
doThreadPostAudit()
}
}

override protected def spark = _spark

protected def sparkConf: SparkConf = {
Expand Down Expand Up @@ -128,6 +148,85 @@ abstract class VeloxUdfSuite extends GlutenQueryTest with SQLHelper {
.sameElements(Array(Row(1.0, 1.0, 1L))))
}
}

test("test hive udf replacement") {
val tbl = "test_hive_udf_replacement"
withTempPath {
dir =>
try {
spark.sql(s"""
|CREATE EXTERNAL TABLE $tbl
|LOCATION 'file://$dir'
|AS select * from values (1, '1'), (2, '2'), (3, '3')
|""".stripMargin)

// Check native hive udf has been registered.
assert(
UDFResolver.UDFNames.contains("org.apache.spark.sql.hive.execution.UDFStringString"))

spark.sql("""
|CREATE TEMPORARY FUNCTION hive_string_string
|AS 'org.apache.spark.sql.hive.execution.UDFStringString'
|""".stripMargin)

val nativeResult =
spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
// Unregister native hive udf to fallback.
UDFResolver.UDFNames.remove("org.apache.spark.sql.hive.execution.UDFStringString")
val fallbackResult =
spark.sql(s"""SELECT hive_string_string(col2, 'a') FROM $tbl""").collect()
assert(nativeResult.sameElements(fallbackResult))

// Add an unimplemented udf to the map to test fallback of registered native hive udf.
UDFResolver.UDFNames.add("org.apache.spark.sql.hive.execution.UDFIntegerToString")
spark.sql("""
|CREATE TEMPORARY FUNCTION hive_int_to_string
|AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString'
|""".stripMargin)
val df = spark.sql(s"""select hive_int_to_string(col1) from $tbl""")
checkAnswer(df, Seq(Row("1"), Row("2"), Row("3")))
} finally {
spark.sql(s"DROP TABLE IF EXISTS $tbl")
spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_int_to_string")
}
}
}

test("test udf fallback in partition filter") {
withTempPath {
dir =>
try {
spark.sql("""
|CREATE TEMPORARY FUNCTION hive_int_to_string
|AS 'org.apache.spark.sql.hive.execution.UDFIntegerToString'
|""".stripMargin)

spark.sql(s"""
|CREATE EXTERNAL TABLE t(i INT, p INT)
|LOCATION 'file://$dir'
|PARTITIONED BY (p)""".stripMargin)

spark
.range(0, 10, 1)
.selectExpr("id as col")
.createOrReplaceTempView("temp")

for (part <- Seq(1, 2, 3, 4)) {
spark.sql(s"""
|INSERT OVERWRITE TABLE t PARTITION (p=$part)
|SELECT col FROM temp""".stripMargin)
}

val df = spark.sql("SELECT i FROM t WHERE hive_int_to_string(p) = '4'")
checkAnswer(df, (0 until 10).map(Row(_)))
} finally {
spark.sql("DROP TABLE IF EXISTS t")
spark.sql("DROP VIEW IF EXISTS temp")
spark.sql(s"DROP TEMPORARY FUNCTION IF EXISTS hive_string_string")
}
}
}
}

@UDFTest
Expand Down
39 changes: 39 additions & 0 deletions cpp/velox/udf/examples/MyUDF.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ namespace {
static const char* kInteger = "int";
static const char* kBigInt = "bigint";
static const char* kDate = "date";
static const char* kVarChar = "varchar";

namespace myudf {

Expand Down Expand Up @@ -248,6 +249,43 @@ class MyDate2Registerer final : public gluten::UdfRegisterer {
};
} // namespace mydate

namespace hivestringstring {
template <typename T>
struct HiveStringStringFunction {
VELOX_DEFINE_FUNCTION_TYPES(T);

FOLLY_ALWAYS_INLINE void call(out_type<Varchar>& result, const arg_type<Varchar>& a, const arg_type<Varchar>& b) {
result.append(a.data());
result.append(" ");
result.append(b.data());
}
};

// name: org.apache.spark.sql.hive.execution.UDFStringString
// signatures:
// varchar, varchar -> varchar
// type: SimpleFunction
class HiveStringStringRegisterer final : public gluten::UdfRegisterer {
public:
int getNumUdf() override {
return 1;
}

void populateUdfEntries(int& index, gluten::UdfEntry* udfEntries) override {
// Set `allowTypeConversion` for hive udf.
udfEntries[index++] = {name_.c_str(), kVarChar, 2, arg_, false, true};
}

void registerSignatures() override {
facebook::velox::registerFunction<HiveStringStringFunction, Varchar, Varchar, Varchar>({name_});
}

private:
const std::string name_ = "org.apache.spark.sql.hive.execution.UDFStringString";
const char* arg_[2] = {kVarChar, kVarChar};
};
} // namespace hivestringstring

std::vector<std::shared_ptr<gluten::UdfRegisterer>>& globalRegisters() {
static std::vector<std::shared_ptr<gluten::UdfRegisterer>> registerers;
return registerers;
Expand All @@ -264,6 +302,7 @@ void setupRegisterers() {
registerers.push_back(std::make_shared<myudf::MyUdf3Registerer>());
registerers.push_back(std::make_shared<mydate::MyDateRegisterer>());
registerers.push_back(std::make_shared<mydate::MyDate2Registerer>());
registerers.push_back(std::make_shared<hivestringstring::HiveStringStringRegisterer>());
inited = true;
}
} // namespace
Expand Down
Loading

0 comments on commit d4d7241

Please sign in to comment.