Skip to content

Commit

Permalink
[jvm-packages] resolve spark compatibility issue (#10917)
Browse files Browse the repository at this point in the history

---------

Co-authored-by: Hyunsu Cho <[email protected]>
  • Loading branch information
wbo4958 and hcho3 authored Oct 23, 2024
1 parent c06994b commit 8a24892
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -561,7 +561,11 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
val featureName = getFeaturesCol
val missing = getMissing

val output = dataset.toDF().mapPartitions { rowIter =>
// Here, we use RDD instead of DF to avoid different encoders for different
// spark versions for the compatibility issue.
// 3.5+, Encoders.row(schema)
// 3.5-, RowEncoder(schema)
val outRDD = dataset.asInstanceOf[Dataset[Row]].rdd.mapPartitions { rowIter =>
rowIter.grouped(inferBatchSize).flatMap { batchRow =>
val features = batchRow.iterator.map(row => row.getAs[Vector](
row.fieldIndex(featureName)))
Expand All @@ -573,8 +577,9 @@ private[spark] trait XGBoostModel[M <: XGBoostModel[M]] extends Model[M] with ML
dm.delete()
}
}
}
val output = dataset.sparkSession.createDataFrame(outRDD, schema)

}(Encoders.row(schema))
bBooster.unpersist(blocking = false)
postTransform(output, pred).toDF()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

package org.apache.spark.ml.xgboost

import org.apache.spark.SparkContext
import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.ml.classification.ProbabilisticClassifierParams
import org.apache.spark.ml.linalg.VectorUDT
import org.apache.spark.ml.param.Params
import org.apache.spark.ml.util.{DatasetUtils, DefaultParamsReader, DefaultParamsWriter, SchemaUtils}
import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetadataUtils, SchemaUtils}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
import org.apache.spark.sql.{Column, Dataset, Row}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, IntegerType, StructType}
import org.json4s.{JObject, JValue}

import ml.dmlc.xgboost4j.scala.spark.params.NonXGBoostParams
Expand Down Expand Up @@ -57,8 +58,52 @@ trait XGBProbabilisticClassifierParams[T <: Params]
/** Utils to access the spark internal functions */
object SparkUtils {

private def checkClassificationLabels(
labelCol: String,
numClasses: Option[Int]): Column = {
val casted = col(labelCol).cast(DoubleType)
numClasses match {
case Some(2) =>
when(casted.isNull || casted.isNaN, raise_error(lit("Labels MUST NOT be Null or NaN")))
.when(casted =!= 0 && casted =!= 1,
raise_error(concat(lit("Labels MUST be in {0, 1}, but got "), casted)))
.otherwise(casted)

case _ =>
val n = numClasses.getOrElse(Int.MaxValue)
require(0 < n && n <= Int.MaxValue)
when(casted.isNull || casted.isNaN, raise_error(lit("Labels MUST NOT be Null or NaN")))
.when(casted < 0 || casted >= n,
raise_error(concat(lit(s"Labels MUST be in [0, $n), but got "), casted)))
.when(casted =!= casted.cast(IntegerType),
raise_error(concat(lit("Labels MUST be Integers, but got "), casted)))
.otherwise(casted)
}
}

// Copied from DatasetUtils of Spark to compatible with spark below 3.4
def getNumClasses(dataset: Dataset[_], labelCol: String, maxNumClasses: Int = 100): Int = {
DatasetUtils.getNumClasses(dataset, labelCol, maxNumClasses)
MetadataUtils.getNumClasses(dataset.schema(labelCol)) match {
case Some(n: Int) => n
case None =>
// Get number of classes from dataset itself.
val maxLabelRow: Array[Row] = dataset
.select(max(checkClassificationLabels(labelCol, Some(maxNumClasses))))
.take(1)
if (maxLabelRow.isEmpty || maxLabelRow(0).get(0) == null) {
throw new SparkException("ML algorithm was given empty dataset.")
}
val maxDoubleLabel: Double = maxLabelRow.head.getDouble(0)
require((maxDoubleLabel + 1).isValidInt, s"Classifier found max label value =" +
s" $maxDoubleLabel but requires integers in range [0, ... ${Int.MaxValue})")
val numClasses = maxDoubleLabel.toInt + 1
require(numClasses <= maxNumClasses, s"Classifier inferred $numClasses from label values" +
s" in column $labelCol, but this exceeded the max numClasses ($maxNumClasses) allowed" +
s" to be inferred from values. To avoid this error for labels with > $maxNumClasses" +
s" classes, specify numClasses explicitly in the metadata; this can be done by applying" +
s" StringIndexer to the label column.")
numClasses
}
}

def checkNumericType(schema: StructType, colName: String, msg: String = ""): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,29 +100,32 @@ class XGBoostSuite extends AnyFunSuite with PerTest {
.config("spark.executor.cores", 4)
.config("spark.executor.resource.gpu.amount", 1)
.config("spark.task.resource.gpu.amount", 0.25)

val ss = builder.getOrCreate()

try {
val df = ss.range(1, 10)
val rdd = df.rdd

val runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)

val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, runtimeParams,
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])

val taskResources = finalRDD.getResourceProfile().taskResources
assert(taskResources.contains("cpus"))
assert(taskResources.get("cpus").get.amount == 3)

assert(taskResources.contains("gpu"))
assert(taskResources.get("gpu").get.amount == 1.0)
} finally {
if (ss.version < "3.4.1") {
// Pass
ss.stop()
} else {
try {
val df = ss.range(1, 10)
val rdd = df.rdd

val runtimeParams = new XGBoostClassifier(
Map("device" -> "cuda")).setNumWorkers(1).setNumRound(1)
.getRuntimeParameters(true)
assert(runtimeParams.runOnGpu)

val finalRDD = FakedXGBoost.tryStageLevelScheduling(ss.sparkContext, runtimeParams,
rdd.asInstanceOf[RDD[(Booster, Map[String, Array[Float]])]])

val taskResources = finalRDD.getResourceProfile().taskResources
assert(taskResources.contains("cpus"))
assert(taskResources.get("cpus").get.amount == 3)

assert(taskResources.contains("gpu"))
assert(taskResources.get("gpu").get.amount == 1.0)
} finally {
ss.stop()
}
}
}
}
18 changes: 15 additions & 3 deletions tests/ci_build/build_jvm_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,25 @@ if [ "x$gpu_arch" != "x" ]; then
export GPU_ARCH_FLAG=$gpu_arch
fi

# Purge artifacts and set correct Scala version
pushd ..
if [ "x$use_scala213" != "x" ]; then
cd ..
python dev/change_scala_version.py --scala-version 2.13 --purge-artifacts
cd jvm-packages
else
python dev/change_scala_version.py --scala-version 2.12 --purge-artifacts
fi
popd

# Build and test XGBoost4j-spark against different spark versions only for CPU and scala=2.12
if [ "x$gpu_options" == "x" ] && [ "x$use_scala213" == "x" ]; then
mvn --no-transfer-progress clean package -Dspark.version=3.1.3 -pl xgboost4j,xgboost4j-spark
mvn --no-transfer-progress clean package -Dspark.version=3.2.4 -pl xgboost4j,xgboost4j-spark
mvn --no-transfer-progress clean package -Dspark.version=3.3.4 -pl xgboost4j,xgboost4j-spark
mvn --no-transfer-progress clean package -Dspark.version=3.4.3 -pl xgboost4j,xgboost4j-spark
fi

mvn --no-transfer-progress clean package -Dspark.version=${spark_version} $gpu_options

mvn --no-transfer-progress package -Dspark.version=${spark_version} $gpu_options

set +x
set +e

0 comments on commit 8a24892

Please sign in to comment.