Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CORE] Use collection interface in method parameter and return type #3603

Merged
merged 1 commit into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.storage.CHShuffleReadStreamFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -88,8 +87,8 @@ public static long build(

/** create table named struct */
private static NamedStruct toNameStruct(List<Attribute> output) {
ArrayList<TypeNode> typeList = ConverterUtils.collectAttributeTypeNodes(output);
ArrayList<String> nameList = ConverterUtils.collectAttributeNamesWithExprId(output);
List<TypeNode> typeList = ConverterUtils.collectAttributeTypeNodes(output);
List<String> nameList = ConverterUtils.collectAttributeNamesWithExprId(output);
Type.Struct.Builder structBuilder = Type.Struct.newBuilder();
for (TypeNode typeNode : typeList) {
structBuilder.addTypes(typeNode.toProtobuf());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import java.lang.{Long => JLong}
import java.net.URI
import java.util
import java.util.{ArrayList => JArrayList}

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -68,9 +68,9 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
.makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath),
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p))
case f: FilePartition =>
val paths = new util.ArrayList[String]()
val starts = new util.ArrayList[JLong]()
val lengths = new util.ArrayList[JLong]()
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
val lengths = new JArrayList[JLong]()
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
f.files.foreach {
file =>
Expand Down Expand Up @@ -122,7 +122,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
val resIter: GeneralOutIterator = GlutenTimeMetric.millis(pipelineTime) {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new util.ArrayList[GeneralInIterator](inputIterators.map {
val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
transKernel.createKernelWithBatchIterator(inputPartition.plan, inBatchIters, false)
Expand Down Expand Up @@ -180,7 +180,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val columnarNativeIterator =
new java.util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
// we need to complete dependency RDD's firstly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap}

class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
override def metricsUpdatingFunction(
child: SparkPlan,
relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]],
joinParamsMap: util.HashMap[lang.Long, JoinParams],
aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import com.google.common.collect.Lists
import org.apache.commons.lang3.ClassUtils

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -64,7 +65,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Transform GetArrayItem to Substrait. */
override def genGetArrayItemExpressionNode(
substraitExprName: String,
functionMap: java.util.HashMap[String, java.lang.Long],
functionMap: JMap[String, JLong],
leftNode: ExpressionNode,
rightNode: ExpressionNode,
original: GetArrayItem): ExpressionNode = {
Expand Down Expand Up @@ -436,9 +437,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Generate window function node */
override def genWindowFunctionsNode(
windowExpression: Seq[NamedExpression],
windowExpressionNodes: util.ArrayList[WindowFunctionNode],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
args: util.HashMap[String, lang.Long]): Unit = {
args: JMap[String, JLong]): Unit = {

windowExpression.map {
windowExpr =>
Expand All @@ -451,7 +452,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
new util.ArrayList[ExpressionNode](),
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
Expand All @@ -467,7 +468,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
throw new UnsupportedOperationException(s"Not currently supported: $aggregateFunc.")
}

val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodeList = new JArrayList[ExpressionNode]()
aggregateFunc.children.foreach(
expr =>
childrenNodeList.add(
Expand Down Expand Up @@ -505,7 +506,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
}
}

val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodeList = new JArrayList[ExpressionNode]()
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric

import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._

object MetricsUtil extends Logging {
Expand Down Expand Up @@ -56,9 +59,9 @@ object MetricsUtil extends Logging {
*/
def updateNativeMetrics(
child: SparkPlan,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {

val mut: MetricsUpdaterTree = treeifyMetricsUpdaters(child)

Expand Down Expand Up @@ -90,10 +93,10 @@ object MetricsUtil extends Logging {
*/
def updateTransformerMetrics(
mutNode: MetricsUpdaterTree,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
operatorIdx: java.lang.Long,
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
operatorIdx: JLong,
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
imetrics =>
try {
val metrics = imetrics.asInstanceOf[NativeMetrics]
Expand Down Expand Up @@ -129,13 +132,13 @@ object MetricsUtil extends Logging {
*/
def updateTransformerMetricsInternal(
mutNode: MetricsUpdaterTree,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
operatorIdx: java.lang.Long,
relMap: JMap[JLong, JList[JLong]],
operatorIdx: JLong,
metrics: NativeMetrics,
metricsIdx: Int,
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): (java.lang.Long, Int) = {
val nodeMetricsList = new java.util.ArrayList[MetricsData]()
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = {
val nodeMetricsList = new JArrayList[MetricsData]()
var curMetricsIdx = metricsIdx
relMap
.get(operatorIdx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import java.lang.{Long => JLong}
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
import java.time.ZoneOffset
import java.util
import java.util.{ArrayList => JArrayList}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand All @@ -67,14 +68,14 @@ class IteratorApiImpl extends IteratorApi with Logging {

def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = mutable.ArrayBuffer.empty[String]
val starts = mutable.ArrayBuffer.empty[java.lang.Long]
val lengths = mutable.ArrayBuffer.empty[java.lang.Long]
val starts = mutable.ArrayBuffer.empty[JLong]
val lengths = mutable.ArrayBuffer.empty[JLong]
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
files.foreach {
file =>
paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name()))
starts.append(java.lang.Long.valueOf(file.start))
lengths.append(java.lang.Long.valueOf(file.length))
starts.append(JLong.valueOf(file.start))
lengths.append(JLong.valueOf(file.length))

val partitionColumn = mutable.Map.empty[String, String]
for (i <- 0 until file.partitionValues.numFields) {
Expand All @@ -90,7 +91,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
case _: TimestampType =>
TimestampFormatter
.getFractionFormatter(ZoneOffset.UTC)
.format(pn.asInstanceOf[java.lang.Long])
.format(pn.asInstanceOf[JLong])
case _ => pn.toString
}
}
Expand Down Expand Up @@ -139,7 +140,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()): Iterator[ColumnarBatch] = {
val beforeBuild = System.nanoTime()
val columnarNativeIterators =
new util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val transKernel = NativePlanEvaluator.create()
Expand Down Expand Up @@ -183,7 +184,7 @@ class IteratorApiImpl extends IteratorApi with Logging {

val transKernel = NativePlanEvaluator.create()
val columnarNativeIterator =
new util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val nativeResultIterator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap}

class MetricsApiImpl extends MetricsApi with Logging {
override def metricsUpdatingFunction(
child: SparkPlan,
relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]],
joinParamsMap: util.HashMap[lang.Long, JoinParams],
aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ import org.apache.commons.lang3.ClassUtils

import javax.ws.rs.core.UriBuilder

import java.lang.{Long => JLong}
import java.util.{Map => JMap}

import scala.collection.mutable.ArrayBuffer

class SparkPlanExecApiImpl extends SparkPlanExecApi {
Expand All @@ -67,7 +70,7 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
*/
override def genGetArrayItemExpressionNode(
substraitExprName: String,
functionMap: java.util.HashMap[String, java.lang.Long],
functionMap: JMap[String, JLong],
leftNode: ExpressionNode,
rightNode: ExpressionNode,
original: GetArrayItem): ExpressionNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDi
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet

import java.util
import java.util.{Map => JMap}

class TransformerApiImpl extends TransformerApi with Logging {

Expand Down Expand Up @@ -65,7 +65,7 @@ class TransformerApiImpl extends TransformerApi with Logging {
}

override def postProcessNativeConfig(
nativeConfMap: util.Map[String, String],
nativeConfMap: JMap[String, String],
backendPrefix: String): Unit = {
// TODO: IMPLEMENT SPECIAL PROCESS FOR VELOX BACKEND
}
Expand Down
Loading
Loading