diff --git a/backends-clickhouse/src/main/java/io/glutenproject/vectorized/StorageJoinBuilder.java b/backends-clickhouse/src/main/java/io/glutenproject/vectorized/StorageJoinBuilder.java index 333889939a00e..7bbc3ef528ea3 100644 --- a/backends-clickhouse/src/main/java/io/glutenproject/vectorized/StorageJoinBuilder.java +++ b/backends-clickhouse/src/main/java/io/glutenproject/vectorized/StorageJoinBuilder.java @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.storage.CHShuffleReadStreamFactory; -import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -88,8 +87,8 @@ public static long build( /** create table named struct */ private static NamedStruct toNameStruct(List output) { - ArrayList typeList = ConverterUtils.collectAttributeTypeNodes(output); - ArrayList nameList = ConverterUtils.collectAttributeNamesWithExprId(output); + List typeList = ConverterUtils.collectAttributeTypeNodes(output); + List nameList = ConverterUtils.collectAttributeNamesWithExprId(output); Type.Struct.Builder structBuilder = Type.Struct.newBuilder(); for (TypeNode typeNode : typeList) { structBuilder.addTypes(typeNode.toProtobuf()); diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala index 4820e4f96aeb3..6213650d6c69a 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHIteratorApi.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import java.lang.{Long => JLong} import java.net.URI -import java.util +import java.util.{ArrayList => JArrayList} import scala.collection.JavaConverters._ import scala.collection.mutable @@ -68,9 +68,9 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { .makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath), SoftAffinityUtil.getNativeMergeTreePartitionLocations(p)) case f: FilePartition => - val paths = new util.ArrayList[String]() - val starts = new util.ArrayList[JLong]() - val lengths = new util.ArrayList[JLong]() + val paths = new JArrayList[String]() + val starts = new JArrayList[JLong]() + val lengths = new JArrayList[JLong]() val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]] f.files.foreach { file => @@ -122,7 +122,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { val resIter: GeneralOutIterator = GlutenTimeMetric.millis(pipelineTime) { _ => val transKernel = new CHNativeExpressionEvaluator() - val inBatchIters = new util.ArrayList[GeneralInIterator](inputIterators.map { + val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map { iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava) }.asJava) transKernel.createKernelWithBatchIterator(inputPartition.plan, inBatchIters, false) @@ -180,7 +180,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil { _ => val transKernel = new CHNativeExpressionEvaluator() val columnarNativeIterator = - new java.util.ArrayList[GeneralInIterator](inputIterators.map { + new JArrayList[GeneralInIterator](inputIterators.map { iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava) }.asJava) // we need to complete dependency RDD's firstly diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala index aaaff0c44a0a5..dcae0ad9e9ea3 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHMetricsApi.scala @@ -26,14 +26,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import java.{lang, util} +import java.lang.{Long => JLong} +import java.util.{List => JList, Map => JMap} class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil { override def metricsUpdatingFunction( child: SparkPlan, - relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]], - joinParamsMap: util.HashMap[lang.Long, JoinParams], - aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = { + relMap: JMap[JLong, JList[JLong]], + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = { MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap) } diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala index 7607440226efc..d4dbd392ddc0b 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/backendsapi/clickhouse/CHSparkPlanExecApi.scala @@ -55,7 +55,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.collect.Lists import org.apache.commons.lang3.ClassUtils -import java.{lang, util} +import java.lang.{Long => JLong} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} import scala.collection.mutable.ArrayBuffer @@ -64,7 +65,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { /** Transform GetArrayItem to Substrait. */ override def genGetArrayItemExpressionNode( substraitExprName: String, - functionMap: java.util.HashMap[String, java.lang.Long], + functionMap: JMap[String, JLong], leftNode: ExpressionNode, rightNode: ExpressionNode, original: GetArrayItem): ExpressionNode = { @@ -436,9 +437,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { /** Generate window function node */ override def genWindowFunctionsNode( windowExpression: Seq[NamedExpression], - windowExpressionNodes: util.ArrayList[WindowFunctionNode], + windowExpressionNodes: JList[WindowFunctionNode], originalInputAttributes: Seq[Attribute], - args: util.HashMap[String, lang.Long]): Unit = { + args: JMap[String, JLong]): Unit = { windowExpression.map { windowExpr => @@ -451,7 +452,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] val windowFunctionNode = ExpressionBuilder.makeWindowFunction( WindowFunctionsBuilder.create(args, aggWindowFunc).toInt, - new util.ArrayList[ExpressionNode](), + new JArrayList[ExpressionNode](), columnName, ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), WindowExecTransformer.getFrameBound(frame.upper), @@ -467,7 +468,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { throw new UnsupportedOperationException(s"Not currently supported: $aggregateFunc.") } - val childrenNodeList = new util.ArrayList[ExpressionNode]() + val childrenNodeList = new JArrayList[ExpressionNode]() aggregateFunc.children.foreach( expr => childrenNodeList.add( @@ -505,7 +506,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi { } } - val childrenNodeList = new util.ArrayList[ExpressionNode]() + val childrenNodeList = new JArrayList[ExpressionNode]() childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer( diff --git a/backends-clickhouse/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala b/backends-clickhouse/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala index 556b9dac674b0..6c2af3eeaa5c4 100644 --- a/backends-clickhouse/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala +++ b/backends-clickhouse/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala @@ -23,6 +23,9 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric +import java.lang.{Long => JLong} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + import scala.collection.JavaConverters._ object MetricsUtil extends Logging { @@ -56,9 +59,9 @@ object MetricsUtil extends Logging { */ def updateNativeMetrics( child: SparkPlan, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = { + relMap: JMap[JLong, JList[JLong]], + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = { val mut: MetricsUpdaterTree = treeifyMetricsUpdaters(child) @@ -90,10 +93,10 @@ object MetricsUtil extends Logging { */ def updateTransformerMetrics( mutNode: MetricsUpdaterTree, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - operatorIdx: java.lang.Long, - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = { + relMap: JMap[JLong, JList[JLong]], + operatorIdx: JLong, + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = { imetrics => try { val metrics = imetrics.asInstanceOf[NativeMetrics] @@ -129,13 +132,13 @@ object MetricsUtil extends Logging { */ def updateTransformerMetricsInternal( mutNode: MetricsUpdaterTree, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - operatorIdx: java.lang.Long, + relMap: JMap[JLong, JList[JLong]], + operatorIdx: JLong, metrics: NativeMetrics, metricsIdx: Int, - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): (java.lang.Long, Int) = { - val nodeMetricsList = new java.util.ArrayList[MetricsData]() + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = { + val nodeMetricsList = new JArrayList[MetricsData]() var curMetricsIdx = metricsIdx relMap .get(operatorIdx) diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala index 2aae3d9c8ca02..36ce0ffae1356 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/IteratorApiImpl.scala @@ -42,10 +42,11 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.ExecutorManager +import java.lang.{Long => JLong} import java.net.URLDecoder import java.nio.charset.StandardCharsets import java.time.ZoneOffset -import java.util +import java.util.{ArrayList => JArrayList} import java.util.concurrent.TimeUnit import scala.collection.JavaConverters._ @@ -67,14 +68,14 @@ class IteratorApiImpl extends IteratorApi with Logging { def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = { val paths = mutable.ArrayBuffer.empty[String] - val starts = mutable.ArrayBuffer.empty[java.lang.Long] - val lengths = mutable.ArrayBuffer.empty[java.lang.Long] + val starts = mutable.ArrayBuffer.empty[JLong] + val lengths = mutable.ArrayBuffer.empty[JLong] val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]] files.foreach { file => paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name())) - starts.append(java.lang.Long.valueOf(file.start)) - lengths.append(java.lang.Long.valueOf(file.length)) + starts.append(JLong.valueOf(file.start)) + lengths.append(JLong.valueOf(file.length)) val partitionColumn = mutable.Map.empty[String, String] for (i <- 0 until file.partitionValues.numFields) { @@ -90,7 +91,7 @@ class IteratorApiImpl extends IteratorApi with Logging { case _: TimestampType => TimestampFormatter .getFractionFormatter(ZoneOffset.UTC) - .format(pn.asInstanceOf[java.lang.Long]) + .format(pn.asInstanceOf[JLong]) case _ => pn.toString } } @@ -139,7 +140,7 @@ class IteratorApiImpl extends IteratorApi with Logging { inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()): Iterator[ColumnarBatch] = { val beforeBuild = System.nanoTime() val columnarNativeIterators = - new util.ArrayList[GeneralInIterator](inputIterators.map { + new JArrayList[GeneralInIterator](inputIterators.map { iter => new ColumnarBatchInIterator(iter.asJava) }.asJava) val transKernel = NativePlanEvaluator.create() @@ -183,7 +184,7 @@ class IteratorApiImpl extends IteratorApi with Logging { val transKernel = NativePlanEvaluator.create() val columnarNativeIterator = - new util.ArrayList[GeneralInIterator](inputIterators.map { + new JArrayList[GeneralInIterator](inputIterators.map { iter => new ColumnarBatchInIterator(iter.asJava) }.asJava) val nativeResultIterator = diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala index cb72ab5caeeb0..95f4ec30a5f68 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/MetricsApiImpl.scala @@ -25,14 +25,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import java.{lang, util} +import java.lang.{Long => JLong} +import java.util.{List => JList, Map => JMap} class MetricsApiImpl extends MetricsApi with Logging { override def metricsUpdatingFunction( child: SparkPlan, - relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]], - joinParamsMap: util.HashMap[lang.Long, JoinParams], - aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = { + relMap: JMap[JLong, JList[JLong]], + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = { MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap) } diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala index 697739ef3d62d..db77ce2ae93ec 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/SparkPlanExecApiImpl.scala @@ -56,6 +56,9 @@ import org.apache.commons.lang3.ClassUtils import javax.ws.rs.core.UriBuilder +import java.lang.{Long => JLong} +import java.util.{Map => JMap} + import scala.collection.mutable.ArrayBuffer class SparkPlanExecApiImpl extends SparkPlanExecApi { @@ -67,7 +70,7 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi { */ override def genGetArrayItemExpressionNode( substraitExprName: String, - functionMap: java.util.HashMap[String, java.lang.Long], + functionMap: JMap[String, JLong], leftNode: ExpressionNode, rightNode: ExpressionNode, original: GetArrayItem): ExpressionNode = { diff --git a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/TransformerApiImpl.scala b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/TransformerApiImpl.scala index 58e6a14646b7a..b2e4d9b6b7891 100644 --- a/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/TransformerApiImpl.scala +++ b/backends-velox/src/main/scala/io/glutenproject/backendsapi/velox/TransformerApiImpl.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDi import org.apache.spark.sql.types._ import org.apache.spark.util.collection.BitSet -import java.util +import java.util.{Map => JMap} class TransformerApiImpl extends TransformerApi with Logging { @@ -65,7 +65,7 @@ class TransformerApiImpl extends TransformerApi with Logging { } override def postProcessNativeConfig( - nativeConfMap: util.Map[String, String], + nativeConfMap: JMap[String, String], backendPrefix: String): Unit = { // TODO: IMPLEMENT SPECIAL PROCESS FOR VELOX BACKEND } diff --git a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala index 92b8e2c9d7d12..667656c769982 100644 --- a/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/execution/HashAggregateExecTransformer.scala @@ -34,7 +34,8 @@ import org.apache.spark.sql.types._ import com.google.protobuf.Any -import java.util +import java.lang.{Long => JLong} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList} import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer @@ -135,7 +136,7 @@ case class HashAggregateExecTransformer( aggRel: RelNode, operatorId: Long, validation: Boolean): RelNode = { - val expressionNodes = new util.ArrayList[ExpressionNode]() + val expressionNodes = new JArrayList[ExpressionNode]() var colIdx = 0 while (colIdx < groupingExpressions.size) { val groupingExpr: ExpressionNode = ExpressionBuilder.makeSelection(colIdx) @@ -233,7 +234,7 @@ case class HashAggregateExecTransformer( * The type of partial outputs. */ private def getIntermediateTypeNode(aggregateFunction: AggregateFunction): TypeNode = { - val structTypeNodes = new util.ArrayList[TypeNode]() + val structTypeNodes = new JArrayList[TypeNode]() aggregateFunction match { case avg: Average => structTypeNodes.add( @@ -305,9 +306,9 @@ case class HashAggregateExecTransformer( override protected def addFunctionNode( args: java.lang.Object, aggregateFunction: AggregateFunction, - childrenNodeList: java.util.ArrayList[ExpressionNode], + childrenNodeList: JList[ExpressionNode], aggregateMode: AggregateMode, - aggregateNodeList: java.util.ArrayList[AggregateFunctionNode]): Unit = { + aggregateNodeList: JList[AggregateFunctionNode]): Unit = { // This is a special handling for PartialMerge in the execution of distinct. // Use Partial phase instead for this aggregation. val modeKeyWord = modeToKeyWord(aggregateMode) @@ -392,8 +393,8 @@ case class HashAggregateExecTransformer( * Return the output types after partial aggregation through Velox. * @return */ - def getPartialAggOutTypes: java.util.ArrayList[TypeNode] = { - val typeNodeList = new java.util.ArrayList[TypeNode]() + def getPartialAggOutTypes: JList[TypeNode] = { + val typeNodeList = new JArrayList[TypeNode]() groupingExpressions.foreach( expression => { typeNodeList.add(ConverterUtils.getTypeNode(expression.dataType, expression.nullable)) @@ -452,20 +453,19 @@ case class HashAggregateExecTransformer( // Return a scalar function node representing row construct function in Velox. private def getRowConstructNode( args: java.lang.Object, - childNodes: util.ArrayList[ExpressionNode], + childNodes: JList[ExpressionNode], rowConstructAttributes: Seq[Attribute], withNull: Boolean = true): ScalarFunctionNode = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName( if (withNull) "row_constructor_with_null" else "row_constructor", rowConstructAttributes.map(attr => attr.dataType)) val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) // Use struct type to represent Velox RowType. - val structTypeNodes = new util.ArrayList[TypeNode]( - rowConstructAttributes - .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - .asJava) + val structTypeNodes = rowConstructAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava ExpressionBuilder.makeScalarFunction( functionId, @@ -484,7 +484,7 @@ case class HashAggregateExecTransformer( validation: Boolean): RelNode = { val args = context.registeredFunction // Create a projection for row construct. - val exprNodes = new util.ArrayList[ExpressionNode]() + val exprNodes = new JArrayList[ExpressionNode]() groupingExpressions.foreach( expr => { exprNodes.add( @@ -498,7 +498,7 @@ case class HashAggregateExecTransformer( val aggregateFunction = aggregateExpression.aggregateFunction aggregateFunction match { case _ if mixedPartialAndMerge && aggregateExpression.mode == Partial => - val childNodes = new util.ArrayList[ExpressionNode]( + val childNodes = new JArrayList[ExpressionNode]( aggregateFunction.children .map( attr => { @@ -515,15 +515,13 @@ case class HashAggregateExecTransformer( functionInputAttributes.size == 2, s"${aggregateExpression.mode.toString} of Average expects two input attributes.") // Use a Velox function to combine the intermediate columns into struct. - val childNodes = new util.ArrayList[ExpressionNode]( + val childNodes = functionInputAttributes.toList .map( - attr => { - ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) - }) - .asJava) + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args)) + .asJava exprNodes.add( getRowConstructNode( args, @@ -540,15 +538,13 @@ case class HashAggregateExecTransformer( functionInputAttributes.size == 2, s"${aggregateExpression.mode.toString} of First/Last expects two input attributes.") // Use a Velox function to combine the intermediate columns into struct. - val childNodes = new util.ArrayList[ExpressionNode]( - functionInputAttributes.toList - .map( - attr => { - ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) - }) - .asJava) + val childNodes = functionInputAttributes.toList + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args) + ) + .asJava exprNodes.add(getRowConstructNode(args, childNodes, functionInputAttributes)) case other => throw new UnsupportedOperationException(s"$other is not supported.") @@ -564,31 +560,28 @@ case class HashAggregateExecTransformer( // Use a Velox function to combine the intermediate columns into struct. var index = 0 var newInputAttributes: Seq[Attribute] = Seq() - val childNodes = new util.ArrayList[ExpressionNode]( - functionInputAttributes.toList - .map( - attr => { - val aggExpr: ExpressionTransformer = ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - val aggNode = aggExpr.doTransform(args) - val expressionNode = if (index == 0) { - // Cast count from DoubleType into LongType to align with Velox semantics. - newInputAttributes = newInputAttributes :+ - attr.copy(attr.name, LongType, attr.nullable, attr.metadata)( - attr.exprId, - attr.qualifier) - ExpressionBuilder.makeCast( - ConverterUtils.getTypeNode(LongType, attr.nullable), - aggNode, - SQLConf.get.ansiEnabled) - } else { - newInputAttributes = newInputAttributes :+ attr - aggNode - } - index += 1 - expressionNode - }) - .asJava) + val childNodes = functionInputAttributes.toList.map { + attr => + val aggExpr: ExpressionTransformer = ExpressionConverter + .replaceWithExpressionTransformer(attr, originalInputAttributes) + val aggNode = aggExpr.doTransform(args) + val expressionNode = if (index == 0) { + // Cast count from DoubleType into LongType to align with Velox semantics. + newInputAttributes = newInputAttributes :+ + attr.copy(attr.name, LongType, attr.nullable, attr.metadata)( + attr.exprId, + attr.qualifier) + ExpressionBuilder.makeCast( + ConverterUtils.getTypeNode(LongType, attr.nullable), + aggNode, + SQLConf.get.ansiEnabled) + } else { + newInputAttributes = newInputAttributes :+ attr + aggNode + } + index += 1 + expressionNode + }.asJava exprNodes.add(getRowConstructNode(args, childNodes, newInputAttributes)) case other => throw new UnsupportedOperationException(s"$other is not supported.") @@ -602,7 +595,7 @@ case class HashAggregateExecTransformer( // Use a Velox function to combine the intermediate columns into struct. var index = 0 var newInputAttributes: Seq[Attribute] = Seq() - val childNodes = new util.ArrayList[ExpressionNode]() + val childNodes = new JArrayList[ExpressionNode]() // Velox's Corr order is [ck, n, xMk, yMk, xAvg, yAvg] // Spark's Corr order is [n, xAvg, yAvg, ck, xMk, yMk] val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) @@ -645,7 +638,7 @@ case class HashAggregateExecTransformer( // Use a Velox function to combine the intermediate columns into struct. var index = 0 var newInputAttributes: Seq[Attribute] = Seq() - val childNodes = new util.ArrayList[ExpressionNode]() + val childNodes = new JArrayList[ExpressionNode]() // Velox's Covar order is [ck, n, xAvg, yAvg] // Spark's Covar order is [n, xAvg, yAvg, ck] val sparkCorrOutputAttr = aggregateFunction.inputAggBufferAttributes.map(_.name) @@ -685,15 +678,13 @@ case class HashAggregateExecTransformer( functionInputAttributes.size == 2, "Final stage of Average expects two input attributes.") // Use a Velox function to combine the intermediate columns into struct. - val childNodes = new util.ArrayList[ExpressionNode]( - functionInputAttributes.toList - .map( - attr => { - ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) - }) - .asJava) + val childNodes = functionInputAttributes.toList + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args) + ) + .asJava exprNodes.add( getRowConstructNode(args, childNodes, functionInputAttributes, withNull = false)) case other => @@ -703,15 +694,13 @@ case class HashAggregateExecTransformer( if (functionInputAttributes.size != 1) { throw new UnsupportedOperationException("Only one input attribute is expected.") } - val childNodes = new util.ArrayList[ExpressionNode]( - functionInputAttributes.toList - .map( - attr => { - ExpressionConverter - .replaceWithExpressionTransformer(attr, originalInputAttributes) - .doTransform(args) - }) - .asJava) + val childNodes = functionInputAttributes.toList + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args) + ) + .asJava exprNodes.addAll(childNodes) } } @@ -722,10 +711,9 @@ case class HashAggregateExecTransformer( RelBuilder.makeProjectRel(inputRel, exprNodes, context, operatorId, emitStartIndex) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeProjectRel( @@ -738,7 +726,7 @@ case class HashAggregateExecTransformer( } // Create aggregation rel. - val groupingList = new util.ArrayList[ExpressionNode]() + val groupingList = new JArrayList[ExpressionNode]() var colIdx = 0 groupingExpressions.foreach( _ => { @@ -746,8 +734,8 @@ case class HashAggregateExecTransformer( colIdx += 1 }) - val aggFilterList = new util.ArrayList[ExpressionNode]() - val aggregateFunctionList = new util.ArrayList[AggregateFunctionNode]() + val aggFilterList = new JArrayList[ExpressionNode]() + val aggregateFunctionList = new JArrayList[AggregateFunctionNode]() aggregateExpressions.foreach( aggExpr => { if (aggExpr.filter.isDefined) { @@ -758,7 +746,7 @@ case class HashAggregateExecTransformer( } val aggregateFunc = aggExpr.aggregateFunction - val childrenNodes = new util.ArrayList[ExpressionNode]() + val childrenNodes = new JArrayList[ExpressionNode]() aggregateFunc match { case _: Average | _: First | _: Last | _: StddevSamp | _: StddevPop | _: VarianceSamp | _: VariancePop | _: Corr | _: CovPopulation | _: CovSample @@ -949,7 +937,7 @@ object VeloxAggregateFunctionsBuilder { args: java.lang.Object, aggregateFunc: AggregateFunction, forMergeCompanion: Boolean = false): Long = { - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] var sigName = ExpressionMappings.expressionsMap.get(aggregateFunc.getClass) if (sigName.isEmpty) { diff --git a/backends-velox/src/main/scala/io/glutenproject/expression/ExpressionTransformer.scala b/backends-velox/src/main/scala/io/glutenproject/expression/ExpressionTransformer.scala index cd98b94549bb7..34db622835f08 100644 --- a/backends-velox/src/main/scala/io/glutenproject/expression/ExpressionTransformer.scala +++ b/backends-velox/src/main/scala/io/glutenproject/expression/ExpressionTransformer.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.types.{IntegerType, LongType} import com.google.common.collect.Lists +import java.lang.{Integer => JInteger, Long => JLong} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap} + import scala.language.existentials case class VeloxAliasTransformer( @@ -49,7 +52,7 @@ case class VeloxNamedStructTransformer( child => expressionNodes.add( replaceWithExpressionTransformer(child, attributeSeq).doTransform(args))) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils .makeFuncName(substraitExprName, Seq(original.dataType), FunctionConfig.OPT) val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) @@ -71,7 +74,7 @@ case class VeloxGetStructFieldTransformer( node.getFieldLiteral(ordinal) case node: SelectionNode => // Append the nested index to selection node. - node.addNestedChildIdx(java.lang.Integer.valueOf(ordinal)) + node.addNestedChildIdx(JInteger.valueOf(ordinal)) case other => throw new UnsupportedOperationException(s"$other is not supported.") } @@ -94,7 +97,7 @@ case class VeloxHashExpressionTransformer( case HiveHash(_) => (ExpressionBuilder.makeIntLiteral(0), IntegerType) } - val nodes = new java.util.ArrayList[ExpressionNode]() + val nodes = new JArrayList[ExpressionNode]() // Seed as the first argument nodes.add(seedNode) exps.foreach( @@ -102,7 +105,7 @@ case class VeloxHashExpressionTransformer( nodes.add(expression.doTransform(args)) }) val childrenTypes = seedType +: original.children.map(child => child.dataType) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName(substraitExprName, childrenTypes, FunctionConfig.OPT) val functionId = ExpressionBuilder.newScalarFunction(functionMap, functionName) diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java index 7651bfed28157..f7b0659a028b2 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/AggregateFunctionNode.java @@ -24,16 +24,17 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class AggregateFunctionNode implements Serializable { private final Long functionId; - private final ArrayList expressionNodes = new ArrayList<>(); + private final List expressionNodes = new ArrayList<>(); private final String phase; private final TypeNode outputTypeNode; AggregateFunctionNode( Long functionId, - ArrayList expressionNodes, + List expressionNodes, String phase, TypeNode outputTypeNode) { this.functionId = functionId; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java index a0f2375d19deb..81194fedcef54 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/ExpressionBuilder.java @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; -import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -239,7 +238,7 @@ public static SelectionNode makeSelection(Integer fieldIdx, Integer childFieldId public static AggregateFunctionNode makeAggregateFunction( Long functionId, - ArrayList expressionNodes, + List expressionNodes, String phase, TypeNode outputTypeNode) { return new AggregateFunctionNode(functionId, expressionNodes, phase, outputTypeNode); @@ -261,7 +260,7 @@ public static SingularOrListNode makeSingularOrListNode( public static WindowFunctionNode makeWindowFunction( Integer functionId, - ArrayList expressionNodes, + List expressionNodes, String columnName, TypeNode outputTypeNode, String upperBound, diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/IfThenNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/IfThenNode.java index fbdca934e26fb..40133ccf6af52 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/IfThenNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/IfThenNode.java @@ -20,18 +20,17 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class IfThenNode implements ExpressionNode, Serializable { - private final ArrayList ifNodes = new ArrayList<>(); - private final ArrayList thenNodes = new ArrayList<>(); + private final List ifNodes = new ArrayList<>(); + private final List thenNodes = new ArrayList<>(); private final ExpressionNode elseValue; public IfThenNode( - ArrayList ifNodes, - ArrayList thenNodes, - ExpressionNode elseValue) { + List ifNodes, List thenNodes, ExpressionNode elseValue) { this.ifNodes.addAll(ifNodes); this.thenNodes.addAll(thenNodes); this.elseValue = elseValue; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/SelectionNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/SelectionNode.java index ef8bcd558bd0a..7dc5e53cb15cf 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/SelectionNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/SelectionNode.java @@ -20,12 +20,13 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class SelectionNode implements ExpressionNode, Serializable { private final Integer fieldIndex; // The nested indices of child field. For case like a.b.c, the index of c is put at last. - private final ArrayList nestedChildIndices = new ArrayList<>(); + private final List nestedChildIndices = new ArrayList<>(); SelectionNode(Integer fieldIndex) { this.fieldIndex = fieldIndex; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/StringMapNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/StringMapNode.java index 62909c171f014..b9ad7120320ad 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/StringMapNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/StringMapNode.java @@ -23,7 +23,7 @@ import java.util.Map; public class StringMapNode implements ExpressionNode, Serializable { - private final Map values = new HashMap(); + private final Map values = new HashMap<>(); public StringMapNode(Map values) { this.values.putAll(values); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java index 20716b752676d..fa450eeea9462 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/expression/WindowFunctionNode.java @@ -24,10 +24,11 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class WindowFunctionNode implements Serializable { private final Integer functionId; - private final ArrayList expressionNodes = new ArrayList<>(); + private final List expressionNodes = new ArrayList<>(); private final String columnName; private final TypeNode outputTypeNode; @@ -40,7 +41,7 @@ public class WindowFunctionNode implements Serializable { WindowFunctionNode( Integer functionId, - ArrayList expressionNodes, + List expressionNodes, String columnName, TypeNode outputTypeNode, String upperBound, diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanBuilder.java index 1452af194a484..5ea68926088c2 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanBuilder.java @@ -24,6 +24,7 @@ import io.glutenproject.substrait.type.TypeNode; import java.util.ArrayList; +import java.util.List; import java.util.Map; public class PlanBuilder { @@ -33,16 +34,14 @@ public class PlanBuilder { private PlanBuilder() {} public static PlanNode makePlan( - ArrayList mappingNodes, - ArrayList relNodes, - ArrayList outNames) { + List mappingNodes, List relNodes, List outNames) { return new PlanNode(mappingNodes, relNodes, outNames); } public static PlanNode makePlan( - ArrayList mappingNodes, - ArrayList relNodes, - ArrayList outNames, + List mappingNodes, + List relNodes, + List outNames, TypeNode outputSchema, AdvancedExtensionNode extension) { return new PlanNode(mappingNodes, relNodes, outNames, outputSchema, extension); @@ -53,20 +52,20 @@ public static PlanNode makePlan(AdvancedExtensionNode extension) { } public static PlanNode makePlan( - SubstraitContext subCtx, ArrayList relNodes, ArrayList outNames) { + SubstraitContext subCtx, List relNodes, List outNames) { return makePlan(subCtx, relNodes, outNames, null, null); } public static PlanNode makePlan( SubstraitContext subCtx, - ArrayList relNodes, - ArrayList outNames, + List relNodes, + List outNames, TypeNode outputSchema, AdvancedExtensionNode extension) { if (subCtx == null) { throw new NullPointerException("ColumnarWholestageTransformer cannot doTansform."); } - ArrayList mappingNodes = new ArrayList<>(); + List mappingNodes = new ArrayList<>(); for (Map.Entry entry : subCtx.registeredFunction().entrySet()) { FunctionMappingNode mappingNode = diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanNode.java index c678f7abc6965..498dfe7b0f50f 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/plan/PlanNode.java @@ -27,28 +27,26 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class PlanNode implements Serializable { - private final ArrayList mappingNodes = new ArrayList<>(); - private final ArrayList relNodes = new ArrayList<>(); - private final ArrayList outNames = new ArrayList<>(); + private final List mappingNodes = new ArrayList<>(); + private final List relNodes = new ArrayList<>(); + private final List outNames = new ArrayList<>(); private TypeNode outputSchema = null; private AdvancedExtensionNode extension = null; - PlanNode( - ArrayList mappingNodes, - ArrayList relNodes, - ArrayList outNames) { + PlanNode(List mappingNodes, List relNodes, List outNames) { this.mappingNodes.addAll(mappingNodes); this.relNodes.addAll(relNodes); this.outNames.addAll(outNames); } PlanNode( - ArrayList mappingNodes, - ArrayList relNodes, - ArrayList outNames, + List mappingNodes, + List relNodes, + List outNames, TypeNode outputSchema, AdvancedExtensionNode extension) { this.mappingNodes.addAll(mappingNodes); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/AggregateRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/AggregateRelNode.java index 75a0931b90742..8db80f9ebc613 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/AggregateRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/AggregateRelNode.java @@ -26,20 +26,21 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class AggregateRelNode implements RelNode, Serializable { private final RelNode input; - private final ArrayList groupings = new ArrayList<>(); - private final ArrayList aggregateFunctionNodes = new ArrayList<>(); + private final List groupings = new ArrayList<>(); + private final List aggregateFunctionNodes = new ArrayList<>(); - private final ArrayList filters = new ArrayList<>(); + private final List filters = new ArrayList<>(); private final AdvancedExtensionNode extensionNode; AggregateRelNode( RelNode input, - ArrayList groupings, - ArrayList aggregateFunctionNodes, - ArrayList filters) { + List groupings, + List aggregateFunctionNodes, + List filters) { this.input = input; this.groupings.addAll(groupings); this.aggregateFunctionNodes.addAll(aggregateFunctionNodes); @@ -49,9 +50,9 @@ public class AggregateRelNode implements RelNode, Serializable { AggregateRelNode( RelNode input, - ArrayList groupings, - ArrayList aggregateFunctionNodes, - ArrayList filters, + List groupings, + List aggregateFunctionNodes, + List filters, AdvancedExtensionNode extensionNode) { this.input = input; this.groupings.addAll(groupings); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExpandRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExpandRelNode.java index 853c1c62b0110..49f714b8a7f64 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExpandRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ExpandRelNode.java @@ -25,23 +25,22 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class ExpandRelNode implements RelNode, Serializable { private final RelNode input; - private final ArrayList> projections = new ArrayList<>(); + private final List> projections = new ArrayList<>(); private final AdvancedExtensionNode extensionNode; public ExpandRelNode( - RelNode input, - ArrayList> projections, - AdvancedExtensionNode extensionNode) { + RelNode input, List> projections, AdvancedExtensionNode extensionNode) { this.input = input; this.projections.addAll(projections); this.extensionNode = extensionNode; } - public ExpandRelNode(RelNode input, ArrayList> projections) { + public ExpandRelNode(RelNode input, List> projections) { this.input = input; this.projections.addAll(projections); this.extensionNode = null; @@ -59,7 +58,7 @@ public Rel toProtobuf() { expandBuilder.setInput(input.toProtobuf()); } - for (ArrayList projectList : projections) { + for (List projectList : projections) { ExpandRel.ExpandField.Builder expandFieldBuilder = ExpandRel.ExpandField.newBuilder(); ExpandRel.SwitchingField.Builder switchingField = ExpandRel.SwitchingField.newBuilder(); for (ExpressionNode exprNode : projectList) { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/GenerateRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/GenerateRelNode.java index 828931c4a38d4..4a53b2bd08f72 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/GenerateRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/GenerateRelNode.java @@ -24,22 +24,22 @@ import io.substrait.proto.RelCommon; import java.io.Serializable; -import java.util.ArrayList; +import java.util.List; public class GenerateRelNode implements RelNode, Serializable { private final RelNode input; private final ExpressionNode generator; - private final ArrayList childOutput; + private final List childOutput; private final AdvancedExtensionNode extensionNode; - GenerateRelNode(RelNode input, ExpressionNode generator, ArrayList childOutput) { + GenerateRelNode(RelNode input, ExpressionNode generator, List childOutput) { this(input, generator, childOutput, null); } GenerateRelNode( RelNode input, ExpressionNode generator, - ArrayList childOutput, + List childOutput, AdvancedExtensionNode extensionNode) { this.input = input; this.generator = generator; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ProjectRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ProjectRelNode.java index 595b823d38f21..ee88390c74ed6 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ProjectRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ProjectRelNode.java @@ -25,14 +25,15 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class ProjectRelNode implements RelNode, Serializable { private final RelNode input; - private final ArrayList expressionNodes = new ArrayList<>(); + private final List expressionNodes = new ArrayList<>(); private final AdvancedExtensionNode extensionNode; private final int emitStartIndex; - ProjectRelNode(RelNode input, ArrayList expressionNodes, int emitStartIndex) { + ProjectRelNode(RelNode input, List expressionNodes, int emitStartIndex) { this.input = input; this.expressionNodes.addAll(expressionNodes); this.extensionNode = null; @@ -41,7 +42,7 @@ public class ProjectRelNode implements RelNode, Serializable { ProjectRelNode( RelNode input, - ArrayList expressionNodes, + List expressionNodes, AdvancedExtensionNode extensionNode, int emitStartIndex) { this.input = input; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java index 9fb47b9e14cfb..ddf381a4a08c6 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/ReadRelNode.java @@ -31,12 +31,13 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; import java.util.Map; public class ReadRelNode implements RelNode, Serializable { - private final ArrayList types = new ArrayList<>(); - private final ArrayList names = new ArrayList<>(); - private final ArrayList columnTypeNodes = new ArrayList<>(); + private final List types = new ArrayList<>(); + private final List names = new ArrayList<>(); + private final List columnTypeNodes = new ArrayList<>(); private final SubstraitContext context; private final ExpressionNode filterNode; private final Long iteratorIndex; @@ -44,8 +45,8 @@ public class ReadRelNode implements RelNode, Serializable { private Map properties; ReadRelNode( - ArrayList types, - ArrayList names, + List types, + List names, SubstraitContext context, ExpressionNode filterNode, Long iteratorIndex) { @@ -57,12 +58,12 @@ public class ReadRelNode implements RelNode, Serializable { } ReadRelNode( - ArrayList types, - ArrayList names, + List types, + List names, SubstraitContext context, ExpressionNode filterNode, Long iteratorIndex, - ArrayList columnTypeNodes) { + List columnTypeNodes) { this.types.addAll(types); this.names.addAll(names); this.context = context; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java index 681c4ecf42b41..8dfb2f4a20afc 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/RelBuilder.java @@ -29,7 +29,7 @@ import io.substrait.proto.SortField; import org.apache.spark.sql.catalyst.expressions.Attribute; -import java.util.ArrayList; +import java.util.List; /** Contains helper functions for constructing substrait relations. */ public class RelBuilder { @@ -53,7 +53,7 @@ public static RelNode makeFilterRel( public static RelNode makeProjectRel( RelNode input, - ArrayList expressionNodes, + List expressionNodes, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -62,7 +62,7 @@ public static RelNode makeProjectRel( public static RelNode makeProjectRel( RelNode input, - ArrayList expressionNodes, + List expressionNodes, SubstraitContext context, Long operatorId, int emitStartIndex) { @@ -72,7 +72,7 @@ public static RelNode makeProjectRel( public static RelNode makeProjectRel( RelNode input, - ArrayList expressionNodes, + List expressionNodes, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId, @@ -83,9 +83,9 @@ public static RelNode makeProjectRel( public static RelNode makeAggregateRel( RelNode input, - ArrayList groupings, - ArrayList aggregateFunctionNodes, - ArrayList filters, + List groupings, + List aggregateFunctionNodes, + List filters, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -94,9 +94,9 @@ public static RelNode makeAggregateRel( public static RelNode makeAggregateRel( RelNode input, - ArrayList groupings, - ArrayList aggregateFunctionNodes, - ArrayList filters, + List groupings, + List aggregateFunctionNodes, + List filters, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId) { @@ -105,8 +105,8 @@ public static RelNode makeAggregateRel( } public static RelNode makeReadRel( - ArrayList types, - ArrayList names, + List types, + List names, ExpressionNode filter, SubstraitContext context, Long operatorId) { @@ -115,9 +115,9 @@ public static RelNode makeReadRel( } public static RelNode makeReadRel( - ArrayList types, - ArrayList names, - ArrayList columnTypeNodes, + List types, + List names, + List columnTypeNodes, ExpressionNode filter, SubstraitContext context, Long operatorId) { @@ -126,8 +126,8 @@ public static RelNode makeReadRel( } public static RelNode makeReadRel( - ArrayList types, - ArrayList names, + List types, + List names, ExpressionNode filter, Long iteratorIndex, SubstraitContext context, @@ -137,15 +137,15 @@ public static RelNode makeReadRel( } public static RelNode makeReadRel( - ArrayList attributes, SubstraitContext context, Long operatorId) { + List attributes, SubstraitContext context, Long operatorId) { if (operatorId >= 0) { // If the operator id is negative, will not register the rel to operator. // Currently, only for the special handling in join. context.registerRelToOperator(operatorId); } - ArrayList typeList = ConverterUtils.collectAttributeTypeNodes(attributes); - ArrayList nameList = ConverterUtils.collectAttributeNamesWithExprId(attributes); + List typeList = ConverterUtils.collectAttributeTypeNodes(attributes); + List nameList = ConverterUtils.collectAttributeNamesWithExprId(attributes); // The iterator index will be added in the path of LocalFiles. Long iteratorIndex = context.nextIteratorIndex(); @@ -184,7 +184,7 @@ public static RelNode makeJoinRel( public static RelNode makeExpandRel( RelNode input, - ArrayList> projections, + List> projections, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId) { @@ -194,7 +194,7 @@ public static RelNode makeExpandRel( public static RelNode makeExpandRel( RelNode input, - ArrayList> projections, + List> projections, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -203,7 +203,7 @@ public static RelNode makeExpandRel( public static RelNode makeSortRel( RelNode input, - ArrayList sorts, + List sorts, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId) { @@ -212,7 +212,7 @@ public static RelNode makeSortRel( } public static RelNode makeSortRel( - RelNode input, ArrayList sorts, SubstraitContext context, Long operatorId) { + RelNode input, List sorts, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); return new SortRelNode(input, sorts); } @@ -236,9 +236,9 @@ public static RelNode makeFetchRel( public static RelNode makeWindowRel( RelNode input, - ArrayList windowFunctionNodes, - ArrayList partitionExpressions, - ArrayList sorts, + List windowFunctionNodes, + List partitionExpressions, + List sorts, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId) { @@ -249,9 +249,9 @@ public static RelNode makeWindowRel( public static RelNode makeWindowRel( RelNode input, - ArrayList windowFunctionNodes, - ArrayList partitionExpressions, - ArrayList sorts, + List windowFunctionNodes, + List partitionExpressions, + List sorts, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -263,7 +263,7 @@ public static RelNode makeWindowRel( public static RelNode makeGenerateRel( RelNode input, ExpressionNode generator, - ArrayList childOutput, + List childOutput, SubstraitContext context, Long operatorId) { context.registerRelToOperator(operatorId); @@ -273,7 +273,7 @@ public static RelNode makeGenerateRel( public static RelNode makeGenerateRel( RelNode input, ExpressionNode generator, - ArrayList childOutput, + List childOutput, AdvancedExtensionNode extensionNode, SubstraitContext context, Long operatorId) { diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/SortRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/SortRelNode.java index 2fd129d5045e3..fb2b8b128e04c 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/SortRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/SortRelNode.java @@ -25,20 +25,20 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class SortRelNode implements RelNode, Serializable { private final RelNode input; - private final ArrayList sorts = new ArrayList<>(); + private final List sorts = new ArrayList<>(); private final AdvancedExtensionNode extensionNode; - public SortRelNode( - RelNode input, ArrayList sorts, AdvancedExtensionNode extensionNode) { + public SortRelNode(RelNode input, List sorts, AdvancedExtensionNode extensionNode) { this.input = input; this.sorts.addAll(sorts); this.extensionNode = extensionNode; } - public SortRelNode(RelNode input, ArrayList sorts) { + public SortRelNode(RelNode input, List sorts) { this.input = input; this.sorts.addAll(sorts); this.extensionNode = null; diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/rel/WindowRelNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/rel/WindowRelNode.java index 475f9f6f3cd07..f120c37c73f44 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/rel/WindowRelNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/rel/WindowRelNode.java @@ -27,19 +27,20 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class WindowRelNode implements RelNode, Serializable { private final RelNode input; - private final ArrayList windowFunctionNodes = new ArrayList<>(); - private final ArrayList partitionExpressions = new ArrayList<>(); - private final ArrayList sorts = new ArrayList<>(); + private final List windowFunctionNodes = new ArrayList<>(); + private final List partitionExpressions = new ArrayList<>(); + private final List sorts = new ArrayList<>(); private final AdvancedExtensionNode extensionNode; public WindowRelNode( RelNode input, - ArrayList windowFunctionNodes, - ArrayList partitionExpressions, - ArrayList sorts) { + List windowFunctionNodes, + List partitionExpressions, + List sorts) { this.input = input; this.windowFunctionNodes.addAll(windowFunctionNodes); this.partitionExpressions.addAll(partitionExpressions); @@ -49,9 +50,9 @@ public WindowRelNode( public WindowRelNode( RelNode input, - ArrayList windowFunctionNodes, - ArrayList partitionExpressions, - ArrayList sorts, + List windowFunctionNodes, + List partitionExpressions, + List sorts, AdvancedExtensionNode extensionNode) { this.input = input; this.windowFunctionNodes.addAll(windowFunctionNodes); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/type/MapNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/type/MapNode.java index 35ede7ed09508..76c4c76734cc4 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/type/MapNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/type/MapNode.java @@ -20,6 +20,7 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class MapNode implements TypeNode, Serializable { private final Boolean nullable; @@ -34,7 +35,7 @@ public MapNode(Boolean nullable, TypeNode keyType, TypeNode valType) { // It's used in ExplodeTransformer to determine output datatype from children. public TypeNode getNestedType() { - ArrayList types = new ArrayList<>(); + List types = new ArrayList<>(); types.add(keyType); types.add(valType); return TypeBuilder.makeStruct(false, types); diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/type/StructNode.java b/gluten-core/src/main/java/io/glutenproject/substrait/type/StructNode.java index 8b81b330518f6..ff7c0922d8229 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/type/StructNode.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/type/StructNode.java @@ -20,24 +20,25 @@ import java.io.Serializable; import java.util.ArrayList; +import java.util.List; public class StructNode implements TypeNode, Serializable { private final Boolean nullable; - private final ArrayList types = new ArrayList<>(); - private final ArrayList names = new ArrayList<>(); + private final List types = new ArrayList<>(); + private final List names = new ArrayList<>(); - public StructNode(Boolean nullable, ArrayList types, ArrayList names) { + public StructNode(Boolean nullable, List types, List names) { this.nullable = nullable; this.types.addAll(types); this.names.addAll(names); } - public StructNode(Boolean nullable, ArrayList types) { + public StructNode(Boolean nullable, List types) { this.nullable = nullable; this.types.addAll(types); } - public ArrayList getFieldTypes() { + public List getFieldTypes() { return types; } diff --git a/gluten-core/src/main/java/io/glutenproject/substrait/type/TypeBuilder.java b/gluten-core/src/main/java/io/glutenproject/substrait/type/TypeBuilder.java index bdaed93238251..a3efd67653c39 100644 --- a/gluten-core/src/main/java/io/glutenproject/substrait/type/TypeBuilder.java +++ b/gluten-core/src/main/java/io/glutenproject/substrait/type/TypeBuilder.java @@ -16,7 +16,7 @@ */ package io.glutenproject.substrait.type; -import java.util.ArrayList; +import java.util.List; public class TypeBuilder { private TypeBuilder() {} @@ -77,12 +77,11 @@ public static TypeNode makeTimestamp(Boolean nullable) { return new TimestampTypeNode(nullable); } - public static TypeNode makeStruct( - Boolean nullable, ArrayList types, ArrayList names) { + public static TypeNode makeStruct(Boolean nullable, List types, List names) { return new StructNode(nullable, types, names); } - public static TypeNode makeStruct(Boolean nullable, ArrayList types) { + public static TypeNode makeStruct(Boolean nullable, List types) { return new StructNode(nullable, types); } diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala index 9c24c1284d841..2425c9fc6f513 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/MetricsApi.scala @@ -23,6 +23,9 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} +import java.lang.{Long => JLong} +import java.util.{List => JList, Map => JMap} + trait MetricsApi extends Serializable { def genWholeStageTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] = @@ -30,9 +33,9 @@ trait MetricsApi extends Serializable { def metricsUpdatingFunction( child: SparkPlan, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit + relMap: JMap[JLong, JList[JLong]], + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit def genBatchScanTransformerMetrics(sparkContext: SparkContext): Map[String, SQLMetric] diff --git a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala index 3f0790ffca09f..3a7bfaa560dab 100644 --- a/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala +++ b/gluten-core/src/main/scala/io/glutenproject/backendsapi/SparkPlanExecApi.scala @@ -41,7 +41,10 @@ import org.apache.spark.sql.hive.HiveTableScanExecTransformer import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch -import java.util +import java.lang.{Long => JLong} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + +import scala.collection.JavaConverters._ trait SparkPlanExecApi { @@ -134,7 +137,7 @@ trait SparkPlanExecApi { /** Transform GetArrayItem to Substrait. */ def genGetArrayItemExpressionNode( substraitExprName: String, - functionMap: java.util.HashMap[String, java.lang.Long], + functionMap: JMap[String, JLong], leftNode: ExpressionNode, rightNode: ExpressionNode, original: GetArrayItem): ExpressionNode @@ -342,9 +345,9 @@ trait SparkPlanExecApi { /** default function to generate window function node */ def genWindowFunctionsNode( windowExpression: Seq[NamedExpression], - windowExpressionNodes: util.ArrayList[WindowFunctionNode], + windowExpressionNodes: JList[WindowFunctionNode], originalInputAttributes: Seq[Attribute], - args: util.HashMap[String, java.lang.Long]): Unit = { + args: JMap[String, JLong]): Unit = { windowExpression.map { windowExpr => @@ -357,7 +360,7 @@ trait SparkPlanExecApi { val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame] val windowFunctionNode = ExpressionBuilder.makeWindowFunction( WindowFunctionsBuilder.create(args, aggWindowFunc).toInt, - new util.ArrayList[ExpressionNode](), + new JArrayList[ExpressionNode](), columnName, ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable), WindowExecTransformer.getFrameBound(frame.upper), @@ -373,13 +376,12 @@ trait SparkPlanExecApi { throw new UnsupportedOperationException(s"Not currently supported: $aggregateFunc.") } - val childrenNodeList = new util.ArrayList[ExpressionNode]() - aggregateFunc.children.foreach( - expr => - childrenNodeList.add( - ExpressionConverter - .replaceWithExpressionTransformer(expr, originalInputAttributes) - .doTransform(args))) + val childrenNodeList = aggregateFunc.children + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args)) + .asJava val windowFunctionNode = ExpressionBuilder.makeWindowFunction( AggregateFunctionsBuilder.create(args, aggExpression.aggregateFunction).toInt, @@ -394,7 +396,7 @@ trait SparkPlanExecApi { case wf @ (Lead(_, _, _, _) | Lag(_, _, _, _)) => val offset_wf = wf.asInstanceOf[FrameLessOffsetWindowFunction] val frame = offset_wf.frame.asInstanceOf[SpecifiedWindowFrame] - val childrenNodeList = new util.ArrayList[ExpressionNode]() + val childrenNodeList = new JArrayList[ExpressionNode]() childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer( @@ -425,12 +427,12 @@ trait SparkPlanExecApi { windowExpressionNodes.add(windowFunctionNode) case wf @ NthValue(input, offset: Literal, _) => val frame = wExpression.windowSpec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] - val childrenNodeList = new util.ArrayList[ExpressionNode]() + val childrenNodeList = new JArrayList[ExpressionNode]() childrenNodeList.add( ExpressionConverter .replaceWithExpressionTransformer(input, attributeSeq = originalInputAttributes) .doTransform(args)) - childrenNodeList.add(new LiteralTransformer(offset).doTransform(args)) + childrenNodeList.add(LiteralTransformer(offset).doTransform(args)) val windowFunctionNode = ExpressionBuilder.makeWindowFunction( WindowFunctionsBuilder.create(args, wf).toInt, childrenNodeList, diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala index 44905796c3af1..74ded1bdc224a 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicPhysicalOperatorTransformer.scala @@ -22,9 +22,8 @@ import io.glutenproject.extension.{GlutenPlan, ValidationResult} import io.glutenproject.extension.columnar.TransformHints import io.glutenproject.metrics.MetricsUpdater import io.glutenproject.sql.shims.SparkShimLoader -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext -import io.glutenproject.substrait.expression.ExpressionNode import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} @@ -39,8 +38,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.Any -import java.util - import scala.collection.JavaConverters._ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkPlan) @@ -94,10 +91,9 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP RelBuilder.makeFilterRel(input, condExprNode, context, operatorId) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeFilterRel(input, condExprNode, extensionNode, context, operatorId) @@ -160,13 +156,12 @@ abstract class FilterExecTransformerBase(val cond: Expression, val input: SparkP } else { // This means the input is just an iterator, so an ReadRel will be created as child. // Prepare the input schema. - val attrList = new util.ArrayList[Attribute](child.output.asJava) getRelNode( context, cond, child.output, operatorId, - RelBuilder.makeReadRel(attrList, context, operatorId), + RelBuilder.makeReadRel(child.output.asJava, context, operatorId), validation = false) } assert(currRel != null, "Filter rel should be valid.") @@ -241,11 +236,7 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch } else { // This means the input is just an iterator, so an ReadRel will be created as child. // Prepare the input schema. - val attrList = new util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) ( getRelNode(context, projectList, child.output, operatorId, readRel, validation = false), child.output) @@ -268,23 +259,18 @@ case class ProjectExecTransformer private (projectList: Seq[NamedExpression], ch validation: Boolean): RelNode = { val args = context.registeredFunction val columnarProjExprs: Seq[ExpressionTransformer] = projectList.map( - expr => { + expr => ExpressionConverter - .replaceWithExpressionTransformer(expr, attributeSeq = originalInputAttributes) - }) - val projExprNodeList = new java.util.ArrayList[ExpressionNode]() - for (expr <- columnarProjExprs) { - projExprNodeList.add(expr.doTransform(args)) - } + .replaceWithExpressionTransformer(expr, attributeSeq = originalInputAttributes)) + val projExprNodeList = columnarProjExprs.map(_.doTransform(args)).asJava val emitStartIndex = originalInputAttributes.size if (!validation) { RelBuilder.makeProjectRel(input, projExprNodeList, context, operatorId, emitStartIndex) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeProjectRel( @@ -317,28 +303,24 @@ object ProjectExecTransformer { // after executing the MergeScalarSubqueries. var needToReplace = false val newProjectList = projectList.map { - pro => - pro match { - case alias @ Alias(cns @ CreateNamedStruct(children: Seq[Expression]), "mergedValue") => - // check whether there are some duplicate names - if (cns.nameExprs.distinct.size == cns.nameExprs.size) { - alias - } else { - val newChildren = children - .grouped(2) - .map { - case Seq(name: Literal, value: NamedExpression) => - val newLiteral = Literal(name.toString() + "#" + value.exprId.id) - Seq(newLiteral, value) - case Seq(name, value) => Seq(name, value) - } - .flatten - .toSeq - needToReplace = true - Alias.apply(CreateNamedStruct(newChildren), "mergedValue")(alias.exprId) + case alias @ Alias(cns @ CreateNamedStruct(children: Seq[Expression]), "mergedValue") => + // check whether there are some duplicate names + if (cns.nameExprs.distinct.size == cns.nameExprs.size) { + alias + } else { + val newChildren = children + .grouped(2) + .flatMap { + case Seq(name: Literal, value: NamedExpression) => + val newLiteral = Literal(name.toString() + "#" + value.exprId.id) + Seq(newLiteral, value) + case Seq(name, value) => Seq(name, value) } - case other: NamedExpression => other + .toSeq + needToReplace = true + Alias.apply(CreateNamedStruct(newChildren), "mergedValue")(alias.exprId) } + case other: NamedExpression => other } if (!needToReplace) { diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala index 08f13f3204966..c39a5e4465611 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/BasicScanExecTransformer.scala @@ -34,6 +34,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.common.collect.Lists +import scala.collection.JavaConverters._ + trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { // The key of merge schema option in Parquet reader. @@ -58,10 +60,7 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { val scanTime = longMetric("scanTime") val substraitContext = new SubstraitContext val transformContext = doTransform(substraitContext) - val outNames = new java.util.ArrayList[String]() - for (attr <- outputAttributes()) { - outNames.add(ConverterUtils.genColumnNameWithExprId(attr)) - } + val outNames = outputAttributes().map(ConverterUtils.genColumnNameWithExprId).asJava val planNode = PlanBuilder.makePlan(substraitContext, Lists.newArrayList(transformContext.root), outNames) val fileFormat = ConverterUtils.getFileFormat(this) @@ -102,14 +101,14 @@ trait BasicScanExecTransformer extends LeafTransformSupport with SupportFormat { val typeNodes = ConverterUtils.collectAttributeTypeNodes(output) val nameList = ConverterUtils.collectAttributeNamesWithoutExprId(output) val partitionSchemas = getPartitionSchemas - val columnTypeNodes = new java.util.ArrayList[ColumnTypeNode]() - for (attr <- output) { - if (partitionSchemas.exists(_.name.equals(attr.name))) { - columnTypeNodes.add(new ColumnTypeNode(1)) - } else { - columnTypeNodes.add(new ColumnTypeNode(0)) - } - } + val columnTypeNodes = output.map { + attr => + if (partitionSchemas.exists(_.name.equals(attr.name))) { + new ColumnTypeNode(1) + } else { + new ColumnTypeNode(0) + } + }.asJava // Will put all filter expressions into an AND expression val transformer = filterExprs() .reduceLeftOption(And) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala index 05d60b51a2ea6..245e40d6c42c3 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/ExpandExecTransformer.scala @@ -34,8 +34,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.Any -import java.util +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer case class ExpandExecTransformer( @@ -76,9 +77,9 @@ case class ExpandExecTransformer( val preExprs = ArrayBuffer.empty[Expression] val selectionMaps = ArrayBuffer.empty[Seq[Int]] var preExprIndex = 0 - for (i <- 0 until projections.size) { + for (i <- projections.indices) { val selections = ArrayBuffer.empty[Int] - for (j <- 0 until projections(i).size) { + for (j <- projections(i).indices) { val proj = projections(i)(j) if (!proj.isInstanceOf[Literal]) { val exprIdx = preExprs.indexWhere(expr => expr.semanticEquals(proj)) @@ -96,14 +97,12 @@ case class ExpandExecTransformer( selectionMaps += selections } // make project - val preExprNodes = new util.ArrayList[ExpressionNode]() - preExprs.foreach { - expr => - val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(expr, originalInputAttributes) - .doTransform(args) - preExprNodes.add(exprNode) - } + val preExprNodes = preExprs + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, originalInputAttributes) + .doTransform(args)) + .asJava val emitStartIndex = originalInputAttributes.size val inputRel = if (!validation) { @@ -126,10 +125,10 @@ case class ExpandExecTransformer( } // make expand - val projectSetExprNodes = new util.ArrayList[util.ArrayList[ExpressionNode]]() - for (i <- 0 until projections.size) { - val projectExprNodes = new util.ArrayList[ExpressionNode]() - for (j <- 0 until projections(i).size) { + val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]() + for (i <- projections.indices) { + val projectExprNodes = new JArrayList[ExpressionNode]() + for (j <- projections(i).indices) { val projectExprNode = projections(i)(j) match { case l: Literal => LiteralTransformer(l).doTransform(args) @@ -143,10 +142,10 @@ case class ExpandExecTransformer( } RelBuilder.makeExpandRel(inputRel, projectSetExprNodes, context, operatorId) } else { - val projectSetExprNodes = new util.ArrayList[util.ArrayList[ExpressionNode]]() + val projectSetExprNodes = new JArrayList[JList[ExpressionNode]]() projections.foreach { projectSet => - val projectExprNodes = new util.ArrayList[ExpressionNode]() + val projectExprNodes = new JArrayList[ExpressionNode]() projectSet.foreach { project => val projectExprNode = ExpressionConverter @@ -218,11 +217,7 @@ case class ExpandExecTransformer( } else { // This means the input is just an iterator, so an ReadRel will be created as child. // Prepare the input schema. - val attrList = new util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) ( getRelNode(context, projections, child.output, operatorId, readRel, validation = false), child.output) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala index 3b69f3c89686e..c895971a92969 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/GenerateExecTransformer.scala @@ -21,7 +21,7 @@ import io.glutenproject.exception.GlutenException import io.glutenproject.expression.{ConverterUtils, ExpressionConverter, ExpressionTransformer} import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} import io.glutenproject.substrait.extensions.ExtensionBuilder @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.SparkPlan import com.google.protobuf.Any -import java.util +import java.util.{ArrayList => JArrayList, List => JList} import scala.collection.JavaConverters._ @@ -105,7 +105,7 @@ case class GenerateExecTransformer( val generatorExpr = ExpressionConverter.replaceWithExpressionTransformer(generator, child.output) val generatorNode = generatorExpr.doTransform(args) - val requiredChildOutputNodes = new java.util.ArrayList[ExpressionNode] + val requiredChildOutputNodes = new JArrayList[ExpressionNode] for (target <- requiredChildOutput) { val found = child.output.zipWithIndex.filter(_._1.name == target.name) if (found.nonEmpty) { @@ -119,11 +119,7 @@ case class GenerateExecTransformer( val inputRel = if (childCtx != null) { childCtx.root } else { - val attrList = new java.util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) readRel } val projRel = @@ -131,7 +127,7 @@ case class GenerateExecTransformer( BackendsApiManager.getSettings.insertPostProjectForGenerate() && needsProjection(generator) ) { // need to insert one projection node for velox backend - val projectExpressions = new util.ArrayList[ExpressionNode]() + val projectExpressions = new JArrayList[ExpressionNode]() val childOutputNodes = child.output.indices .map(i => ExpressionBuilder.makeSelection(i).asInstanceOf[ExpressionNode]) .asJava @@ -174,16 +170,14 @@ case class GenerateExecTransformer( inputAttributes: Seq[Attribute], input: RelNode, generator: ExpressionNode, - childOutput: util.ArrayList[ExpressionNode], + childOutput: JList[ExpressionNode], validation: Boolean): RelNode = { if (!validation) { RelBuilder.makeGenerateRel(input, generator, childOutput, context, operatorId) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- inputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = + inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeGenerateRel(input, generator, childOutput, extensionNode, context, operatorId) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala index 4422fcc586bc6..00b3cc7e8c332 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashAggregateExecBaseTransformer.scala @@ -20,7 +20,7 @@ import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression._ import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.{AggregationParams, SubstraitContext} import io.glutenproject.substrait.expression.{AggregateFunctionNode, ExpressionBuilder, ExpressionNode} import io.glutenproject.substrait.extensions.ExtensionBuilder @@ -38,8 +38,9 @@ import org.apache.spark.util.sketch.BloomFilter import com.google.protobuf.Any -import java.util +import java.util.{ArrayList => JArrayList, List => JList} +import scala.collection.JavaConverters._ import scala.collection.mutable.ListBuffer import scala.util.control.Breaks.{break, breakable} @@ -104,10 +105,10 @@ abstract class HashAggregateExecBaseTransformer( case BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | StringType | TimestampType | DateType | BinaryType => true - case d: DecimalType => true - case a: ArrayType => true - case n: NullType => true - case other => false + case _: DecimalType => true + case _: ArrayType => true + case _: NullType => true + case _ => false } } @@ -146,11 +147,7 @@ abstract class HashAggregateExecBaseTransformer( // This means the input is just an iterator, so an ReadRel will be created as child. // Prepare the input schema. aggParams.isReadRel = true - val attrList = new util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) (getAggRel(context, operatorId, aggParams, readRel), child.output, output) } TransformContext(inputAttributes, outputAttributes, relNode) @@ -283,22 +280,20 @@ abstract class HashAggregateExecBaseTransformer( }) // Create the expression nodes needed by Project node. - val preExprNodes = new util.ArrayList[ExpressionNode]() - for (expr <- preExpressions) { - preExprNodes.add( + val preExprNodes = preExpressions + .map( ExpressionConverter - .replaceWithExpressionTransformer(expr, originalInputAttributes) + .replaceWithExpressionTransformer(_, originalInputAttributes) .doTransform(args)) - } + .asJava val emitStartIndex = originalInputAttributes.size val inputRel = if (!validation) { RelBuilder.makeProjectRel(input, preExprNodes, context, operatorId, emitStartIndex) } else { // Use a extension node to send the input types through Substrait plan for a validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeProjectRel( @@ -321,7 +316,7 @@ abstract class HashAggregateExecBaseTransformer( filterSelections: Seq[Int], inputRel: RelNode, operatorId: Long): RelNode = { - val groupingList = new util.ArrayList[ExpressionNode]() + val groupingList = new JArrayList[ExpressionNode]() var colIdx = 0 while (colIdx < groupingExpressions.size) { val groupingExpr: ExpressionNode = ExpressionBuilder.makeSelection(selections(colIdx)) @@ -330,11 +325,11 @@ abstract class HashAggregateExecBaseTransformer( } // Create Aggregation functions. - val aggregateFunctionList = new util.ArrayList[AggregateFunctionNode]() + val aggregateFunctionList = new JArrayList[AggregateFunctionNode]() aggregateExpressions.foreach( aggExpr => { val aggregateFunc = aggExpr.aggregateFunction - val childrenNodeList = new util.ArrayList[ExpressionNode]() + val childrenNodeList = new JArrayList[ExpressionNode]() val childrenNodes = aggregateFunc.children.toList.map( _ => { val aggExpr = ExpressionBuilder.makeSelection(selections(colIdx)) @@ -352,7 +347,7 @@ abstract class HashAggregateExecBaseTransformer( aggregateFunctionList) }) - val aggFilterList = new util.ArrayList[ExpressionNode]() + val aggFilterList = new JArrayList[ExpressionNode]() aggregateExpressions.foreach( aggExpr => { if (aggExpr.filter.isDefined) { @@ -376,9 +371,9 @@ abstract class HashAggregateExecBaseTransformer( protected def addFunctionNode( args: java.lang.Object, aggregateFunction: AggregateFunction, - childrenNodeList: util.ArrayList[ExpressionNode], + childrenNodeList: JList[ExpressionNode], aggregateMode: AggregateMode, - aggregateNodeList: util.ArrayList[AggregateFunctionNode]): Unit = { + aggregateNodeList: JList[AggregateFunctionNode]): Unit = { aggregateNodeList.add( ExpressionBuilder.makeAggregateFunction( AggregateFunctionsBuilder.create(args, aggregateFunction), @@ -396,23 +391,20 @@ abstract class HashAggregateExecBaseTransformer( val args = context.registeredFunction // Will add an projection after Agg. - val resExprNodes = new util.ArrayList[ExpressionNode]() - resultExpressions.foreach( - expr => { - resExprNodes.add( - ExpressionConverter - .replaceWithExpressionTransformer(expr, allAggregateResultAttributes) - .doTransform(args)) - }) + val resExprNodes = resultExpressions + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, allAggregateResultAttributes) + .doTransform(args)) + .asJava val emitStartIndex = allAggregateResultAttributes.size if (!validation) { RelBuilder.makeProjectRel(aggRel, resExprNodes, context, operatorId, emitStartIndex) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- allAggregateResultAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = allAggregateResultAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeProjectRel( @@ -429,7 +421,7 @@ abstract class HashAggregateExecBaseTransformer( protected def getAttrForAggregateExprs( aggregateExpressions: Seq[AggregateExpression], aggregateAttributeList: Seq[Attribute]): List[Attribute] = { - var aggregateAttr = new ListBuffer[Attribute]() + val aggregateAttr = new ListBuffer[Attribute]() val size = aggregateExpressions.size var resIndex = 0 for (expIdx <- 0 until size) { @@ -657,19 +649,17 @@ abstract class HashAggregateExecBaseTransformer( validation: Boolean): RelNode = { val args = context.registeredFunction // Get the grouping nodes. - val groupingList = new util.ArrayList[ExpressionNode]() - groupingExpressions.foreach( - expr => { - // Use 'child.output' as based Seq[Attribute], the originalInputAttributes - // may be different for each backend. - val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(expr, child.output) - .doTransform(args) - groupingList.add(exprNode) - }) + // Use 'child.output' as based Seq[Attribute], the originalInputAttributes + // may be different for each backend. + val groupingList = groupingExpressions + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, child.output) + .doTransform(args)) + .asJava // Get the aggregate function nodes. - val aggFilterList = new util.ArrayList[ExpressionNode]() - val aggregateFunctionList = new util.ArrayList[AggregateFunctionNode]() + val aggFilterList = new JArrayList[ExpressionNode]() + val aggregateFunctionList = new JArrayList[AggregateFunctionNode]() aggregateExpressions.foreach( aggExpr => { if (aggExpr.filter.isDefined) { @@ -682,7 +672,6 @@ abstract class HashAggregateExecBaseTransformer( aggFilterList.add(null) } val aggregateFunc = aggExpr.aggregateFunction - val childrenNodeList = new util.ArrayList[ExpressionNode]() val childrenNodes = aggExpr.mode match { case Partial => aggregateFunc.children.toList.map( @@ -701,10 +690,12 @@ abstract class HashAggregateExecBaseTransformer( case other => throw new UnsupportedOperationException(s"$other not supported.") } - for (node <- childrenNodes) { - childrenNodeList.add(node) - } - addFunctionNode(args, aggregateFunc, childrenNodeList, aggExpr.mode, aggregateFunctionList) + addFunctionNode( + args, + aggregateFunc, + childrenNodes.asJava, + aggExpr.mode, + aggregateFunctionList) }) if (!validation) { RelBuilder.makeAggregateRel( @@ -716,10 +707,9 @@ abstract class HashAggregateExecBaseTransformer( operatorId) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeAggregateRel( diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala index f34910db7ef0a..b95aa765d3148 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/HashJoinExecTransformer.scala @@ -43,7 +43,7 @@ import com.google.protobuf.{Any, StringValue} import io.substrait.proto.JoinRel import java.lang.{Long => JLong} -import java.util.{ArrayList => JArrayList, HashMap => JHashMap} +import java.util.{Map => JMap} import scala.collection.JavaConverters._ @@ -226,7 +226,7 @@ trait HashJoinLikeExecTransformer (transformContext.root, transformContext.outputAttributes, false) case _ => val readRel = RelBuilder.makeReadRel( - new JArrayList[Attribute](plan.output.asJava), + plan.output.asJava, substraitContext, -1 ) /* A special handling in Join to delay the rel registration. */ @@ -335,7 +335,7 @@ object HashJoinLikeExecTransformer { leftType: DataType, rightNode: ExpressionNode, rightType: DataType, - functionMap: JHashMap[String, JLong]): ExpressionNode = { + functionMap: JMap[String, JLong]): ExpressionNode = { val functionId = ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName(ExpressionNames.EQUAL, Seq(leftType, rightType))) @@ -349,7 +349,7 @@ object HashJoinLikeExecTransformer { def makeAndExpression( leftNode: ExpressionNode, rightNode: ExpressionNode, - functionMap: JHashMap[String, JLong]): ExpressionNode = { + functionMap: JMap[String, JLong]): ExpressionNode = { val functionId = ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName(ExpressionNames.AND, Seq(BooleanType, BooleanType))) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/JoinUtils.scala b/gluten-core/src/main/scala/io/glutenproject/execution/JoinUtils.scala index 0ca1654a41d5f..4f256debb058b 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/JoinUtils.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/JoinUtils.scala @@ -17,7 +17,7 @@ package io.glutenproject.execution import io.glutenproject.expression.{AttributeReferenceTransformer, ConverterUtils, ExpressionConverter} -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} import io.glutenproject.substrait.extensions.{AdvancedExtensionNode, ExtensionBuilder} @@ -30,8 +30,6 @@ import org.apache.spark.sql.types.DataType import com.google.protobuf.Any import io.substrait.proto.JoinRel -import java.util - import scala.collection.JavaConverters._ object JoinUtils { @@ -43,7 +41,7 @@ object JoinUtils { // is also used in execution phase. In this case an empty typeUrlPrefix need to be passed, // so that it can be correctly parsed into json string on the cpp side. Any.pack( - TypeBuilder.makeStruct(false, new util.ArrayList[TypeNode](inputTypeNodes.asJava)).toProtobuf, + TypeBuilder.makeStruct(false, inputTypeNodes.asJava).toProtobuf, /* typeUrlPrefix */ "") } diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala index 3401b2866701a..ad7c68a6a9f03 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/LimitTransformer.scala @@ -20,7 +20,7 @@ import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression.ConverterUtils import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} @@ -32,7 +32,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.Any -import java.util +import scala.collection.JavaConverters._ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long) extends UnaryTransformSupport { @@ -71,11 +71,7 @@ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long) val relNode = if (childCtx != null) { getRelNode(context, operatorId, offset, count, child.output, childCtx.root, false) } else { - val attrList = new util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) getRelNode(context, operatorId, offset, count, child.output, readRel, false) } TransformContext(child.output, child.output, relNode) @@ -92,10 +88,8 @@ case class LimitTransformer(child: SparkPlan, offset: Long, count: Long) if (!validation) { RelBuilder.makeFetchRel(input, offset, count, context, operatorId) } else { - val inputTypeNodes = new util.ArrayList[TypeNode]() - for (attr <- inputAttributes) { - inputTypeNodes.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodes = + inputAttributes.map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)).asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodes).toProtobuf)) RelBuilder.makeFetchRel(input, offset, count, extensionNode, context, operatorId) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala index 34403dc5fa9b7..7280891cbc862 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/SortExecTransformer.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.Any import io.substrait.proto.SortField -import java.util +import java.util.{ArrayList => JArrayList} import scala.collection.JavaConverters._ import scala.util.control.Breaks.{break, breakable} @@ -72,13 +72,13 @@ case class SortExecTransformer( validation: Boolean): RelNode = { val args = context.registeredFunction - val sortFieldList = new util.ArrayList[SortField]() - val projectExpressions = new util.ArrayList[ExpressionNode]() - val sortExprArttributes = new util.ArrayList[AttributeReference]() + val sortFieldList = new JArrayList[SortField]() + val projectExpressions = new JArrayList[ExpressionNode]() + val sortExprAttributes = new JArrayList[AttributeReference]() val selectOrigins = - originalInputAttributes.indices.map(ExpressionBuilder.makeSelection(_)) - projectExpressions.addAll(selectOrigins.asJava) + originalInputAttributes.indices.map(ExpressionBuilder.makeSelection(_)).asJava + projectExpressions.addAll(selectOrigins) var colIdx = originalInputAttributes.size sortOrder.foreach( @@ -90,7 +90,7 @@ case class SortExecTransformer( projectExpressions.add(projectExprNode) val exprNode = ExpressionBuilder.makeSelection(colIdx) - sortExprArttributes.add(AttributeReference(s"col_$colIdx", order.child.dataType)()) + sortExprAttributes.add(AttributeReference(s"col_$colIdx", order.child.dataType)()) colIdx += 1 builder.setExpr(exprNode.toProtobuf) @@ -109,7 +109,7 @@ case class SortExecTransformer( for (attr <- originalInputAttributes) { inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) } - sortExprArttributes.forEach { + sortExprAttributes.forEach { attr => inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) } @@ -133,7 +133,7 @@ case class SortExecTransformer( inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) } - sortExprArttributes.forEach { + sortExprAttributes.forEach { attr => inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) } @@ -147,7 +147,7 @@ case class SortExecTransformer( if (!validation) { RelBuilder.makeProjectRel( sortRel, - new java.util.ArrayList[ExpressionNode](selectOrigins.asJava), + new JArrayList[ExpressionNode](selectOrigins), context, operatorId, originalInputAttributes.size + sortFieldList.size) @@ -162,7 +162,7 @@ case class SortExecTransformer( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) RelBuilder.makeProjectRel( sortRel, - new java.util.ArrayList[ExpressionNode](selectOrigins.asJava), + new JArrayList[ExpressionNode](selectOrigins), extensionNode, context, operatorId, @@ -178,9 +178,8 @@ case class SortExecTransformer( input: RelNode, validation: Boolean): RelNode = { val args = context.registeredFunction - val sortFieldList = new util.ArrayList[SortField]() - sortOrder.foreach( - order => { + val sortFieldList = sortOrder.map { + order => val builder = SortField.newBuilder() val exprNode = ExpressionConverter .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) @@ -189,20 +188,18 @@ case class SortExecTransformer( builder.setDirectionValue( SortExecTransformer.transformSortDirection(order.direction.sql, order.nullOrdering.sql)) - sortFieldList.add(builder.build()) - }) + builder.build() + } if (!validation) { - RelBuilder.makeSortRel(input, sortFieldList, context, operatorId) + RelBuilder.makeSortRel(input, sortFieldList.asJava, context, operatorId) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes.map( + attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) val extensionNode = ExtensionBuilder.makeAdvancedExtension( - Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) + Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList.asJava).toProtobuf)) - RelBuilder.makeSortRel(input, sortFieldList, extensionNode, context, operatorId) + RelBuilder.makeSortRel(input, sortFieldList.asJava, extensionNode, context, operatorId) } } @@ -263,11 +260,7 @@ case class SortExecTransformer( } else { // This means the input is just an iterator, so an ReadRel will be created as child. // Prepare the input schema. - val attrList = new util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) ( getRelNode(context, sortOrder, child.output, operatorId, readRel, validation = false), child.output) diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/SortMergeJoinExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/SortMergeJoinExecTransformer.scala index 3b3db8fca4411..3c2214356d7ec 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/SortMergeJoinExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/SortMergeJoinExecTransformer.scala @@ -33,8 +33,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.StringValue import io.substrait.proto.JoinRel -import java.util.{ArrayList => JArrayList} - import scala.collection.JavaConverters._ /** Performs a sort merge join of two child relations. */ @@ -258,7 +256,7 @@ case class SortMergeJoinExecTransformer( (transformContext.root, transformContext.outputAttributes, false) case _ => val readRel = RelBuilder.makeReadRel( - new JArrayList[Attribute](plan.output.asJava), + plan.output.asJava, context, -1 ) /* A special handling in Join to delay the rel registration. */ diff --git a/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala index 2b88639aae387..fdf23f2838532 100644 --- a/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/execution/WindowExecTransformer.scala @@ -20,9 +20,9 @@ import io.glutenproject.backendsapi.BackendsApiManager import io.glutenproject.expression._ import io.glutenproject.extension.ValidationResult import io.glutenproject.metrics.MetricsUpdater -import io.glutenproject.substrait.`type`.{TypeBuilder, TypeNode} +import io.glutenproject.substrait.`type`.TypeBuilder import io.glutenproject.substrait.SubstraitContext -import io.glutenproject.substrait.expression.{ExpressionNode, WindowFunctionNode} +import io.glutenproject.substrait.expression.WindowFunctionNode import io.glutenproject.substrait.extensions.ExtensionBuilder import io.glutenproject.substrait.rel.{RelBuilder, RelNode} @@ -36,7 +36,9 @@ import org.apache.spark.sql.vectorized.ColumnarBatch import com.google.protobuf.Any import io.substrait.proto.SortField -import java.util +import java.util.{ArrayList => JArrayList} + +import scala.collection.JavaConverters._ case class WindowExecTransformer( windowExpression: Seq[NamedExpression], @@ -89,7 +91,7 @@ case class WindowExecTransformer( validation: Boolean): RelNode = { val args = context.registeredFunction // WindowFunction Expressions - val windowExpressions = new util.ArrayList[WindowFunctionNode]() + val windowExpressions = new JArrayList[WindowFunctionNode]() BackendsApiManager.getSparkPlanExecApiInstance.genWindowFunctionsNode( windowExpression, windowExpressions, @@ -98,39 +100,37 @@ case class WindowExecTransformer( ) // Partition By Expressions - val partitionsExpressions = new util.ArrayList[ExpressionNode]() - partitionSpec.map { - partitionExpr => - val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(partitionExpr, attributeSeq = child.output) - .doTransform(args) - partitionsExpressions.add(exprNode) - } + val partitionsExpressions = partitionSpec + .map( + ExpressionConverter + .replaceWithExpressionTransformer(_, attributeSeq = child.output) + .doTransform(args)) + .asJava // Sort By Expressions - val sortFieldList = new util.ArrayList[SortField]() - sortOrder.map( - order => { - val builder = SortField.newBuilder() - val exprNode = ExpressionConverter - .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) - .doTransform(args) - builder.setExpr(exprNode.toProtobuf) - - (order.direction.sql, order.nullOrdering.sql) match { - case ("ASC", "NULLS FIRST") => - builder.setDirectionValue(1); - case ("ASC", "NULLS LAST") => - builder.setDirectionValue(2); - case ("DESC", "NULLS FIRST") => - builder.setDirectionValue(3); - case ("DESC", "NULLS LAST") => - builder.setDirectionValue(4); - case _ => - builder.setDirectionValue(0); - } - sortFieldList.add(builder.build()) - }) + val sortFieldList = + sortOrder.map { + order => + val builder = SortField.newBuilder() + val exprNode = ExpressionConverter + .replaceWithExpressionTransformer(order.child, attributeSeq = child.output) + .doTransform(args) + builder.setExpr(exprNode.toProtobuf) + + (order.direction.sql, order.nullOrdering.sql) match { + case ("ASC", "NULLS FIRST") => + builder.setDirectionValue(1); + case ("ASC", "NULLS LAST") => + builder.setDirectionValue(2); + case ("DESC", "NULLS FIRST") => + builder.setDirectionValue(3); + case ("DESC", "NULLS LAST") => + builder.setDirectionValue(4); + case _ => + builder.setDirectionValue(0); + } + builder.build() + }.asJava if (!validation) { RelBuilder.makeWindowRel( input, @@ -141,10 +141,9 @@ case class WindowExecTransformer( operatorId) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() - for (attr <- originalInputAttributes) { - inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) - } + val inputTypeNodeList = originalInputAttributes + .map(attr => ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) + .asJava val extensionNode = ExtensionBuilder.makeAdvancedExtension( Any.pack(TypeBuilder.makeStruct(false, inputTypeNodeList).toProtobuf)) @@ -210,11 +209,7 @@ case class WindowExecTransformer( } else { // This means the input is just an iterator, so an ReadRel will be created as child. // Prepare the input schema. - val attrList = new util.ArrayList[Attribute]() - for (attr <- child.output) { - attrList.add(attr) - } - val readRel = RelBuilder.makeReadRel(attrList, context, operatorId) + val readRel = RelBuilder.makeReadRel(child.output.asJava, context, operatorId) ( getRelNode( context, diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ArrayExpressionTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ArrayExpressionTransformer.scala index faa5a1a01ad30..a5ddad257446a 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ArrayExpressionTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ArrayExpressionTransformer.scala @@ -25,6 +25,8 @@ import org.apache.spark.sql.types._ import com.google.common.collect.Lists +import scala.collection.JavaConverters._ + case class CreateArrayTransformer( substraitExprName: String, children: Seq[ExpressionTransformer], @@ -40,12 +42,7 @@ case class CreateArrayTransformer( throw new UnsupportedOperationException(s"$original not supported yet.") } - val childNodes = new java.util.ArrayList[ExpressionNode]() - children.foreach( - child => { - val childNode = child.doTransform(args) - childNodes.add(childNode) - }) + val childNodes = children.map(_.doTransform(args)).asJava val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] val functionName = ConverterUtils.makeFuncName( @@ -87,11 +84,11 @@ case class GetArrayItemTransformer( ConverterUtils.getTypeNode(original.right.dataType, original.right.nullable)) BackendsApiManager.getSparkPlanExecApiInstance.genGetArrayItemExpressionNode( - substraitExprName: String, - functionMap: java.util.HashMap[String, java.lang.Long], - leftNode: ExpressionNode, - rightNode: ExpressionNode, - original: GetArrayItem + substraitExprName, + functionMap, + leftNode, + rightNode, + original ) } } diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala index 4555c88f61e5c..2e1b415e3f01c 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/ConverterUtils.scala @@ -104,32 +104,30 @@ object ConverterUtils extends Logging { getShortAttributeName(attr) + "#" + attr.exprId.id } - def collectAttributeTypeNodes(attributes: JList[Attribute]): JArrayList[TypeNode] = { + def collectAttributeTypeNodes(attributes: JList[Attribute]): JList[TypeNode] = { collectAttributeTypeNodes(attributes.asScala) } - def collectAttributeTypeNodes(attributes: Seq[Attribute]): JArrayList[TypeNode] = { - val typeList = new JArrayList[TypeNode]() - attributes.foreach(attr => typeList.add(getTypeNode(attr.dataType, attr.nullable))) - typeList + def collectAttributeTypeNodes(attributes: Seq[Attribute]): JList[TypeNode] = { + attributes.map(attr => getTypeNode(attr.dataType, attr.nullable)).asJava } - def collectAttributeNamesWithExprId(attributes: JList[Attribute]): JArrayList[String] = { + def collectAttributeNamesWithExprId(attributes: JList[Attribute]): JList[String] = { collectAttributeNamesWithExprId(attributes.asScala) } - def collectAttributeNamesWithExprId(attributes: Seq[Attribute]): JArrayList[String] = { + def collectAttributeNamesWithExprId(attributes: Seq[Attribute]): JList[String] = { collectAttributeNamesDFS(attributes)(genColumnNameWithExprId) } // TODO: This is used only by `BasicScanExecTransformer`, // perhaps we can remove this in the future and use `withExprId` version consistently. - def collectAttributeNamesWithoutExprId(attributes: Seq[Attribute]): JArrayList[String] = { + def collectAttributeNamesWithoutExprId(attributes: Seq[Attribute]): JList[String] = { collectAttributeNamesDFS(attributes)(genColumnNameWithoutExprId) } private def collectAttributeNamesDFS(attributes: Seq[Attribute])( - f: Attribute => String): JArrayList[String] = { + f: Attribute => String): JList[String] = { val nameList = new JArrayList[String]() attributes.foreach( attr => { @@ -146,7 +144,7 @@ object ConverterUtils extends Logging { nameList } - def collectStructFieldNames(dataType: DataType): JArrayList[String] = { + def collectStructFieldNames(dataType: DataType): JList[String] = { val nameList = new JArrayList[String]() dataType match { case structType: StructType => @@ -196,10 +194,10 @@ object ConverterUtils extends Logging { (DecimalType(precision, scale), isNullable(decimal.getNullability)) case Type.KindCase.STRUCT => val struct_ = substraitType.getStruct - val fields = new JArrayList[StructField] - for (typ <- struct_.getTypesList.asScala) { - val (field, nullable) = parseFromSubstraitType(typ) - fields.add(StructField("", field, nullable)) + val fields = struct_.getTypesList.asScala.map { + typ => + val (field, nullable) = parseFromSubstraitType(typ) + StructField("", field, nullable) } (StructType(fields), isNullable(substraitType.getStruct.getNullability)) case Type.KindCase.LIST => diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/DateTimeExpressionsTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/DateTimeExpressionsTransformer.scala index 38c86d46a6a78..7753690ddc699 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/DateTimeExpressionsTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/DateTimeExpressionsTransformer.scala @@ -25,6 +25,9 @@ import org.apache.spark.sql.types._ import com.google.common.collect.Lists +import java.lang.{Long => JLong} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap} + import scala.collection.JavaConverters._ /** The extract trait for 'GetDateField' from Date */ @@ -37,7 +40,7 @@ case class ExtractDateTransformer( override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode = child.doTransform(args) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName( substraitExprName, original.children.map(_.dataType), @@ -67,7 +70,7 @@ case class DateDiffTransformer( val endDateNode = endDate.doTransform(args) val startDateNode = startDate.doTransform(args) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionName = ConverterUtils.makeFuncName( substraitExprName, Seq(StringType, original.startDate.dataType, original.endDate.dataType), @@ -96,20 +99,20 @@ case class FromUnixTimeTransformer( val secNode = sec.doTransform(args) val formatNode = format.doTransform(args) - val dataTypes = if (timeZoneId != None) { + val dataTypes = if (timeZoneId.isDefined) { Seq(original.sec.dataType, original.format.dataType, StringType) } else { Seq(original.sec.dataType, original.format.dataType) } - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionId = ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName(substraitExprName, dataTypes)) - val expressionNodes = new java.util.ArrayList[ExpressionNode]() + val expressionNodes = new JArrayList[ExpressionNode]() expressionNodes.add(secNode) expressionNodes.add(formatNode) - if (timeZoneId != None) { + if (timeZoneId.isDefined) { expressionNodes.add(ExpressionBuilder.makeStringLiteral(timeZoneId.get)) } @@ -133,12 +136,12 @@ case class ToUnixTimestampTransformer( override def doTransform(args: java.lang.Object): ExpressionNode = { val dataTypes = Seq(original.timeExp.dataType, StringType) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] val functionId = ExpressionBuilder.newScalarFunction( functionMap, ConverterUtils.makeFuncName(substraitExprName, dataTypes)) - val expressionNodes = new java.util.ArrayList[ExpressionNode]() + val expressionNodes = new JArrayList[ExpressionNode]() val timeExpNode = timeExp.doTransform(args) expressionNodes.add(timeExpNode) val formatNode = format.doTransform(args) @@ -160,8 +163,8 @@ case class TruncTimestampTransformer( val timestampNode = timestamp.doTransform(args) val formatNode = format.doTransform(args) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val dataTypes = if (timeZoneId != None) { + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] + val dataTypes = if (timeZoneId.isDefined) { Seq(original.format.dataType, original.timestamp.dataType, StringType) } else { Seq(original.format.dataType, original.timestamp.dataType) @@ -171,10 +174,10 @@ case class TruncTimestampTransformer( functionMap, ConverterUtils.makeFuncName(substraitExprName, dataTypes)) - val expressionNodes = new java.util.ArrayList[ExpressionNode]() + val expressionNodes = new JArrayList[ExpressionNode]() expressionNodes.add(formatNode) expressionNodes.add(timestampNode) - if (timeZoneId != None) { + if (timeZoneId.isDefined) { expressionNodes.add(ExpressionBuilder.makeStringLiteral(timeZoneId.get)) } @@ -197,8 +200,8 @@ case class MonthsBetweenTransformer( val data2Node = date2.doTransform(args) val roundOffNode = roundOff.doTransform(args) - val functionMap = args.asInstanceOf[java.util.HashMap[String, java.lang.Long]] - val dataTypes = if (timeZoneId != None) { + val functionMap = args.asInstanceOf[JHashMap[String, JLong]] + val dataTypes = if (timeZoneId.isDefined) { Seq(original.date1.dataType, original.date2.dataType, original.roundOff.dataType, StringType) } else { Seq(original.date1.dataType, original.date2.dataType, original.roundOff.dataType) @@ -208,11 +211,11 @@ case class MonthsBetweenTransformer( functionMap, ConverterUtils.makeFuncName(substraitExprName, dataTypes)) - val expressionNodes = new java.util.ArrayList[ExpressionNode]() + val expressionNodes = new JArrayList[ExpressionNode]() expressionNodes.add(date1Node) expressionNodes.add(data2Node) expressionNodes.add(roundOffNode) - if (timeZoneId != None) { + if (timeZoneId.isDefined) { expressionNodes.add(ExpressionBuilder.makeStringLiteral(timeZoneId.get)) } diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/LambdaFunctionTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/LambdaFunctionTransformer.scala index ce0466afc19b8..05ba330256c7e 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/LambdaFunctionTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/LambdaFunctionTransformer.scala @@ -20,8 +20,6 @@ import io.glutenproject.substrait.expression.{ExpressionBuilder, ExpressionNode} import org.apache.spark.sql.catalyst.expressions.LambdaFunction -import java.util.ArrayList - case class LambdaFunctionTransformer( substraitExprName: String, function: ExpressionTransformer, @@ -42,7 +40,7 @@ case class LambdaFunctionTransformer( substraitExprName, Seq(original.dataType), ConverterUtils.FunctionConfig.OPT)) - val expressionNodes = new ArrayList[ExpressionNode] + val expressionNodes = new java.util.ArrayList[ExpressionNode] expressionNodes.add(function.doTransform(args)) arguments.foreach(argument => expressionNodes.add(argument.doTransform(args))) val typeNode = ConverterUtils.getTypeNode(original.dataType, original.nullable) diff --git a/gluten-core/src/main/scala/io/glutenproject/expression/StructExpressionTransformer.scala b/gluten-core/src/main/scala/io/glutenproject/expression/StructExpressionTransformer.scala index 363408e567c89..e5af5400ba883 100644 --- a/gluten-core/src/main/scala/io/glutenproject/expression/StructExpressionTransformer.scala +++ b/gluten-core/src/main/scala/io/glutenproject/expression/StructExpressionTransformer.scala @@ -31,7 +31,7 @@ case class GetStructFieldTransformer( original: GetStructField) extends ExpressionTransformer { - override def doTransform(args: Object): ExpressionNode = { + override def doTransform(args: java.lang.Object): ExpressionNode = { val childNode = childTransformer.doTransform(args) childNode match { case node: StructLiteralNode => diff --git a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala index b1cacd354d33b..ce37cbfc96f96 100644 --- a/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala +++ b/gluten-core/src/main/scala/io/glutenproject/substrait/SubstraitContext.scala @@ -22,7 +22,7 @@ import io.glutenproject.substrait.rel.LocalFilesNode.ReadFileFormat import java.lang.{Integer => JInt, Long => JLong} import java.security.InvalidParameterException -import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList} +import java.util.{ArrayList => JArrayList, HashMap => JHashMap, List => JList, Map => JMap} case class JoinParams() { // Whether the input of streamed side is a ReadRel represented iterator. @@ -69,7 +69,7 @@ class SubstraitContext extends Serializable { private val iteratorNodes = new JHashMap[JLong, LocalFilesNode]() // A map stores the relationship between Spark operator id and its respective Substrait Rel ids. - private val operatorToRelsMap = new JHashMap[JLong, JArrayList[JLong]]() + private val operatorToRelsMap: JMap[JLong, JList[JLong]] = new JHashMap[JLong, JList[JLong]]() // Only for debug conveniently private val operatorToPlanNameMap = new JHashMap[JLong, String]() @@ -181,7 +181,7 @@ class SubstraitContext extends Serializable { * Return the registered map. * @return */ - def registeredRelMap: JHashMap[JLong, JArrayList[JLong]] = operatorToRelsMap + def registeredRelMap: JMap[JLong, JList[JLong]] = operatorToRelsMap /** * Register the join params to certain operator id. diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala index e281d7f36f828..d6c13537ac6d6 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/GlutenExplainUtils.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveS import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec import org.apache.spark.sql.execution.exchange.{Exchange, ReusedExchangeExec} -import java.util.{IdentityHashMap, Set} +import java.util import java.util.Collections.newSetFromMap import scala.collection.mutable @@ -124,7 +124,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { try { // Initialize a reference-unique set of Operators to avoid accdiental overwrites and to allow // intentional overwriting of IDs generated in previous AQE iteration - val operators = newSetFromMap[QueryPlan[_]](new IdentityHashMap()) + val operators = newSetFromMap[QueryPlan[_]](new util.IdentityHashMap()) // Initialize an array of ReusedExchanges to help find Adaptively Optimized Out // Exchanges as part of SPARK-42753 val reusedExchanges = ArrayBuffer.empty[ReusedExchangeExec] @@ -224,7 +224,7 @@ object GlutenExplainUtils extends AdaptiveSparkPlanHelper { private def generateOperatorIDs( plan: QueryPlan[_], startOperatorID: Int, - visited: Set[QueryPlan[_]], + visited: util.Set[QueryPlan[_]], reusedExchanges: ArrayBuffer[ReusedExchangeExec], addReusedExchanges: Boolean): Int = { var currentOperationID = startOperatorID diff --git a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala index 7186cad59fbf7..4146cbc46b668 100644 --- a/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala +++ b/gluten-core/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExecTransformer.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.types.StructType import com.google.protobuf.Any -import java.util.ArrayList +import java.util.{ArrayList => JArrayList, List => JList} case class EvalPythonExecTransformer( udfs: Seq[PythonUDF], @@ -72,7 +72,7 @@ case class EvalPythonExecTransformer( val args = context.registeredFunction val operatorId = context.nextOperatorId(this.nodeName) - val expressionNodes = new java.util.ArrayList[ExpressionNode] + val expressionNodes = new JArrayList[ExpressionNode] child.output.zipWithIndex.foreach( x => expressionNodes.add(ExpressionBuilder.makeSelection(x._2))) udfs.foreach( @@ -94,7 +94,7 @@ case class EvalPythonExecTransformer( val args = context.registeredFunction val operatorId = context.nextOperatorId(this.nodeName) - val expressionNodes = new java.util.ArrayList[ExpressionNode] + val expressionNodes = new JArrayList[ExpressionNode] child.output.zipWithIndex.foreach( x => expressionNodes.add(ExpressionBuilder.makeSelection(x._2))) udfs.foreach( @@ -106,7 +106,7 @@ case class EvalPythonExecTransformer( val relNode = if (childCtx != null) { getRelNode(childCtx.root, expressionNodes, context, operatorId, child.output, false) } else { - val attrList = new java.util.ArrayList[Attribute]() + val attrList = new JArrayList[Attribute]() for (attr <- child.output) { attrList.add(attr) } @@ -119,7 +119,7 @@ case class EvalPythonExecTransformer( def getRelNode( input: RelNode, - expressionNodes: ArrayList[ExpressionNode], + expressionNodes: JList[ExpressionNode], context: SubstraitContext, operatorId: Long, inputAttributes: Seq[Attribute], @@ -128,7 +128,7 @@ case class EvalPythonExecTransformer( RelBuilder.makeProjectRel(input, expressionNodes, context, operatorId) } else { // Use a extension node to send the input types through Substrait plan for validation. - val inputTypeNodeList = new java.util.ArrayList[TypeNode]() + val inputTypeNodeList = new JArrayList[TypeNode]() for (attr <- inputAttributes) { inputTypeNodeList.add(ConverterUtils.getTypeNode(attr.dataType, attr.nullable)) } diff --git a/gluten-data/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala b/gluten-data/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala index d5afef91d58b3..0c70679eda641 100644 --- a/gluten-data/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala +++ b/gluten-data/src/main/scala/io/glutenproject/metrics/MetricsUtil.scala @@ -22,6 +22,9 @@ import io.glutenproject.substrait.{AggregationParams, JoinParams} import org.apache.spark.internal.Logging import org.apache.spark.sql.execution.SparkPlan +import java.lang.{Long => JLong} +import java.util.{ArrayList => JArrayList, List => JList, Map => JMap} + object MetricsUtil extends Logging { /** @@ -38,9 +41,9 @@ object MetricsUtil extends Logging { */ def updateNativeMetrics( child: SparkPlan, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = { + relMap: JMap[JLong, JList[JLong]], + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = { def treeifyMetricsUpdaters(plan: SparkPlan): MetricsUpdaterTree = { plan match { case j: HashJoinLikeExecTransformer => @@ -59,7 +62,7 @@ object MetricsUtil extends Logging { updateTransformerMetrics( mut, relMap, - java.lang.Long.valueOf(relMap.size() - 1), + JLong.valueOf(relMap.size() - 1), joinParamsMap, aggParamsMap) } @@ -72,8 +75,7 @@ object MetricsUtil extends Logging { * @return * the merged metrics */ - private def mergeMetrics( - operatorMetrics: java.util.ArrayList[OperatorMetrics]): OperatorMetrics = { + private def mergeMetrics(operatorMetrics: JList[OperatorMetrics]): OperatorMetrics = { if (operatorMetrics.size() == 0) { return null } @@ -165,13 +167,13 @@ object MetricsUtil extends Logging { */ def updateTransformerMetricsInternal( mutNode: MetricsUpdaterTree, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - operatorIdx: java.lang.Long, + relMap: JMap[JLong, JList[JLong]], + operatorIdx: JLong, metrics: Metrics, metricsIdx: Int, - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): (java.lang.Long, Int) = { - val operatorMetrics = new java.util.ArrayList[OperatorMetrics]() + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = { + val operatorMetrics = new JArrayList[OperatorMetrics]() var curMetricsIdx = metricsIdx relMap .get(operatorIdx) @@ -205,7 +207,7 @@ object MetricsUtil extends Logging { u.updateNativeMetrics(opMetrics) } - var newOperatorIdx: java.lang.Long = operatorIdx - 1 + var newOperatorIdx: JLong = operatorIdx - 1 var newMetricsIdx: Int = if ( mutNode.updater.isInstanceOf[LimitMetricsUpdater] && @@ -256,10 +258,10 @@ object MetricsUtil extends Logging { */ def updateTransformerMetrics( mutNode: MetricsUpdaterTree, - relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]], - operatorIdx: java.lang.Long, - joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams], - aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = { + relMap: JMap[JLong, JList[JLong]], + operatorIdx: JLong, + joinParamsMap: JMap[JLong, JoinParams], + aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = { imetrics => try { val metrics = imetrics.asInstanceOf[Metrics]