Skip to content

Commit

Permalink
[OPPRO-274] Enable GlutenStringFunctionsSuite and lpad()/rpad() funct…
Browse files Browse the repository at this point in the history
…ions in Gluten. (facebookincubator#512)

* Support string functions  for Gluten, including length, lower and upper, ascii, concat, replace and coalesce. oap-project#436

* Support Velox RegExp functions rlike and regexp_extract in Gluten.

* Enable GlutenStringFunctionsSuite and support Velox type VARBINARY and REAL.
  • Loading branch information
lviiii authored Nov 14, 2022
1 parent 468bbc6 commit 7d48403
Show file tree
Hide file tree
Showing 9 changed files with 487 additions and 270 deletions.

Large diffs are not rendered by default.

10 changes: 10 additions & 0 deletions cpp/velox/compute/ArrowTypeUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,18 @@ std::shared_ptr<arrow::DataType> toArrowTypeFromName(
if (type_name == "BIGINT") {
return arrow::int64();
}
if (type_name == "REAL") {
return arrow::float32();
}
if (type_name == "DOUBLE") {
return arrow::float64();
}
if (type_name == "VARCHAR") {
return arrow::utf8();
}
if (type_name == "VARBINARY") {
return arrow::utf8();
}
// The type name of Array type is like ARRAY<type>.
std::string arrayType = "ARRAY";
if (type_name.substr(0, arrayType.length()) == arrayType) {
Expand All @@ -63,10 +69,14 @@ std::shared_ptr<arrow::DataType> toArrowType(const TypePtr& type) {
return arrow::int32();
case TypeKind::BIGINT:
return arrow::int64();
case TypeKind::REAL:
return arrow::float32();
case TypeKind::DOUBLE:
return arrow::float64();
case TypeKind::VARCHAR:
return arrow::utf8();
case TypeKind::VARBINARY:
return arrow::utf8();
case TypeKind::TIMESTAMP:
return arrow::timestamp(arrow::TimeUnit::MICRO);
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,11 @@ object ConverterUtils extends Logging {
final val LENGTH = "char_length" // length
final val LOWER = "lower"
final val UPPER = "upper"
final val LOCATE = "locate"
final val LTRIM = "ltrim"
final val RTRIM = "rtrim"
final val LPAD = "lpad"
final val RPAD = "rpad"
final val REPLACE = "replace"
final val SPLIT = "split"
final val STARTS_WITH = "starts_with"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,56 +148,17 @@ object ExpressionConverter extends Logging {
attributeSeq),
i.hset,
expr)
case ss: StringReplace =>
case t: TernaryExpression =>
logInfo(s"${expr.getClass} ${expr} is supported.")
TernaryOperatorTransformer.create(
TernaryExpressionTransformer.create(
replaceWithExpressionTransformer(
ss.srcExpr,
t.first,
attributeSeq),
replaceWithExpressionTransformer(
ss.searchExpr,
t.second,
attributeSeq),
replaceWithExpressionTransformer(
ss.replaceExpr,
attributeSeq),
expr)
case ss: StringSplit =>
logInfo(s"${expr.getClass} ${expr} is supported.")
TernaryOperatorTransformer.create(
replaceWithExpressionTransformer(
ss.str,
attributeSeq),
replaceWithExpressionTransformer(
ss.regex,
attributeSeq),
replaceWithExpressionTransformer(
ss.limit,
attributeSeq),
expr)
case ss: RegExpExtract =>
logInfo(s"${expr.getClass} ${expr} is supported.")
TernaryOperatorTransformer.create(
replaceWithExpressionTransformer(
ss.subject,
attributeSeq),
replaceWithExpressionTransformer(
ss.regexp,
attributeSeq),
replaceWithExpressionTransformer(
ss.idx,
attributeSeq),
expr)
case ss: Substring =>
logInfo(s"${expr.getClass} ${expr} is supported.")
TernaryOperatorTransformer.create(
replaceWithExpressionTransformer(
ss.str,
attributeSeq),
replaceWithExpressionTransformer(
ss.pos,
attributeSeq),
replaceWithExpressionTransformer(
ss.len,
t.third,
attributeSeq),
expr)
case u: UnaryExpression =>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.glutenproject.expression

import com.google.common.collect.Lists
import io.glutenproject.expression.ConverterUtils.FunctionConfig
import io.glutenproject.substrait.`type`.TypeBuilder
import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode}
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions._

class LocateTransformer(first: Expression, second: Expression,
third: Expression, original: Expression)
extends StringLocate(first: Expression, second: Expression, second: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
val firstNode =
first.asInstanceOf[ExpressionTransformer].doTransform(args)
val secondNode =
second.asInstanceOf[ExpressionTransformer].doTransform(args)
val thirdNode =
third.asInstanceOf[ExpressionTransformer].doTransform(args)
if (!firstNode.isInstanceOf[ExpressionNode] ||
!secondNode.isInstanceOf[ExpressionNode] ||
!thirdNode.isInstanceOf[ExpressionNode]) {
throw new UnsupportedOperationException(s"Not supported yet.")
}

val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName = ConverterUtils.makeFuncName(ConverterUtils.LOCATE,
Seq(first.dataType, second.dataType, third.dataType), FunctionConfig.OPT)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode)
val typeNode = TypeBuilder.makeI64(original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

class LPadTransformer(first: Expression, second: Expression,
third: Expression, original: Expression)
extends StringLPad(first: Expression, second: Expression, second: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
val firstNode =
first.asInstanceOf[ExpressionTransformer].doTransform(args)
val secondNode =
second.asInstanceOf[ExpressionTransformer].doTransform(args)
val thirdNode =
third.asInstanceOf[ExpressionTransformer].doTransform(args)
if (!firstNode.isInstanceOf[ExpressionNode] ||
!secondNode.isInstanceOf[ExpressionNode] ||
!thirdNode.isInstanceOf[ExpressionNode]) {
throw new UnsupportedOperationException(s"Not supported yet.")
}

val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName = ConverterUtils.makeFuncName(ConverterUtils.LPAD,
Seq(first.dataType, second.dataType, third.dataType),
FunctionConfig.OPT)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode)
val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

class RPadTransformer(first: Expression, second: Expression,
third: Expression, original: Expression)
extends StringRPad(first: Expression, second: Expression, second: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
val firstNode =
first.asInstanceOf[ExpressionTransformer].doTransform(args)
val secondNode =
second.asInstanceOf[ExpressionTransformer].doTransform(args)
val thirdNode =
third.asInstanceOf[ExpressionTransformer].doTransform(args)
if (!firstNode.isInstanceOf[ExpressionNode] ||
!secondNode.isInstanceOf[ExpressionNode] ||
!thirdNode.isInstanceOf[ExpressionNode]) {
throw new UnsupportedOperationException(s"Not supported yet.")
}

val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName = ConverterUtils.makeFuncName(ConverterUtils.RPAD,
Seq(first.dataType, second.dataType, third.dataType),
FunctionConfig.OPT)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode)
val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

class RegExpExtractTransformer(subject: Expression, regexp: Expression,
index: Expression, original: Expression)
extends RegExpExtract(subject: Expression, regexp: Expression, regexp: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
val firstNode =
subject.asInstanceOf[ExpressionTransformer].doTransform(args)
val secondNode =
regexp.asInstanceOf[ExpressionTransformer].doTransform(args)
val thirdNode =
index.asInstanceOf[ExpressionTransformer].doTransform(args)
if (!firstNode.isInstanceOf[ExpressionNode] ||
!secondNode.isInstanceOf[ExpressionNode] ||
!thirdNode.isInstanceOf[ExpressionNode]) {
throw new UnsupportedOperationException(s"Not supported yet.")
}

val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName = ConverterUtils.makeFuncName(ConverterUtils.REGEXP_EXTRACT,
Seq(subject.dataType, regexp.dataType, index.dataType), FunctionConfig.OPT)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressionNodes = Lists.newArrayList(firstNode, secondNode, thirdNode)
val typeNode = TypeBuilder.makeString(original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

class ReplaceTransformer(str: Expression, search: Expression, replace: Expression,
original: Expression)
extends StringReplace(str: Expression, search: Expression, search: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
val strNode =
str.asInstanceOf[ExpressionTransformer].doTransform(args)
val searchNode =
search.asInstanceOf[ExpressionTransformer].doTransform(args)
val replaceNode =
replace.asInstanceOf[ExpressionTransformer].doTransform(args)
if (!strNode.isInstanceOf[ExpressionNode] ||
!searchNode.isInstanceOf[ExpressionNode] ||
!replaceNode.isInstanceOf[ExpressionNode]) {
throw new UnsupportedOperationException(s"Not supported yet.")
}

val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName = ConverterUtils.makeFuncName(ConverterUtils.REPLACE,
Seq(str.dataType, search.dataType, replace.dataType), FunctionConfig.OPT)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressionNodes = Lists.newArrayList(strNode, searchNode, replaceNode)
val typeNode = TypeBuilder.makeString(original.nullable)
ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

class SplitTransformer(str: Expression, delimiter: Expression, limit: Expression,
original: Expression)
extends StringSplit(str: Expression, delimiter: Expression, limit: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
throw new UnsupportedOperationException("Not supported: Split.")
}
}

class SubStringTransformer(str: Expression, pos: Expression, len: Expression, original: Expression)
extends Substring(str: Expression, pos: Expression, len: Expression)
with ExpressionTransformer
with Logging {

override def doTransform(args: java.lang.Object): ExpressionNode = {
val strNode =
str.asInstanceOf[ExpressionTransformer].doTransform(args)
val posNode =
pos.asInstanceOf[ExpressionTransformer].doTransform(args)
val lenNode =
len.asInstanceOf[ExpressionTransformer].doTransform(args)

if (!strNode.isInstanceOf[ExpressionNode] ||
!posNode.isInstanceOf[ExpressionNode] ||
!lenNode.isInstanceOf[ExpressionNode]) {
throw new UnsupportedOperationException(s"not supported yet.")
}

val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]]
val functionName = ConverterUtils.makeFuncName(
ConverterUtils.SUBSTRING, Seq(str.dataType), FunctionConfig.OPT)
val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName)
val expressionNodes = Lists.newArrayList(
strNode.asInstanceOf[ExpressionNode],
posNode.asInstanceOf[ExpressionNode],
lenNode.asInstanceOf[ExpressionNode])
// Substring inherits NullIntolerant, the output is nullable
val typeNode = TypeBuilder.makeString(true)

ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode)
}
}

object TernaryExpressionTransformer {

def create(first: Expression, second: Expression, third: Expression,
original: Expression): Expression = original match {
case _: StringLocate =>
// locate() gets incorrect results, so fall back to Vanilla Spark
throw new UnsupportedOperationException("Not supported: locate().")
case _: StringSplit =>
// split() gets incorrect results, so fall back to Vanilla Spark
throw new UnsupportedOperationException("Not supported: locate().")
case lpad: StringLPad =>
new LPadTransformer(first, second, third, lpad)
case rpad: StringRPad =>
new RPadTransformer(first, second, third, rpad)
case extract: RegExpExtract =>
new RegExpExtractTransformer(first, second, third, extract)
case replace: StringReplace =>
new ReplaceTransformer(first, second, third, replace)
case ss: Substring =>
new SubStringTransformer(first, second, third, ss)
case other =>
throw new UnsupportedOperationException(s"not currently supported: $other.")
}
}
Loading

0 comments on commit 7d48403

Please sign in to comment.