Skip to content

Commit

Permalink
Fix java collection usage
Browse files Browse the repository at this point in the history
  • Loading branch information
liujiayi771 committed Nov 2, 2023
1 parent f82737a commit b538196
Show file tree
Hide file tree
Showing 78 changed files with 641 additions and 655 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.storage.CHShuffleReadStreamFactory;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -88,8 +87,8 @@ public static long build(

/** create table named struct */
private static NamedStruct toNameStruct(List<Attribute> output) {
ArrayList<TypeNode> typeList = ConverterUtils.collectAttributeTypeNodes(output);
ArrayList<String> nameList = ConverterUtils.collectAttributeNamesWithExprId(output);
List<TypeNode> typeList = ConverterUtils.collectAttributeTypeNodes(output);
List<String> nameList = ConverterUtils.collectAttributeNamesWithExprId(output);
Type.Struct.Builder structBuilder = Type.Struct.newBuilder();
for (TypeNode typeNode : typeList) {
structBuilder.addTypes(typeNode.toProtobuf());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.sql.vectorized.ColumnarBatch

import java.lang.{Long => JLong}
import java.net.URI
import java.util
import java.util.{ArrayList => JArrayList}

import scala.collection.JavaConverters._
import scala.collection.mutable
Expand All @@ -68,9 +68,9 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
.makeExtensionTable(p.minParts, p.maxParts, p.database, p.table, p.tablePath),
SoftAffinityUtil.getNativeMergeTreePartitionLocations(p))
case f: FilePartition =>
val paths = new util.ArrayList[String]()
val starts = new util.ArrayList[JLong]()
val lengths = new util.ArrayList[JLong]()
val paths = new JArrayList[String]()
val starts = new JArrayList[JLong]()
val lengths = new JArrayList[JLong]()
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
f.files.foreach {
file =>
Expand Down Expand Up @@ -122,7 +122,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
val resIter: GeneralOutIterator = GlutenTimeMetric.millis(pipelineTime) {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val inBatchIters = new util.ArrayList[GeneralInIterator](inputIterators.map {
val inBatchIters = new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
transKernel.createKernelWithBatchIterator(inputPartition.plan, inBatchIters, false)
Expand Down Expand Up @@ -180,7 +180,7 @@ class CHIteratorApi extends IteratorApi with Logging with LogLevelUtil {
_ =>
val transKernel = new CHNativeExpressionEvaluator()
val columnarNativeIterator =
new java.util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarNativeIterator(genCloseableColumnBatchIterator(iter).asJava)
}.asJava)
// we need to complete dependency RDD's firstly
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap}

class CHMetricsApi extends MetricsApi with Logging with LogLevelUtil {
override def metricsUpdatingFunction(
child: SparkPlan,
relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]],
joinParamsMap: util.HashMap[lang.Long, JoinParams],
aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
import com.google.common.collect.Lists
import org.apache.commons.lang3.ClassUtils

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.mutable.ArrayBuffer

Expand All @@ -64,7 +65,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Transform GetArrayItem to Substrait. */
override def genGetArrayItemExpressionNode(
substraitExprName: String,
functionMap: java.util.HashMap[String, java.lang.Long],
functionMap: JMap[String, JLong],
leftNode: ExpressionNode,
rightNode: ExpressionNode,
original: GetArrayItem): ExpressionNode = {
Expand Down Expand Up @@ -436,9 +437,9 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
/** Generate window function node */
override def genWindowFunctionsNode(
windowExpression: Seq[NamedExpression],
windowExpressionNodes: util.ArrayList[WindowFunctionNode],
windowExpressionNodes: JList[WindowFunctionNode],
originalInputAttributes: Seq[Attribute],
args: util.HashMap[String, lang.Long]): Unit = {
args: JMap[String, JLong]): Unit = {

windowExpression.map {
windowExpr =>
Expand All @@ -451,7 +452,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
val frame = aggWindowFunc.frame.asInstanceOf[SpecifiedWindowFrame]
val windowFunctionNode = ExpressionBuilder.makeWindowFunction(
WindowFunctionsBuilder.create(args, aggWindowFunc).toInt,
new util.ArrayList[ExpressionNode](),
new JArrayList[ExpressionNode](),
columnName,
ConverterUtils.getTypeNode(aggWindowFunc.dataType, aggWindowFunc.nullable),
WindowExecTransformer.getFrameBound(frame.upper),
Expand All @@ -467,7 +468,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
throw new UnsupportedOperationException(s"Not currently supported: $aggregateFunc.")
}

val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodeList = new JArrayList[ExpressionNode]()
aggregateFunc.children.foreach(
expr =>
childrenNodeList.add(
Expand Down Expand Up @@ -505,7 +506,7 @@ class CHSparkPlanExecApi extends SparkPlanExecApi {
}
}

val childrenNodeList = new util.ArrayList[ExpressionNode]()
val childrenNodeList = new JArrayList[ExpressionNode]()
childrenNodeList.add(
ExpressionConverter
.replaceWithExpressionTransformer(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ import org.apache.spark.util.collection.BitSet

import com.google.common.collect.Lists

import java.util
import java.util.{Map => JMap, Set => JSet}

class CHTransformerApi extends TransformerApi with Logging {

Expand Down Expand Up @@ -103,7 +103,7 @@ class CHTransformerApi extends TransformerApi with Logging {
}

override def postProcessNativeConfig(
nativeConfMap: util.Map[String, String],
nativeConfMap: JMap[String, String],
backendPrefix: String): Unit = {
val settingPrefix = backendPrefix + ".runtime_settings."
if (nativeConfMap.getOrDefault("spark.memory.offHeap.enabled", "false").toBoolean) {
Expand Down Expand Up @@ -182,7 +182,7 @@ class CHTransformerApi extends TransformerApi with Logging {
}
}

override def getSupportExpressionClassName: util.Set[String] = {
override def getSupportExpressionClassName: JSet[String] = {
ExpressionDocUtil.supportExpression()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.SQLMetric

import java.lang.{Long => JLong}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}

import scala.collection.JavaConverters._

object MetricsUtil extends Logging {
Expand Down Expand Up @@ -56,9 +59,9 @@ object MetricsUtil extends Logging {
*/
def updateNativeMetrics(
child: SparkPlan,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {

val mut: MetricsUpdaterTree = treeifyMetricsUpdaters(child)

Expand Down Expand Up @@ -90,10 +93,10 @@ object MetricsUtil extends Logging {
*/
def updateTransformerMetrics(
mutNode: MetricsUpdaterTree,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
operatorIdx: java.lang.Long,
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
operatorIdx: JLong,
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
imetrics =>
try {
val metrics = imetrics.asInstanceOf[NativeMetrics]
Expand Down Expand Up @@ -129,13 +132,13 @@ object MetricsUtil extends Logging {
*/
def updateTransformerMetricsInternal(
mutNode: MetricsUpdaterTree,
relMap: java.util.HashMap[java.lang.Long, java.util.ArrayList[java.lang.Long]],
operatorIdx: java.lang.Long,
relMap: JMap[JLong, JList[JLong]],
operatorIdx: JLong,
metrics: NativeMetrics,
metricsIdx: Int,
joinParamsMap: java.util.HashMap[java.lang.Long, JoinParams],
aggParamsMap: java.util.HashMap[java.lang.Long, AggregationParams]): (java.lang.Long, Int) = {
val nodeMetricsList = new java.util.ArrayList[MetricsData]()
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): (JLong, Int) = {
val nodeMetricsList = new JArrayList[MetricsData]()
var curMetricsIdx = metricsIdx
relMap
.get(operatorIdx)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ import org.apache.spark.sql.utils.OASPackageBridge.InputMetricsWrapper
import org.apache.spark.sql.vectorized.ColumnarBatch
import org.apache.spark.util.ExecutorManager

import java.lang.{Long => JLong}
import java.net.URLDecoder
import java.nio.charset.StandardCharsets
import java.time.ZoneOffset
import java.util
import java.util.{ArrayList => JArrayList}
import java.util.concurrent.TimeUnit

import scala.collection.JavaConverters._
Expand All @@ -67,14 +68,14 @@ class IteratorApiImpl extends IteratorApi with Logging {

def constructSplitInfo(schema: StructType, files: Array[PartitionedFile]) = {
val paths = mutable.ArrayBuffer.empty[String]
val starts = mutable.ArrayBuffer.empty[java.lang.Long]
val lengths = mutable.ArrayBuffer.empty[java.lang.Long]
val starts = mutable.ArrayBuffer.empty[JLong]
val lengths = mutable.ArrayBuffer.empty[JLong]
val partitionColumns = mutable.ArrayBuffer.empty[Map[String, String]]
files.foreach {
file =>
paths.append(URLDecoder.decode(file.filePath.toString, StandardCharsets.UTF_8.name()))
starts.append(java.lang.Long.valueOf(file.start))
lengths.append(java.lang.Long.valueOf(file.length))
starts.append(JLong.valueOf(file.start))
lengths.append(JLong.valueOf(file.length))

val partitionColumn = mutable.Map.empty[String, String]
for (i <- 0 until file.partitionValues.numFields) {
Expand All @@ -90,7 +91,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
case _: TimestampType =>
TimestampFormatter
.getFractionFormatter(ZoneOffset.UTC)
.format(pn.asInstanceOf[java.lang.Long])
.format(pn.asInstanceOf[JLong])
case _ => pn.toString
}
}
Expand Down Expand Up @@ -139,7 +140,7 @@ class IteratorApiImpl extends IteratorApi with Logging {
inputIterators: Seq[Iterator[ColumnarBatch]] = Seq()): Iterator[ColumnarBatch] = {
val beforeBuild = System.nanoTime()
val columnarNativeIterators =
new util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val transKernel = NativePlanEvaluator.create()
Expand Down Expand Up @@ -183,7 +184,7 @@ class IteratorApiImpl extends IteratorApi with Logging {

val transKernel = NativePlanEvaluator.create()
val columnarNativeIterator =
new util.ArrayList[GeneralInIterator](inputIterators.map {
new JArrayList[GeneralInIterator](inputIterators.map {
iter => new ColumnarBatchInIterator(iter.asJava)
}.asJava)
val nativeResultIterator =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}

import java.{lang, util}
import java.lang.{Long => JLong}
import java.util.{List => JList, Map => JMap}

class MetricsApiImpl extends MetricsApi with Logging {
override def metricsUpdatingFunction(
child: SparkPlan,
relMap: util.HashMap[lang.Long, util.ArrayList[lang.Long]],
joinParamsMap: util.HashMap[lang.Long, JoinParams],
aggParamsMap: util.HashMap[lang.Long, AggregationParams]): IMetrics => Unit = {
relMap: JMap[JLong, JList[JLong]],
joinParamsMap: JMap[JLong, JoinParams],
aggParamsMap: JMap[JLong, AggregationParams]): IMetrics => Unit = {
MetricsUtil.updateNativeMetrics(child, relMap, joinParamsMap, aggParamsMap)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ import org.apache.commons.lang3.ClassUtils

import javax.ws.rs.core.UriBuilder

import java.lang.{Long => JLong}
import java.util.{Map => JMap}

import scala.collection.mutable.ArrayBuffer

class SparkPlanExecApiImpl extends SparkPlanExecApi {
Expand All @@ -67,7 +70,7 @@ class SparkPlanExecApiImpl extends SparkPlanExecApi {
*/
override def genGetArrayItemExpressionNode(
substraitExprName: String,
functionMap: java.util.HashMap[String, java.lang.Long],
functionMap: JMap[String, JLong],
leftNode: ExpressionNode,
rightNode: ExpressionNode,
original: GetArrayItem): ExpressionNode = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, PartitionDi
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.BitSet

import java.util
import java.util.{Map => JMap}

class TransformerApiImpl extends TransformerApi with Logging {

Expand Down Expand Up @@ -65,7 +65,7 @@ class TransformerApiImpl extends TransformerApi with Logging {
}

override def postProcessNativeConfig(
nativeConfMap: util.Map[String, String],
nativeConfMap: JMap[String, String],
backendPrefix: String): Unit = {
// TODO: IMPLEMENT SPECIAL PROCESS FOR VELOX BACKEND
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ import org.apache.spark.sql.expression.UDFResolver
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._

import java.util.{Map => JMap}

import scala.util.control.Breaks.breakable

class VeloxBackend extends Backend {
Expand Down Expand Up @@ -302,7 +304,7 @@ object BackendSettings extends BackendSettingsApi {

override def shuffleSupportedCodec(): Set[String] = SHUFFLE_SUPPORTED_CODEC

override def resolveNativeConf(nativeConf: java.util.Map[String, String]): Unit = {
override def resolveNativeConf(nativeConf: JMap[String, String]): Unit = {
UDFResolver.resolveUdfConf(nativeConf)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import io.glutenproject.substrait.rel.RelBuilder
import org.apache.spark.sql.catalyst.expressions.{And, Attribute, Expression}
import org.apache.spark.sql.execution.SparkPlan

import java.util
import java.util.{ArrayList => JArrayList}

import scala.collection.JavaConverters._

Expand Down Expand Up @@ -73,7 +73,7 @@ case class FilterExecTransformer(condition: Expression, child: SparkPlan)
} else {
// This means the input is just an iterator, so an ReadRel will be created as child.
// Prepare the input schema.
val attrList = new util.ArrayList[Attribute](child.output.asJava)
val attrList = new JArrayList[Attribute](child.output.asJava)
getRelNode(
context,
leftCondition,
Expand Down
Loading

0 comments on commit b538196

Please sign in to comment.