Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49249][SPARK-49122] Artifact isolation in Spark Classic #48120

Open
wants to merge 43 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
fe143d6
make it work for sql
xupefei Sep 12, 2024
e2597c1
REPL
xupefei Sep 16, 2024
827e01e
.
xupefei Sep 16, 2024
73d13f9
.
xupefei Sep 19, 2024
caa4251
revert fmt change
xupefei Sep 23, 2024
5043ce3
try fix
xupefei Sep 25, 2024
3c8fef5
.
xupefei Sep 25, 2024
255e85d
the other way around
xupefei Sep 25, 2024
7e3ecfe
mima
xupefei Sep 25, 2024
a9c20e0
OOPS
xupefei Sep 25, 2024
355bea8
fmt
xupefei Sep 25, 2024
6564ccf
Merge remote-tracking branch 'databricks/master' into session-artifac…
xupefei Sep 25, 2024
06a658c
handle hive
xupefei Sep 25, 2024
225ec6f
fix addjar, break connect repl
xupefei Sep 26, 2024
e4f5a5c
ugly fix for streaming
xupefei Sep 26, 2024
7630e2f
clone
xupefei Oct 2, 2024
7b8f1da
.
xupefei Oct 2, 2024
786d48f
.
xupefei Oct 2, 2024
1849ac5
Merge branch 'clone-artifact-manager' into session-artifact-apply
xupefei Oct 2, 2024
a0ae922
undo
xupefei Oct 2, 2024
bfa6d85
address comments
xupefei Oct 3, 2024
fe7947f
address comments
xupefei Oct 4, 2024
04a5bb2
Merge branch 'clone-artifact-manager' into session-artifact-apply
xupefei Oct 4, 2024
24f99a5
.
xupefei Oct 4, 2024
39a8086
wip
xupefei Oct 7, 2024
7cce314
address comment
xupefei Oct 7, 2024
80289b8
.
xupefei Oct 8, 2024
4542a21
rvt
xupefei Oct 8, 2024
0b021d9
fix (hopefully) all tests
xupefei Oct 9, 2024
97c7d6c
remove reuse code
xupefei Oct 9, 2024
aa9c21d
.
xupefei Oct 9, 2024
a2849f8
fix pyspark
xupefei Oct 9, 2024
508ee7b
disable hive
xupefei Oct 9, 2024
fdcb05b
omg
xupefei Oct 9, 2024
c9cf1a2
why so slow
xupefei Oct 10, 2024
be49405
why so slow try 2
xupefei Oct 10, 2024
5c15612
Merge remote-tracking branch 'origin/clone-artifact-manager' into ses…
xupefei Oct 10, 2024
3899b22
make streaming great again
xupefei Oct 10, 2024
d8ec1d3
.
xupefei Oct 10, 2024
3bcda6d
optimizzzzze
xupefei Oct 11, 2024
216b467
lemme try if this can make things faster
xupefei Oct 11, 2024
4de3ce8
cache AppClassLoader
xupefei Oct 12, 2024
7a0910b
.
xupefei Oct 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,8 @@ object CheckConnectJvmClientCompatibility {
"org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.initializeLogIfNecessary$default$2"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"org.apache.spark.sql.UDFRegistration.registerJava"),

// Protected DataFrameReader methods...
ProblemFilters.exclude[DirectMissingMethodProblem](
Expand Down
8 changes: 6 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkFiles.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@ object SparkFiles {
/**
* Get the absolute path of a file added through `SparkContext.addFile()`.
*/
def get(filename: String): String =
new File(getRootDirectory(), filename).getAbsolutePath()
def get(filename: String): String = {
val jobArtifactUUID = JobArtifactSet
.getCurrentJobArtifactState.map(_.uuid).getOrElse("default")
val withUuid = if (jobArtifactUUID == "default") filename else s"$jobArtifactUUID/$filename"
new File(getRootDirectory(), withUuid).getAbsolutePath
}

/**
* Get the root directory that contains files added through `SparkContext.addFile()`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ private[spark] class BlockManager(
* '''Important!''' Callers must not mutate or release the data buffer underlying `bytes`. Doing
* so may corrupt or change the data stored by the `BlockManager`.
*/
private case class ByteBufferBlockStoreUpdater[T](
private[spark] case class ByteBufferBlockStoreUpdater[T](
blockId: BlockId,
level: StorageLevel,
classTag: ClassTag[T],
Expand Down
22 changes: 16 additions & 6 deletions core/src/main/scala/org/apache/spark/util/Utils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -176,15 +176,25 @@ private[spark] object Utils
}

/**
* Run a segment of code using a different context class loader in the current thread
*/
def withContextClassLoader[T](ctxClassLoader: ClassLoader)(fn: => T): T = {
val oldClassLoader = Thread.currentThread().getContextClassLoader()
* Run a segment of code using a different context class loader in the current thread.
*
* If `retainChange` is `true` and `fn` changed the context class loader during execution,
* the class loader will be not reverted to the original one when this method returns.
*/
def withContextClassLoader[T](
ctxClassLoader: ClassLoader,
retainChange: Boolean = false)(fn: => T): T = {
val oldClassLoader = Thread.currentThread().getContextClassLoader
var classLoaderAfterFn: ClassLoader = null
try {
Thread.currentThread().setContextClassLoader(ctxClassLoader)
fn
val ret = fn
classLoaderAfterFn = Thread.currentThread().getContextClassLoader
ret
} finally {
Thread.currentThread().setContextClassLoader(oldClassLoader)
if (!retainChange || classLoaderAfterFn == ctxClassLoader) {
Thread.currentThread().setContextClassLoader(oldClassLoader)
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
DEFAULT_CONFIGS: Dict[str, Any] = {
"spark.serializer.objectStreamReset": 100,
"spark.rdd.compress": True,
# Disable artifact isolation in PySpark, or user-added .py file won't work
"spark.session.isolate.artifacts": "false",
}

T = TypeVar("T")
Expand Down
6 changes: 5 additions & 1 deletion python/pyspark/sql/connect/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,7 +1027,11 @@ def _start_connect_server(master: str, opts: Dict[str, Any]) -> None:
os.environ["SPARK_LOCAL_CONNECT"] = "1"

# Configurations to be set if unset.
default_conf = {"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin"}
default_conf = {
"spark.plugins": "org.apache.spark.sql.connect.SparkConnectPlugin",
"spark.repl.isolate.artifacts": "true",
"spark.session.isolate.artifacts": "true",
}

if "SPARK_TESTING" in os.environ:
# For testing, we use 0 to use an ephemeral port to allow parallel testing.
Expand Down
Binary file added repl/src/test/resources/IntSumUdf.class
Binary file not shown.
22 changes: 22 additions & 0 deletions repl/src/test/resources/IntSumUdf.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import org.apache.spark.sql.api.java.UDF2

class IntSumUdf extends UDF2[Long, Long, Long] {
override def call(t1: Long, t2: Long): Long = t1 + t2
}
63 changes: 63 additions & 0 deletions repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -396,4 +396,67 @@ class ReplSuite extends SparkFunSuite {
Main.sparkContext.stop()
System.clearProperty("spark.driver.port")
}

test("register UDF via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.api.java.UDF2
|import org.apache.spark.sql.types.DataTypes
|
|spark.addArtifact("${intSumUdfPath.toString}")
|
|spark.udf.registerJava("intSum", "IntSumUdf", DataTypes.LongType)
|
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
|
""".stripMargin)
assertContains("Array([1], [3], [5], [7], [9])", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertDoesNotContain("assertion failed", output)

// The UDF should not work in a new REPL session.
val anotherOutput = runInterpreterInPasteMode("local",
s"""
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|
""".stripMargin)
assertContains(
"[UNRESOLVED_ROUTINE] Cannot resolve routine `intSum` on search path",
anotherOutput)
}

test("register a class via SparkSession.addArtifact") {
val artifactPath = new File("src/test/resources").toPath
val intSumUdfPath = artifactPath.resolve("IntSumUdf.class")
val output = runInterpreterInPasteMode("local",
s"""
|import org.apache.spark.sql.functions.udf
|
|spark.addArtifact("${intSumUdfPath.toString}")
|
|val intSumUdf = udf((x: Long, y: Long) => new IntSumUdf().call(x, y))
|spark.udf.register("intSum", intSumUdf)
|
|val r = spark.range(5)
| .withColumn("id2", col("id") + 1)
| .selectExpr("intSum(id, id2)")
| .collect()
|assert(r.map(_.getLong(0)).toSeq == Seq(1, 3, 5, 7, 9))
|
""".stripMargin)
assertContains("Array([1], [3], [5], [7], [9])", output)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
assertDoesNotContain("assertion failed", output)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ private[sql] object SimpleSparkConnectService {
def main(args: Array[String]): Unit = {
val conf = new SparkConf()
.set("spark.plugins", "org.apache.spark.sql.connect.SparkConnectPlugin")
.set("spark.repl.isolate.artifacts", "true")
.set("spark.session.isolate.artifacts", "true")
val sparkSession = SparkSession.builder().config(conf).getOrCreate()
val sparkContext = sparkSession.sparkContext // init spark context
// scalastyle:off println
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ object SparkConnectServer extends Logging {
def main(args: Array[String]): Unit = {
// Set the active Spark Session, and starts SparkEnv instance (via Spark Context)
logInfo("Starting Spark session.")
val session = SparkSession.builder().getOrCreate()
val session = SparkSession
.builder()
.config("spark.repl.isolate.artifacts", "true")
.config("spark.session.isolate.artifacts", "true")
.getOrCreate()
try {
try {
SparkConnectService.start(session.sparkContext)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ class SparkSession private(
Map.empty,
managedJobTags.asScala.toMap)
result.sessionState // force copy of SessionState
result.sessionState.artifactManager // force copy of ArtifactManager and its resources
result.managedJobTags // force copy of userDefinedToRealTagsMap
result
}
Expand Down Expand Up @@ -898,6 +899,7 @@ object SparkSession extends api.SparkSessionCompanion with Logging {
override def enableHiveSupport(): this.type = synchronized {
if (hiveClassesArePresent) {
super.enableHiveSupport()
.config("spark.session.isolate.artifacts", "false")
} else {
throw new IllegalArgumentException(
"Unable to instantiate SparkSession with Hive support because " +
Expand Down
21 changes: 12 additions & 9 deletions sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
import org.apache.spark.sql.expressions.{SparkUserDefinedFunction, UserDefinedAggregateFunction, UserDefinedAggregator, UserDefinedFunction}
import org.apache.spark.sql.internal.UserDefinedFunctionUtils.toScalaUDF
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

/**
* Functions for registering user-defined functions. Use `SparkSession.udf` to access this:
Expand All @@ -44,7 +43,7 @@ import org.apache.spark.util.Utils
* @since 1.3.0
*/
@Stable
class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
class UDFRegistration private[sql] (session: SparkSession, functionRegistry: FunctionRegistry)
extends api.UDFRegistration
with Logging {
protected[sql] def registerPython(name: String, udf: UserDefinedPythonFunction): Unit = {
Expand Down Expand Up @@ -121,7 +120,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)
*/
private[sql] def registerJavaUDAF(name: String, className: String): Unit = {
try {
val clazz = Utils.classForName[AnyRef](className)
val clazz = session.artifactManager.classloader.loadClass(className)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One follow-up here would be to cache the ArtifactManager classloader. I think we create that thing over and over.

Copy link
Contributor Author

@xupefei xupefei Sep 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. We can invalidate the cache when a new JAR is added.

if (!classOf[UserDefinedAggregateFunction].isAssignableFrom(clazz)) {
throw QueryCompilationErrors
.classDoesNotImplementUserDefinedAggregateFunctionError(className)
Expand All @@ -138,16 +137,20 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry)

// scalastyle:off line.size.limit
/**
* Register a Java UDF class using reflection, for use from pyspark
* Register a Java UDF class using it's class name. The class must implement one of the UDF
* interfaces in the [[org.apache.spark.sql.api.java]] package, and discoverable by the current
* session's class loader.
*
* @param name udf name
* @param className fully qualified class name of udf
* @param returnDataType return type of udf. If it is null, spark would try to infer
* @param name Name of the UDF.
* @param className Fully qualified class name of the UDF.
* @param returnDataType Return type of UDF. If it is `null`, Spark would try to infer
* via reflection.
*
* @since 4.0.0
*/
private[sql] def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
def registerJava(name: String, className: String, returnDataType: DataType): Unit = {
Comment on lines -148 to +151
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to make this method public so I can call it from REPL.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not against this. I am trying to understand the user facing consequences though. I'd probably prefer that we add support for Scala UDFs as well. That can be done in a follow-up though.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you file a follow-up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

try {
val clazz = Utils.classForName[AnyRef](className)
val clazz = session.artifactManager.classloader.loadClass(className)
val udfInterfaces = clazz.getGenericInterfaces
.filter(_.isInstanceOf[ParameterizedType])
.map(_.asInstanceOf[ParameterizedType])
Expand Down
Loading