diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index b3379ccb3df9..a4dadccbae9e 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -18,31 +18,23 @@ package io.glutenproject.backendsapi.clickhouse import scala.collection.mutable.ArrayBuffer - import io.glutenproject.GlutenConfig import io.glutenproject.backendsapi.ISparkPlanExecApi import io.glutenproject.execution._ import io.glutenproject.expression.{AliasBaseTransformer, AliasTransformer} import io.glutenproject.vectorized.{BlockNativeWriter, CHColumnarBatchSerializer} - import org.apache.spark.{ShuffleDependency, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriterWrapper} import org.apache.spark.shuffle.utils.CHShuffleUtil import org.apache.spark.sql.{SparkSession, Strategy} -import org.apache.spark.sql.catalyst.expressions.{ - Alias, - Attribute, - AttributeReference, - BoundReference, - Expression, - ExprId, - NamedExpression -} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, BoundReference, Expression, ExprId, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.delta.DeltaLogFileIndex import org.apache.spark.sql.execution._ @@ -50,11 +42,7 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.execution.datasources.v2.V2CommandExec import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec -import org.apache.spark.sql.execution.joins.{ - BuildSideRelation, - ClickHouseBuildSideRelation, - HashedRelationBroadcastMode -} +import org.apache.spark.sql.execution.joins.{BuildSideRelation, ClickHouseBuildSideRelation, HashedRelationBroadcastMode} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.utils.CHExecUtil import org.apache.spark.sql.extension.{CHDataSourceV2Strategy, ClickHouseAnalysis} @@ -145,6 +133,33 @@ class CHSparkPlanExecApi extends ISparkPlanExecApi with AdaptiveSparkPlanHelper resultExpressions, child) + /** + * Generate ShuffledHashJoinExecTransformer. + */ + def genShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan): ShuffledHashJoinExecTransformer = + CHShuffledHashJoinExecTransformer( + leftKeys, rightKeys, joinType, buildSide, condition, left, right) + + /** + * Generate BroadcastHashJoinExecTransformer. + */ + def genBroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + : BroadcastHashJoinExecTransformer = CHBroadcastHashJoinExecTransformer( + leftKeys, rightKeys, joinType, buildSide, condition, left, right) + /** * Generate Alias transformer. * diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashJoinExecTransformer.scala b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashJoinExecTransformer.scala new file mode 100644 index 000000000000..5220110866a1 --- /dev/null +++ b/backends-clickhouse/src/main/scala/io/glutenproject/execution/CHHashJoinExecTransformer.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.execution + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.BuildSide +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.SparkPlan + +case class CHShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends ShuffledHashJoinExecTransformer( + leftKeys, rightKeys, joinType, buildSide, condition, left, right) { + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): CHShuffledHashJoinExecTransformer = + copy(left = newLeft, right = newRight) +} + +case class CHBroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + extends BroadcastHashJoinExecTransformer( + leftKeys, rightKeys, joinType, buildSide, condition, left, right) { + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): CHBroadcastHashJoinExecTransformer = + copy(left = newLeft, right = newRight) +} diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/benchmarks/DSV2BenchmarkTest.scala b/backends-clickhouse/src/test/scala/io/glutenproject/benchmarks/DSV2BenchmarkTest.scala index 8d73e783b92a..2e01bf12a6be 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/benchmarks/DSV2BenchmarkTest.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/benchmarks/DSV2BenchmarkTest.scala @@ -290,8 +290,8 @@ object DSV2BenchmarkTest extends AdaptiveSparkPlanHelper { def collectAllJoinSide(executedPlan: SparkPlan): Unit = { val buildSides = collect(executedPlan) { - case s: ShuffledHashJoinExecTransformer => "Shuffle-" + s.buildSide.toString - case b: BroadcastHashJoinExecTransformer => "Broadcast-" + b.buildSide.toString + case s: ShuffledHashJoinExecTransformer => "Shuffle-" + s.joinBuildSide.toString + case b: BroadcastHashJoinExecTransformer => "Broadcast-" + b.joinBuildSide.toString case os: ShuffledHashJoinExec => "Shuffle-" + os.buildSide.toString case ob: BroadcastHashJoinExec => "Broadcast-" + ob.buildSide.toString case sm: SortMergeJoinExec => "SortMerge-Join" diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableColumnarShuffleSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableColumnarShuffleSuite.scala index 55c4fcf761b1..bc28287ed66f 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableColumnarShuffleSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableColumnarShuffleSuite.scala @@ -63,7 +63,7 @@ class GlutenClickHouseTPCHNullableColumnarShuffleSuite extends GlutenClickHouseT withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) { runTPCHQuery(3) { df => val shjBuildLeft = df.queryExecution.executedPlan.collect { - case shj: ShuffledHashJoinExecTransformer if shj.buildSide == BuildLeft => shj + case shj: ShuffledHashJoinExecTransformer if shj.joinBuildSide == BuildLeft => shj } assert(shjBuildLeft.size == 2) } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala index a8bca23e10d3..dee6a2998332 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHNullableSuite.scala @@ -63,7 +63,7 @@ class GlutenClickHouseTPCHNullableSuite extends GlutenClickHouseTPCHAbstractSuit withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) { runTPCHQuery(3) { df => val shjBuildLeft = df.queryExecution.executedPlan.collect { - case shj: ShuffledHashJoinExecTransformer if shj.buildSide == BuildLeft => shj + case shj: ShuffledHashJoinExecTransformer if shj.joinBuildSide == BuildLeft => shj } assert(shjBuildLeft.size == 2) } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala index 8a839594a091..8289446eed75 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetAQESuite.scala @@ -75,7 +75,7 @@ class GlutenClickHouseTPCHParquetAQESuite runTPCHQuery(3) { df => assert(df.queryExecution.executedPlan.isInstanceOf[AdaptiveSparkPlanExec]) val shjBuildLeft = collect(df.queryExecution.executedPlan) { - case shj: ShuffledHashJoinExecTransformer if shj.buildSide == BuildLeft => shj + case shj: ShuffledHashJoinExecTransformer if shj.joinBuildSide == BuildLeft => shj } assert(shjBuildLeft.size == 2) } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala index edee548c7753..77ad56cfb23b 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHParquetSuite.scala @@ -219,7 +219,7 @@ class GlutenClickHouseTPCHParquetSuite extends GlutenClickHouseTPCHAbstractSuite withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) { runTPCHQuery(3) { df => val shjBuildLeft = df.queryExecution.executedPlan.collect { - case shj: ShuffledHashJoinExecTransformer if shj.buildSide == BuildLeft => shj + case shj: ShuffledHashJoinExecTransformer if shj.joinBuildSide == BuildLeft => shj } assert(shjBuildLeft.size == 1) } diff --git a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSuite.scala b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSuite.scala index 50a300960f36..ed1ed5e92a9a 100644 --- a/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSuite.scala +++ b/backends-clickhouse/src/test/scala/io/glutenproject/execution/GlutenClickHouseTPCHSuite.scala @@ -60,7 +60,7 @@ class GlutenClickHouseTPCHSuite extends GlutenClickHouseTPCHAbstractSuite { withSQLConf(("spark.sql.autoBroadcastJoinThreshold", "-1")) { runTPCHQuery(3) { df => val shjBuildLeft = df.queryExecution.executedPlan.collect { - case shj: ShuffledHashJoinExecTransformer if shj.buildSide == BuildLeft => shj + case shj: ShuffledHashJoinExecTransformer if shj.joinBuildSide == BuildLeft => shj } assert(shjBuildLeft.size == 2) } diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxSparkPlanExecApi.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxSparkPlanExecApi.scala index b6cb486559c8..a83b1c6414bb 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxSparkPlanExecApi.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/VeloxSparkPlanExecApi.scala @@ -25,7 +25,6 @@ import io.glutenproject.columnarbatch.ArrowColumnarBatches import io.glutenproject.execution._ import io.glutenproject.expression.{AliasBaseTransformer, ArrowConverterUtils, VeloxAliasTransformer} import io.glutenproject.vectorized.{ArrowColumnarBatchSerializer, ArrowWritableColumnVector} - import org.apache.spark.{ShuffleDependency, SparkException} import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -35,8 +34,10 @@ import org.apache.spark.sql.VeloxColumnarRules._ import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, ExprId, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan, VeloxBuildSideRelation} import org.apache.spark.sql.execution.datasources.v2.arrow.SparkMemoryUtils @@ -113,6 +114,33 @@ class VeloxSparkPlanExecApi extends ISparkPlanExecApi { resultExpressions, child) + /** + * Generate ShuffledHashJoinExecTransformer. + */ + def genShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan): ShuffledHashJoinExecTransformer = + VeloxShuffledHashJoinExecTransformer( + leftKeys, rightKeys, joinType, buildSide, condition, left, right) + + /** + * Generate BroadcastHashJoinExecTransformer. + */ + def genBroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + : BroadcastHashJoinExecTransformer = VeloxBroadcastHashJoinExecTransformer( + leftKeys, rightKeys, joinType, buildSide, condition, left, right) + /** * Generate Alias transformer. * diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/VeloxHashJoinExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/VeloxHashJoinExecTransformer.scala new file mode 100644 index 000000000000..349fcbe717a6 --- /dev/null +++ b/backends-velox/src/main/scala/io/glutenproject/execution/VeloxHashJoinExecTransformer.scala @@ -0,0 +1,283 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.execution + +import io.glutenproject.execution.HashJoinLikeExecTransformer.{makeAndExpression, makeEqualToExpression, makeIsNullExpression} +import io.glutenproject.expression._ +import io.glutenproject.substrait.SubstraitContext +import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} +import io.glutenproject.substrait.rel.{RelBuilder, RelNode} +import io.substrait.proto.JoinRel +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide} +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.{BooleanType, DataType} +import java.util + +import scala.collection.JavaConverters._ + +trait VeloxHashJoinLikeExecTransformer extends HashJoinLikeExecTransformer { + + // Direct output order of substrait join operation + override protected val substraitJoinType: JoinRel.JoinType = joinType match { + case Inner => + JoinRel.JoinType.JOIN_TYPE_INNER + case FullOuter => + JoinRel.JoinType.JOIN_TYPE_OUTER + case LeftOuter | RightOuter => + // The right side is required to be used for building hash table in Substrait plan. + // Therefore, for RightOuter Join, the left and right relations are exchanged and the + // join type is reverted. + JoinRel.JoinType.JOIN_TYPE_LEFT + case LeftSemi => + JoinRel.JoinType.JOIN_TYPE_SEMI + case LeftAnti => + if (!antiJoinWorkaroundNeeded) { + JoinRel.JoinType.JOIN_TYPE_ANTI + } else { + // Use Left to replace Anti as a workaround. + JoinRel.JoinType.JOIN_TYPE_LEFT + } + case _ => + // TODO: Support cross join with Cross Rel + // TODO: Support existence join + JoinRel.JoinType.UNRECOGNIZED + } + + /** + * Use a workaround for Anti join with 'not exists' semantics. + * Firstly, add an aggregation rel over build rel to make the build keys distinct. + * Secondly, add a project rel over the aggregation rel to append a column of true constant. + * Then, this project rel is returned and will be used as the input rel for join build side. + * @param buildInfo: the original build keys, build rel and build outputs. + * @param context: Substrait context. + * @param operatorId: operator id of this join. + * @return original build keys, new build rel and build outputs. + */ + private def createSpecialRelForAntiBuild( + buildInfo: (Seq[(ExpressionNode, DataType)], RelNode, Seq[Attribute]), + context: SubstraitContext, + operatorId: java.lang.Long, + validation: Boolean): (Seq[(ExpressionNode, DataType)], RelNode, Seq[Attribute]) = { + // Create an Aggregation Rel over build Rel. + val groupingNodes = new util.ArrayList[ExpressionNode]() + var colIdx = 0 + buildInfo._3.foreach(_ => { + groupingNodes.add(ExpressionBuilder.makeSelection(colIdx)) + colIdx += 1 + }) + val aggNode = RelBuilder.makeAggregateRel( + buildInfo._2, + groupingNodes, + new util.ArrayList[AggregateFunctionNode](), + context, + operatorId) + // Create a Project Rel over Aggregation Rel. + val expressionNodes = groupingNodes + // Append a new column of true constant. + expressionNodes.add(ExpressionBuilder.makeBooleanLiteral(true)) + val projectNode = RelBuilder.makeProjectRel( + aggNode, + expressionNodes, + createExtensionNode(buildInfo._3, validation), + context, + operatorId) + ( + buildInfo._1, + projectNode, + buildInfo._3 :+ AttributeReference(s"constant_true", BooleanType)() + ) + } + + override def createJoinRel(inputStreamedRelNode: RelNode, + inputBuildRelNode: RelNode, + inputStreamedOutput: Seq[Attribute], + inputBuildOutput: Seq[Attribute], + substraitContext: SubstraitContext, + operatorId: java.lang.Long, + validation: Boolean = false): RelNode = { + // Create pre-projection for build/streamed plan. Append projected keys to each side. + val (streamedKeys, streamedRelNode, streamedOutput) = createPreProjectionIfNeeded( + streamedKeyExprs, + inputStreamedRelNode, + inputStreamedOutput, + inputStreamedOutput, + substraitContext, + operatorId, + validation) + + val (buildKeys, buildRelNode, buildOutput) = { + val (keys, relNode, output) = createPreProjectionIfNeeded( + buildKeyExprs, + inputBuildRelNode, + inputBuildOutput, + streamedOutput ++ inputBuildOutput, + substraitContext, + operatorId, + validation) + if (!antiJoinWorkaroundNeeded) { + (keys, relNode, output) + } else { + // Use a workaround for Anti join. + createSpecialRelForAntiBuild( + (keys, relNode, output), substraitContext, operatorId, validation) + } + } + + // Combine join keys to make a single expression. + val joinExpressionNode = (streamedKeys zip buildKeys).map { + case ((leftKey, leftType), (rightKey, rightType)) => + makeEqualToExpression( + leftKey, leftType, rightKey, rightType, substraitContext.registeredFunction) + }.reduce((l, r) => makeAndExpression(l, r, substraitContext.registeredFunction)) + + // Create post-join filter, which will be computed in hash join. + val postJoinFilter = condition.map { + expr => + ExpressionConverter + .replaceWithExpressionTransformer(expr, streamedOutput ++ buildOutput) + .asInstanceOf[ExpressionTransformer] + .doTransform(substraitContext.registeredFunction) + } + + // Create JoinRel. + val joinRel = { + val joinNode = RelBuilder.makeJoinRel( + streamedRelNode, + buildRelNode, + substraitJoinType, + joinExpressionNode, + postJoinFilter.orNull, + createJoinExtensionNode(streamedOutput ++ buildOutput, validation), + substraitContext, + operatorId) + if (!antiJoinWorkaroundNeeded) { + joinNode + } else { + // Use an isNulll filter to select the rows needed by Anti join from Left join outputs. + val isNullFilter = makeIsNullExpression( + ExpressionBuilder.makeSelection(streamedOutput.size + buildOutput.size - 1), + substraitContext.registeredFunction) + RelBuilder.makeFilterRel( + joinNode, + isNullFilter, + createJoinExtensionNode(streamedOutput ++ buildOutput, validation), + substraitContext, + operatorId) + } + } + + // Result projection will drop the appended keys, and exchange columns order if BuildLeft. + val resultProjection = joinBuildSide match { + case BuildLeft => + val (leftOutput, rightOutput) = + getResultProjectionOutput(inputBuildOutput, inputStreamedOutput) + // Exchange the order of build and streamed. + leftOutput.indices.map(idx => + ExpressionBuilder.makeSelection(idx + streamedOutput.size)) ++ + rightOutput.indices + .map(ExpressionBuilder.makeSelection(_)) + case BuildRight => + val (leftOutput, rightOutput) = + getResultProjectionOutput(inputStreamedOutput, inputBuildOutput) + leftOutput.indices.map(ExpressionBuilder.makeSelection(_)) ++ + rightOutput.indices.map(idx => ExpressionBuilder.makeSelection(idx + streamedOutput.size)) + } + + RelBuilder.makeProjectRel( + joinRel, + new java.util.ArrayList[ExpressionNode](resultProjection.asJava), + createExtensionNode(streamedOutput ++ buildOutput, validation), + substraitContext, + operatorId) + } +} + +case class VeloxShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends ShuffledHashJoinExecTransformer( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + left, + right) with VeloxHashJoinLikeExecTransformer { + + /** + * Returns whether a workaround for Anti join is needed. True for 'not exists' semantics. + * For SHJ, always returns true for Anti join. + */ + override def antiJoinWorkaroundNeeded: Boolean = { + joinType match { + case LeftAnti => true + case _ => false + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): VeloxShuffledHashJoinExecTransformer = + copy(left = newLeft, right = newRight) +} + +case class VeloxBroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + extends BroadcastHashJoinExecTransformer( + leftKeys, + rightKeys, + joinType, + buildSide, + condition, + left, + right) with VeloxHashJoinLikeExecTransformer { + + /** + * Returns whether a workaround for Anti join is needed. True for 'not exists' semantics. + * For BHJ, only when isNullAwareAntiJoin is disabled, true is returned. + */ + override def antiJoinWorkaroundNeeded: Boolean = { + joinType match { + case LeftAnti => + if (isNullAwareAntiJoin) { + false + } else { + // Velox's Anti semantics are matched with the case when isNullAwareAntiJoin is enabled. + // So a workaround is needed if isNullAwareAntiJoin is disabled. + true + } + case _ => + false + } + } + + override protected def withNewChildrenInternal( + newLeft: SparkPlan, newRight: SparkPlan): VeloxBroadcastHashJoinExecTransformer = + copy(left = newLeft, right = newRight) +} diff --git a/jvm/src/main/java/io/glutenproject/substrait/expression/BooleanLiteralNode.java b/jvm/src/main/java/io/glutenproject/substrait/expression/BooleanLiteralNode.java new file mode 100644 index 000000000000..b19aa4bd9fc0 --- /dev/null +++ b/jvm/src/main/java/io/glutenproject/substrait/expression/BooleanLiteralNode.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.glutenproject.substrait.expression; + +import io.substrait.proto.Expression; + +import java.io.Serializable; + +public class BooleanLiteralNode implements ExpressionNode, Serializable { + private final Boolean value; + + public BooleanLiteralNode(Boolean value) { + this.value = value; + } + + @Override + public Expression toProtobuf() { + Expression.Literal.Builder booleanBuilder = + Expression.Literal.newBuilder(); + booleanBuilder.setBoolean(value); + + Expression.Builder builder = Expression.newBuilder(); + builder.setLiteral(booleanBuilder.build()); + return builder.build(); + } +} diff --git a/jvm/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java b/jvm/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java index e71d55363d60..6e1a51d6ff0e 100644 --- a/jvm/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java +++ b/jvm/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java @@ -50,6 +50,10 @@ public static NullLiteralNode makeNullLiteral(TypeNode typeNode) { return new NullLiteralNode(typeNode); } + public static BooleanLiteralNode makeBooleanLiteral(Boolean booleanConstant) { + return new BooleanLiteralNode(booleanConstant); + } + public static IntLiteralNode makeIntLiteral(Integer intConstant) { return new IntLiteralNode(intConstant); } diff --git a/jvm/src/main/scala/io/glutenproject/backendsapi/ISparkPlanExecApi.scala b/jvm/src/main/scala/io/glutenproject/backendsapi/ISparkPlanExecApi.scala index 52ac1f1af214..dd5cdf9f19aa 100644 --- a/jvm/src/main/scala/io/glutenproject/backendsapi/ISparkPlanExecApi.scala +++ b/jvm/src/main/scala/io/glutenproject/backendsapi/ISparkPlanExecApi.scala @@ -17,9 +17,8 @@ package io.glutenproject.backendsapi -import io.glutenproject.execution.{FilterExecBaseTransformer, HashAggregateExecBaseTransformer, NativeColumnarToRowExec, RowToArrowColumnarExec} +import io.glutenproject.execution.{BroadcastHashJoinExecTransformer, FilterExecBaseTransformer, HashAggregateExecBaseTransformer, NativeColumnarToRowExec, RowToArrowColumnarExec, ShuffledHashJoinExecTransformer} import io.glutenproject.expression.AliasBaseTransformer - import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -27,8 +26,10 @@ import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriter import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, ExprId, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -82,6 +83,30 @@ trait ISparkPlanExecApi extends IBackendsApi { resultExpressions: Seq[NamedExpression], child: SparkPlan): HashAggregateExecBaseTransformer + /** + * Generate ShuffledHashJoinExecTransformer. + */ + def genShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan): ShuffledHashJoinExecTransformer + + /** + * Generate BroadcastHashJoinExecTransformer. + */ + def genBroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + : BroadcastHashJoinExecTransformer + /** * Generate Alias transformer. * diff --git a/jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala b/jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala index ee1671e6831f..daba115adfbd 100644 --- a/jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala +++ b/jvm/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala @@ -20,7 +20,6 @@ package io.glutenproject.execution import com.google.common.collect.Lists import com.google.protobuf.{Any, ByteString} import io.glutenproject.GlutenConfig -import io.glutenproject.execution.HashJoinLikeExecTransformer.{makeAndExpression, makeEqualToExpression} import io.glutenproject.expression._ import io.glutenproject.substrait.{JoinParams, SubstraitContext} import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} @@ -42,8 +41,8 @@ import org.apache.spark.sql.execution.joins.{BaseJoinExec, BuildSideRelation, Ha import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} import org.apache.spark.sql.types.{BooleanType, DataType} import org.apache.spark.sql.vectorized.ColumnarBatch - import java.{lang, util} + import scala.collection.JavaConverters._ import scala.util.control.Breaks.{break, breakable} @@ -100,16 +99,10 @@ trait ColumnarShuffledJoin extends BaseJoinExec { /** * Performs a hash join of two child relations by first shuffling the data using the join keys. */ -abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - buildSide: BuildSide, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) - extends BaseJoinExec - with TransformSupport - with ColumnarShuffledJoin { +trait HashJoinLikeExecTransformer + extends BaseJoinExec with TransformSupport with ColumnarShuffledJoin { + + def joinBuildSide: BuildSide override lazy val metrics = Map( "streamInputRows" -> SQLMetrics.createMetric( @@ -239,6 +232,56 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], "hashBuildNumMemoryAllocations" -> SQLMetrics.createMetric( sparkContext, "number of hash build memory allocations"), + "antiDistinctInputRows" -> SQLMetrics.createMetric( + sparkContext, "number of anti join distinct aggregation input rows"), + "antiDistinctInputVectors" -> SQLMetrics.createMetric( + sparkContext, "number of anti join distinct aggregation input vectors"), + "antiDistinctInputBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "number of anti join distinct aggregation input bytes"), + "antiDistinctRawInputRows" -> SQLMetrics.createMetric( + sparkContext, "number of anti join distinct aggregation raw input rows"), + "antiDistinctRawInputBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "number of anti join distinct aggregation raw input bytes"), + "antiDistinctOutputRows" -> SQLMetrics.createMetric( + sparkContext, "number of anti join distinct aggregation output rows"), + "antiDistinctOutputVectors" -> SQLMetrics.createMetric( + sparkContext, "number of anti join distinct aggregation output vectors"), + "antiDistinctOutputBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "number of anti join distinct aggregation output bytes"), + "antiDistinctCount" -> SQLMetrics.createMetric( + sparkContext, "anti join distinct aggregation cpu wall time count"), + "antiDistinctWallNanos" -> SQLMetrics.createNanoTimingMetric( + sparkContext, "totaltime_anti_join_distinct_aggregation"), + "antiDistinctPeakMemoryBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "anti join distinct aggregation peak memory bytes"), + "antiDistinctNumMemoryAllocations" -> SQLMetrics.createMetric( + sparkContext, "number of anti join distinct aggregation memory allocations"), + + "antiProjectInputRows" -> SQLMetrics.createMetric( + sparkContext, "number of anti project input rows"), + "antiProjectInputVectors" -> SQLMetrics.createMetric( + sparkContext, "number of anti project input vectors"), + "antiProjectInputBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "number of anti project input bytes"), + "antiProjectRawInputRows" -> SQLMetrics.createMetric( + sparkContext, "number of anti project raw input rows"), + "antiProjectRawInputBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "number of anti project raw input bytes"), + "antiProjectOutputRows" -> SQLMetrics.createMetric( + sparkContext, "number of anti project output rows"), + "antiProjectOutputVectors" -> SQLMetrics.createMetric( + sparkContext, "number of anti project output vectors"), + "antiProjectOutputBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "number of anti project output bytes"), + "antiProjectCount" -> SQLMetrics.createMetric( + sparkContext, "anti project cpu wall time count"), + "antiProjectWallNanos" -> SQLMetrics.createNanoTimingMetric( + sparkContext, "totaltime_anti_project"), + "antiProjectPeakMemoryBytes" -> SQLMetrics.createSizeMetric( + sparkContext, "anti project peak memory bytes"), + "antiProjectNumMemoryAllocations" -> SQLMetrics.createMetric( + sparkContext, "number of anti project memory allocations"), + "hashProbeInputRows" -> SQLMetrics.createMetric( sparkContext, "number of hash probe input rows"), "hashProbeInputVectors" -> SQLMetrics.createMetric( @@ -368,6 +411,38 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], val hashBuildPeakMemoryBytes: SQLMetric = longMetric("hashBuildPeakMemoryBytes") val hashBuildNumMemoryAllocations: SQLMetric = longMetric("hashBuildNumMemoryAllocations") + /** + * The metrics of build side distinct aggregation, relating to the special handling for Anti join. + */ + val antiDistinctInputRows: SQLMetric = longMetric("antiDistinctInputRows") + val antiDistinctInputVectors: SQLMetric = longMetric("antiDistinctInputVectors") + val antiDistinctInputBytes: SQLMetric = longMetric("antiDistinctInputBytes") + val antiDistinctRawInputRows: SQLMetric = longMetric("antiDistinctRawInputRows") + val antiDistinctRawInputBytes: SQLMetric = longMetric("antiDistinctRawInputBytes") + val antiDistinctOutputRows: SQLMetric = longMetric("antiDistinctOutputRows") + val antiDistinctOutputVectors: SQLMetric = longMetric("antiDistinctOutputVectors") + val antiDistinctOutputBytes: SQLMetric = longMetric("antiDistinctOutputBytes") + val antiDistinctCount: SQLMetric = longMetric("antiDistinctCount") + val antiDistinctWallNanos: SQLMetric = longMetric("antiDistinctWallNanos") + val antiDistinctPeakMemoryBytes: SQLMetric = longMetric("antiDistinctPeakMemoryBytes") + val antiDistinctNumMemoryAllocations: SQLMetric = longMetric("antiDistinctNumMemoryAllocations") + + /** + * The metrics of build side extra projection, relating to the special handling for Anti join. + */ + val antiProjectInputRows: SQLMetric = longMetric("antiProjectInputRows") + val antiProjectInputVectors: SQLMetric = longMetric("antiProjectInputVectors") + val antiProjectInputBytes: SQLMetric = longMetric("antiProjectInputBytes") + val antiProjectRawInputRows: SQLMetric = longMetric("antiProjectRawInputRows") + val antiProjectRawInputBytes: SQLMetric = longMetric("antiProjectRawInputBytes") + val antiProjectOutputRows: SQLMetric = longMetric("antiProjectOutputRows") + val antiProjectOutputVectors: SQLMetric = longMetric("antiProjectOutputVectors") + val antiProjectOutputBytes: SQLMetric = longMetric("antiProjectOutputBytes") + val antiProjectCount: SQLMetric = longMetric("antiProjectCount") + val antiProjectWallNanos: SQLMetric = longMetric("antiProjectWallNanos") + val antiProjectPeakMemoryBytes: SQLMetric = longMetric("antiProjectPeakMemoryBytes") + val antiProjectNumMemoryAllocations: SQLMetric = longMetric("antiProjectNumMemoryAllocations") + val hashProbeInputRows: SQLMetric = longMetric("hashProbeInputRows") val hashProbeInputVectors: SQLMetric = longMetric("hashProbeInputVectors") val hashProbeInputBytes: SQLMetric = longMetric("hashProbeInputBytes") @@ -408,7 +483,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], val finalOutputVectors: SQLMetric = longMetric("finalOutputVectors") def isSkewJoin: Boolean = false - lazy val (buildPlan, streamedPlan) = buildSide match { + lazy val (buildPlan, streamedPlan) = joinBuildSide match { case BuildLeft => (left, right) case BuildRight => (right, left) } @@ -419,19 +494,28 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], "Join keys from two sides should have same types") val lkeys = HashJoin.rewriteKeyExpr(leftKeys) val rkeys = HashJoin.rewriteKeyExpr(rightKeys) - buildSide match { + joinBuildSide match { case BuildLeft => (lkeys, rkeys) case BuildRight => (rkeys, lkeys) } } + /** + * Returns whether a workaround for Anti join is needed. + * True for 'not exists' semantics on Velox backend. + */ + def antiJoinWorkaroundNeeded: Boolean = false + // Direct output order of substrait join operation - private val substraitJoinType = joinType match { + protected val substraitJoinType: JoinRel.JoinType = joinType match { case Inner => JoinRel.JoinType.JOIN_TYPE_INNER case FullOuter => JoinRel.JoinType.JOIN_TYPE_OUTER case LeftOuter | RightOuter => + // The right side is required to be used for building hash table in Substrait plan. + // Therefore, for RightOuter Join, the left and right relations are exchanged and the + // join type is reverted. JoinRel.JoinType.JOIN_TYPE_LEFT case LeftSemi => JoinRel.JoinType.JOIN_TYPE_SEMI @@ -473,6 +557,12 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], idx += 1 } + if (antiJoinWorkaroundNeeded) { + // Filter of isNull over project are mapped into FilterProject operator in Velox. + // Therefore, the filter metrics are empty and no need to be updated. + idx += 1 + } + // HashProbe val hashProbeMetrics = joinMetrics.get(idx) hashProbeInputRows += hashProbeMetrics.inputRows @@ -507,6 +597,38 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], hashBuildNumMemoryAllocations += hashBuildMetrics.numMemoryAllocations idx += 1 + if (antiJoinWorkaroundNeeded) { + var metrics = joinMetrics.get(idx) + antiProjectInputRows += metrics.inputRows + antiProjectInputVectors += metrics.inputVectors + antiProjectInputBytes += metrics.inputBytes + antiProjectRawInputRows += metrics.rawInputRows + antiProjectRawInputBytes += metrics.rawInputBytes + antiProjectOutputRows += metrics.outputRows + antiProjectOutputVectors += metrics.outputVectors + antiProjectOutputBytes += metrics.outputBytes + antiProjectCount += metrics.count + antiProjectWallNanos += metrics.wallNanos + antiProjectPeakMemoryBytes += metrics.peakMemoryBytes + antiProjectNumMemoryAllocations += metrics.numMemoryAllocations + idx += 1 + + metrics = joinMetrics.get(idx) + antiDistinctInputRows += metrics.inputRows + antiDistinctInputVectors += metrics.inputVectors + antiDistinctInputBytes += metrics.inputBytes + antiDistinctRawInputRows += metrics.rawInputRows + antiDistinctRawInputBytes += metrics.rawInputBytes + antiDistinctOutputRows += metrics.outputRows + antiDistinctOutputVectors += metrics.outputVectors + antiDistinctOutputBytes += metrics.outputBytes + antiDistinctCount += metrics.count + antiDistinctWallNanos += metrics.wallNanos + antiDistinctPeakMemoryBytes += metrics.peakMemoryBytes + antiDistinctNumMemoryAllocations += metrics.numMemoryAllocations + idx += 1 + } + if (joinParams.buildPreProjectionNeeded) { val metrics = joinMetrics.get(idx) buildPreProjectionInputRows += metrics.inputRows @@ -576,7 +698,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], } } - override def outputPartitioning: Partitioning = buildSide match { + override def outputPartitioning: Partitioning = joinBuildSide match { case BuildLeft => joinType match { case _: InnerLike | RightOuter => right.outputPartitioning @@ -669,7 +791,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], if (joinParams.isStreamedReadRel) { substraitContext.registerRelToOperator(operatorId) } - if(joinParams.isBuildReadRel) { + if (joinParams.isBuildReadRel) { substraitContext.registerRelToOperator(operatorId) } @@ -697,7 +819,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], !keyExprs.forall(_.isInstanceOf[AttributeReference]) } - private def createPreProjectionIfNeeded(keyExprs: Seq[Expression], + protected def createPreProjectionIfNeeded(keyExprs: Seq[Expression], inputNode: RelNode, inputNodeOutput: Seq[Attribute], joinOutput: Seq[Attribute], @@ -758,7 +880,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], } - private def createJoinRel(inputStreamedRelNode: RelNode, + protected def createJoinRel(inputStreamedRelNode: RelNode, inputBuildRelNode: RelNode, inputStreamedOutput: Seq[Attribute], inputBuildOutput: Seq[Attribute], @@ -787,9 +909,10 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], // Combine join keys to make a single expression. val joinExpressionNode = (streamedKeys zip buildKeys).map { case ((leftKey, leftType), (rightKey, rightType)) => - makeEqualToExpression( + HashJoinLikeExecTransformer.makeEqualToExpression( leftKey, leftType, rightKey, rightType, substraitContext.registeredFunction) - }.reduce((l, r) => makeAndExpression(l, r, substraitContext.registeredFunction)) + }.reduce((l, r) => + HashJoinLikeExecTransformer.makeAndExpression(l, r, substraitContext.registeredFunction)) // Create post-join filter, which will be computed in hash join. val postJoinFilter = condition.map { @@ -812,7 +935,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], operatorId) // Result projection will drop the appended keys, and exchange columns order if BuildLeft. - val resultProjection = buildSide match { + val resultProjection = joinBuildSide match { case BuildLeft => val (leftOutput, rightOutput) = getResultProjectionOutput(inputBuildOutput, inputStreamedOutput) @@ -839,7 +962,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], private def createTransformContext(rel: RelNode, inputStreamedOutput: Seq[Attribute], inputBuildOutput: Seq[Attribute]): TransformContext = { - val inputAttributes = buildSide match { + val inputAttributes = joinBuildSide match { case BuildLeft => inputBuildOutput ++ inputStreamedOutput case BuildRight => inputStreamedOutput ++ inputBuildOutput } @@ -854,8 +977,8 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], new util.ArrayList[TypeNode](inputTypeNodes.asJava)).toProtobuf) } - private def createExtensionNode(output: Seq[Attribute], - validation: Boolean): AdvancedExtensionNode = { + protected def createExtensionNode(output: Seq[Attribute], + validation: Boolean): AdvancedExtensionNode = { // Use field [enhancement] in a extension node for input type validation. if (validation) { ExtensionBuilder.makeAdvancedExtension( @@ -865,7 +988,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], } } - private def createJoinExtensionNode(output: Seq[Attribute], + protected def createJoinExtensionNode(output: Seq[Attribute], validation: Boolean): AdvancedExtensionNode = { // Use field [optimization] in a extension node // to send some join parameters through Substrait plan. @@ -898,7 +1021,7 @@ abstract class HashJoinLikeExecTransformer(leftKeys: Seq[Expression], } // The output of result projection should be consistent with ShuffledJoin.output - private def getResultProjectionOutput(leftOutput: Seq[Attribute], + protected def getResultProjectionOutput(leftOutput: Seq[Attribute], rightOutput: Seq[Attribute]) : (Seq[Attribute], Seq[Attribute]) = { joinType match { @@ -957,31 +1080,34 @@ object HashJoinLikeExecTransformer { ExpressionBuilder.makeScalarFunction(functionId, expressionNodes, typeNode) } + + def makeIsNullExpression(childNode: ExpressionNode, + functionMap: java.util.HashMap[String, java.lang.Long]) + : ExpressionNode = { + val functionId = ExpressionBuilder.newScalarFunction( + functionMap, ConverterUtils.makeFuncName(ConverterUtils.IS_NULL, Seq(BooleanType))) + + ExpressionBuilder.makeScalarFunction( + functionId, + Lists.newArrayList(childNode), + TypeBuilder.makeBoolean(true)) + } } -case class ShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - buildSide: BuildSide, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan) - extends HashJoinLikeExecTransformer( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - left, - right) { +abstract class ShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan) + extends HashJoinLikeExecTransformer { + + override def joinBuildSide: BuildSide = buildSide override def columnarInputRDDs: Seq[RDD[ColumnarBatch]] = { getColumnarInputRDDs(streamedPlan) ++ getColumnarInputRDDs(buildPlan) } - - override protected def withNewChildrenInternal( - newLeft: SparkPlan, newRight: SparkPlan): ShuffledHashJoinExecTransformer = - copy(left = newLeft, right = newRight) } case class BroadCastHashJoinContext(buildSideJoinKeys: Seq[Expression], @@ -989,22 +1115,17 @@ case class BroadCastHashJoinContext(buildSideJoinKeys: Seq[Expression], buildSideStructure: Seq[Attribute], buildHashTableId: String) -case class BroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], - rightKeys: Seq[Expression], - joinType: JoinType, - buildSide: BuildSide, - condition: Option[Expression], - left: SparkPlan, - right: SparkPlan, - isNullAwareAntiJoin: Boolean = false) - extends HashJoinLikeExecTransformer( - leftKeys, - rightKeys, - joinType, - buildSide, - condition, - left, - right) { +abstract class BroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + extends HashJoinLikeExecTransformer { + + override def joinBuildSide: BuildSide = buildSide // Unique ID for builded hash table lazy val buildHashTableId = "BuildedHashTable-" + buildPlan.id @@ -1067,8 +1188,4 @@ case class BroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], } streamedRDD :+ buildRDD } - - override protected def withNewChildrenInternal( - newLeft: SparkPlan, newRight: SparkPlan): BroadcastHashJoinExecTransformer = - copy(left = newLeft, right = newRight) } diff --git a/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala b/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala index 66c28c90de2c..47c028dfc3d4 100644 --- a/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala +++ b/jvm/src/main/scala/io/glutenproject/extension/ColumnarOverrides.scala @@ -145,14 +145,15 @@ case class TransformPreOverrides() extends Rule[SparkPlan] { val left = replaceWithTransformerPlan(plan.left, isSupportAdaptive) val right = replaceWithTransformerPlan(plan.right, isSupportAdaptive) logDebug(s"Columnar Processing for ${plan.getClass} is currently supported.") - ShuffledHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right) + BackendsApiManager.getSparkPlanExecApiInstance + .genShuffledHashJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right) case plan: SortMergeJoinExec => val left = replaceWithTransformerPlan(plan.left, isSupportAdaptive) val right = replaceWithTransformerPlan(plan.right, isSupportAdaptive) @@ -176,15 +177,16 @@ case class TransformPreOverrides() extends Rule[SparkPlan] { case plan: BroadcastHashJoinExec => val left = replaceWithTransformerPlan(plan.left, isSupportAdaptive) val right = replaceWithTransformerPlan(plan.right, isSupportAdaptive) - BroadcastHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - left, - right, - isNullAwareAntiJoin = plan.isNullAwareAntiJoin) + BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastHashJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + left, + right, + isNullAwareAntiJoin = plan.isNullAwareAntiJoin) case plan: AQEShuffleReadExec if columnarConf.enableColumnarShuffle => plan.child match { case shuffle: ColumnarShuffleExchangeAdaptor => diff --git a/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala b/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala index 81478b2f5d51..58f7c0f24051 100644 --- a/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala +++ b/jvm/src/main/scala/io/glutenproject/extension/ColumnarQueryStagePrepRule.scala @@ -18,8 +18,8 @@ package io.glutenproject.extension import io.glutenproject.{GlutenConfig, GlutenSparkExtensionsInjector} +import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.execution.BroadcastHashJoinExecTransformer - import org.apache.spark.sql.{SparkSession, SparkSessionExtensions} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreeNodeTag @@ -47,15 +47,16 @@ case class ColumnarQueryStagePrepRule(session: SparkSession) extends Rule[SparkP case bhj: BroadcastHashJoinExec => if (columnarConf.enableColumnarBroadcastExchange && columnarConf.enableColumnarBroadcastJoin) { - val transformer = BroadcastHashJoinExecTransformer( - bhj.leftKeys, - bhj.rightKeys, - bhj.joinType, - bhj.buildSide, - bhj.condition, - bhj.left, - bhj.right, - bhj.isNullAwareAntiJoin) + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastHashJoinExecTransformer( + bhj.leftKeys, + bhj.rightKeys, + bhj.joinType, + bhj.buildSide, + bhj.condition, + bhj.left, + bhj.right, + bhj.isNullAwareAntiJoin) if (!transformer.doValidate()) { bhj.children.map { // ResuedExchange is not created yet, so we don't need to handle that case. diff --git a/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala b/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala index 749b724e2dcc..594fa3e9ead7 100644 --- a/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala +++ b/jvm/src/main/scala/io/glutenproject/extension/columnar/ColumnarGuardRule.scala @@ -147,14 +147,15 @@ case class TransformGuardRule() extends Rule[SparkPlan] { exec.doValidate() case plan: ShuffledHashJoinExec => if (!enableColumnarShuffledHashJoin) return false - val transformer = ShuffledHashJoinExecTransformer( - plan.leftKeys, - plan.rightKeys, - plan.joinType, - plan.buildSide, - plan.condition, - plan.left, - plan.right) + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genShuffledHashJoinExecTransformer( + plan.leftKeys, + plan.rightKeys, + plan.joinType, + plan.buildSide, + plan.condition, + plan.left, + plan.right) transformer.doValidate() case plan: BroadcastExchangeExec => // columnar broadcast is enabled only when columnar bhj is enabled. @@ -163,7 +164,8 @@ case class TransformGuardRule() extends Rule[SparkPlan] { exec.doValidate() case plan: BroadcastHashJoinExec => if (!enableColumnarBroadcastJoin) return false - val transformer = BroadcastHashJoinExecTransformer( + val transformer = BackendsApiManager.getSparkPlanExecApiInstance + .genBroadcastHashJoinExecTransformer( plan.leftKeys, plan.rightKeys, plan.joinType, diff --git a/jvm/src/test/scala/io/glutenproject/backendsapi/SparkPlanExecApiImplSuite.scala b/jvm/src/test/scala/io/glutenproject/backendsapi/SparkPlanExecApiImplSuite.scala index 89a1d44f9f21..73f3e5372594 100644 --- a/jvm/src/test/scala/io/glutenproject/backendsapi/SparkPlanExecApiImplSuite.scala +++ b/jvm/src/test/scala/io/glutenproject/backendsapi/SparkPlanExecApiImplSuite.scala @@ -17,9 +17,8 @@ package io.glutenproject.backendsapi -import io.glutenproject.execution.{FilterExecBaseTransformer, HashAggregateExecBaseTransformer, NativeColumnarToRowExec, RowToArrowColumnarExec} +import io.glutenproject.execution.{BroadcastHashJoinExecTransformer, FilterExecBaseTransformer, HashAggregateExecBaseTransformer, NativeColumnarToRowExec, RowToArrowColumnarExec, ShuffledHashJoinExecTransformer} import io.glutenproject.expression.AliasBaseTransformer - import org.apache.spark.ShuffleDependency import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer @@ -27,8 +26,10 @@ import org.apache.spark.shuffle.{GenShuffleWriterParameters, GlutenShuffleWriter import org.apache.spark.sql.{SparkSession, Strategy} import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, ExprId, NamedExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, Partitioning} +import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{ColumnarRule, SparkPlan} import org.apache.spark.sql.execution.joins.BuildSideRelation @@ -84,6 +85,30 @@ class SparkPlanExecApiImplSuite extends ISparkPlanExecApi { resultExpressions: Seq[NamedExpression], child: SparkPlan): HashAggregateExecBaseTransformer = null + /** + * Generate ShuffledHashJoinExecTransformer. + */ + def genShuffledHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan): ShuffledHashJoinExecTransformer = null + + /** + * Generate BroadcastHashJoinExecTransformer. + */ + def genBroadcastHashJoinExecTransformer(leftKeys: Seq[Expression], + rightKeys: Seq[Expression], + joinType: JoinType, + buildSide: BuildSide, + condition: Option[Expression], + left: SparkPlan, + right: SparkPlan, + isNullAwareAntiJoin: Boolean = false) + : BroadcastHashJoinExecTransformer = null + /** * Generate Alias transformer. *