Skip to content

Commit

Permalink
[SPARK-49609][PYTHON][CONNECT] Add API compatibility check between Cl…
Browse files Browse the repository at this point in the history
…assic and Connect

### What changes were proposed in this pull request?

This PR proposes to add API compatibility check between Classic and Connect.

This PR also includes updating both APIs to the same signature.

### Why are the changes needed?

APIs supported on both Spark Connect and Spark Classic should guarantee the same signature, such as argument and return types.

For example, test would fail when the signature of API is mismatched:

```
Signature mismatch in Column method 'dropFields'
Classic: (self, *fieldNames: str) -> pyspark.sql.column.Column
Connect: (self, *fieldNames: 'ColumnOrName') -> pyspark.sql.column.Column
<Signature (self, *fieldNames: 'ColumnOrName') -> pyspark.sql.column.Column> != <Signature (self, *fieldNames: str) -> pyspark.sql.column.Column>

Expected :<Signature (self, *fieldNames: str) -> pyspark.sql.column.Column>
Actual   :<Signature (self, *fieldNames: 'ColumnOrName') -> pyspark.sql.column.Column>
```

### Does this PR introduce _any_ user-facing change?

No, it is a test to prevent future API behavior inconsistencies between Classic and Connect.

### How was this patch tested?

Added UTs.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #48085 from itholic/SPARK-49609.

Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Haejoon Lee <[email protected]>
  • Loading branch information
itholic committed Sep 24, 2024
1 parent 6bdd151 commit 982028e
Show file tree
Hide file tree
Showing 5 changed files with 209 additions and 15 deletions.
1 change: 1 addition & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __hash__(self):
"pyspark.sql.tests.test_resources",
"pyspark.sql.tests.plot.test_frame_plot",
"pyspark.sql.tests.plot.test_frame_plot_plotly",
"pyspark.sql.tests.test_connect_compatibility",
],
)

Expand Down
6 changes: 3 additions & 3 deletions python/pyspark/sql/classic/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,7 @@ def selectExpr(self, *expr: Union[str, List[str]]) -> ParentDataFrame:
jdf = self._jdf.selectExpr(self._jseq(expr))
return DataFrame(jdf, self.sparkSession)

def filter(self, condition: "ColumnOrName") -> ParentDataFrame:
def filter(self, condition: Union[Column, str]) -> ParentDataFrame:
if isinstance(condition, str):
jdf = self._jdf.filter(condition)
elif isinstance(condition, Column):
Expand Down Expand Up @@ -1809,10 +1809,10 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ign
def drop_duplicates(self, subset: Optional[List[str]] = None) -> ParentDataFrame:
return self.dropDuplicates(subset)

def writeTo(self, table: str) -> DataFrameWriterV2:
def writeTo(self, table: str) -> "DataFrameWriterV2":
return DataFrameWriterV2(self, table)

def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter:
def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter":
return MergeIntoWriter(self, table, condition)

def pandas_api(
Expand Down
26 changes: 15 additions & 11 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,7 @@ def groupby(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
def groupby(self, __cols: Union[List[Column], List[str], List[int]]) -> "GroupedData":
...

def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> GroupedData:
def groupBy(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData":
if len(cols) == 1 and isinstance(cols[0], list):
cols = cols[0]

Expand Down Expand Up @@ -570,7 +570,7 @@ def rollup(self, *cols: "ColumnOrName") -> "GroupedData":
def rollup(self, __cols: Union[List[Column], List[str]]) -> "GroupedData":
...

def rollup(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
def rollup(self, *cols: "ColumnOrNameOrOrdinal") -> "GroupedData": # type: ignore[misc]
_cols: List[Column] = []
for c in cols:
if isinstance(c, Column):
Expand Down Expand Up @@ -731,8 +731,8 @@ def _convert_col(df: ParentDataFrame, col: "ColumnOrName") -> Column:
session=self._session,
)

def limit(self, n: int) -> ParentDataFrame:
res = DataFrame(plan.Limit(child=self._plan, limit=n), session=self._session)
def limit(self, num: int) -> ParentDataFrame:
res = DataFrame(plan.Limit(child=self._plan, limit=num), session=self._session)
res._cached_schema = self._cached_schema
return res

Expand Down Expand Up @@ -931,7 +931,11 @@ def _show_string(
)._to_table()
return table[0][0].as_py()

def withColumns(self, colsMap: Dict[str, Column]) -> ParentDataFrame:
def withColumns(self, *colsMap: Dict[str, Column]) -> ParentDataFrame:
# Below code is to help enable kwargs in future.
assert len(colsMap) == 1
colsMap = colsMap[0] # type: ignore[assignment]

if not isinstance(colsMap, dict):
raise PySparkTypeError(
errorClass="NOT_DICT",
Expand Down Expand Up @@ -1256,7 +1260,7 @@ def intersectAll(self, other: ParentDataFrame) -> ParentDataFrame:
res._cached_schema = self._merge_cached_schema(other)
return res

def where(self, condition: Union[Column, str]) -> ParentDataFrame:
def where(self, condition: "ColumnOrName") -> ParentDataFrame:
if not isinstance(condition, (str, Column)):
raise PySparkTypeError(
errorClass="NOT_COLUMN_OR_STR",
Expand Down Expand Up @@ -2193,18 +2197,18 @@ def cb(ei: "ExecutionInfo") -> None:

return DataFrameWriterV2(self._plan, self._session, table, cb)

def mergeInto(self, table: str, condition: Column) -> MergeIntoWriter:
def mergeInto(self, table: str, condition: Column) -> "MergeIntoWriter":
def cb(ei: "ExecutionInfo") -> None:
self._execution_info = ei

return MergeIntoWriter(
self._plan, self._session, table, condition, cb # type: ignore[arg-type]
)

def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)
def offset(self, num: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=num), session=self._session)

def checkpoint(self, eager: bool = True) -> "DataFrame":
def checkpoint(self, eager: bool = True) -> ParentDataFrame:
cmd = plan.Checkpoint(child=self._plan, local=False, eager=eager)
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
Expand All @@ -2214,7 +2218,7 @@ def checkpoint(self, eager: bool = True) -> "DataFrame":
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
def localCheckpoint(self, eager: bool = True) -> ParentDataFrame:
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties, self._execution_info = self._session.client.execute_command(
cmd.command(self._session.client)
Expand Down
3 changes: 2 additions & 1 deletion python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
from pyspark.sql.udf import UDFRegistration
from pyspark.sql.udtf import UDTFRegistration
from pyspark.sql.datasource import DataSourceRegistration
from pyspark.sql.dataframe import DataFrame as ParentDataFrame

# Running MyPy type checks will always require pandas and
# other dependencies so importing here is fine.
Expand Down Expand Up @@ -1641,7 +1642,7 @@ def prepare(obj: Any) -> Any:

def sql(
self, sqlQuery: str, args: Optional[Union[Dict[str, Any], List]] = None, **kwargs: Any
) -> DataFrame:
) -> "ParentDataFrame":
"""Returns a :class:`DataFrame` representing the result of the given query.
When ``kwargs`` is specified, this method formats the given string by using the Python
standard formatter. The method binds named parameters to SQL literals or
Expand Down
188 changes: 188 additions & 0 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import unittest
import inspect

from pyspark.testing.sqlutils import ReusedSQLTestCase
from pyspark.sql.classic.dataframe import DataFrame as ClassicDataFrame
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
from pyspark.sql.classic.column import Column as ClassicColumn
from pyspark.sql.connect.column import Column as ConnectColumn
from pyspark.sql.session import SparkSession as ClassicSparkSession
from pyspark.sql.connect.session import SparkSession as ConnectSparkSession


class ConnectCompatibilityTestsMixin:
def get_public_methods(self, cls):
"""Get public methods of a class."""
return {
name: method
for name, method in inspect.getmembers(cls, predicate=inspect.isfunction)
if not name.startswith("_")
}

def get_public_properties(self, cls):
"""Get public properties of a class."""
return {
name: member
for name, member in inspect.getmembers(cls)
if isinstance(member, property) and not name.startswith("_")
}

def test_signature_comparison_between_classic_and_connect(self):
def compare_method_signatures(classic_cls, connect_cls, cls_name):
"""Compare method signatures between classic and connect classes."""
classic_methods = self.get_public_methods(classic_cls)
connect_methods = self.get_public_methods(connect_cls)

common_methods = set(classic_methods.keys()) & set(connect_methods.keys())

for method in common_methods:
classic_signature = inspect.signature(classic_methods[method])
connect_signature = inspect.signature(connect_methods[method])

# createDataFrame cannot be the same since RDD is not supported from Spark Connect
if not method == "createDataFrame":
self.assertEqual(
classic_signature,
connect_signature,
f"Signature mismatch in {cls_name} method '{method}'\n"
f"Classic: {classic_signature}\n"
f"Connect: {connect_signature}",
)

# DataFrame API signature comparison
compare_method_signatures(ClassicDataFrame, ConnectDataFrame, "DataFrame")

# Column API signature comparison
compare_method_signatures(ClassicColumn, ConnectColumn, "Column")

# SparkSession API signature comparison
compare_method_signatures(ClassicSparkSession, ConnectSparkSession, "SparkSession")

def test_property_comparison_between_classic_and_connect(self):
def compare_property_lists(classic_cls, connect_cls, cls_name, expected_missing_properties):
"""Compare properties between classic and connect classes."""
classic_properties = self.get_public_properties(classic_cls)
connect_properties = self.get_public_properties(connect_cls)

# Identify missing properties
classic_only_properties = set(classic_properties.keys()) - set(
connect_properties.keys()
)

# Compare the actual missing properties with the expected ones
self.assertEqual(
classic_only_properties,
expected_missing_properties,
f"{cls_name}: Unexpected missing properties in Connect: {classic_only_properties}",
)

# Expected missing properties for DataFrame
expected_missing_properties_for_dataframe = {"sql_ctx", "isStreaming"}

# DataFrame properties comparison
compare_property_lists(
ClassicDataFrame,
ConnectDataFrame,
"DataFrame",
expected_missing_properties_for_dataframe,
)

# Expected missing properties for Column (if any, replace with actual values)
expected_missing_properties_for_column = set()

# Column properties comparison
compare_property_lists(
ClassicColumn, ConnectColumn, "Column", expected_missing_properties_for_column
)

# Expected missing properties for SparkSession
expected_missing_properties_for_spark_session = {"sparkContext", "version"}

# SparkSession properties comparison
compare_property_lists(
ClassicSparkSession,
ConnectSparkSession,
"SparkSession",
expected_missing_properties_for_spark_session,
)

def test_missing_methods(self):
def check_missing_methods(classic_cls, connect_cls, cls_name, expected_missing_methods):
"""Check for expected missing methods between classic and connect classes."""
classic_methods = self.get_public_methods(classic_cls)
connect_methods = self.get_public_methods(connect_cls)

# Identify missing methods
classic_only_methods = set(classic_methods.keys()) - set(connect_methods.keys())

# Compare the actual missing methods with the expected ones
self.assertEqual(
classic_only_methods,
expected_missing_methods,
f"{cls_name}: Unexpected missing methods in Connect: {classic_only_methods}",
)

# Expected missing methods for DataFrame
expected_missing_methods_for_dataframe = {
"inputFiles",
"isLocal",
"semanticHash",
"isEmpty",
}

# DataFrame missing method check
check_missing_methods(
ClassicDataFrame, ConnectDataFrame, "DataFrame", expected_missing_methods_for_dataframe
)

# Expected missing methods for Column (if any, replace with actual values)
expected_missing_methods_for_column = set()

# Column missing method check
check_missing_methods(
ClassicColumn, ConnectColumn, "Column", expected_missing_methods_for_column
)

# Expected missing methods for SparkSession (if any, replace with actual values)
expected_missing_methods_for_spark_session = {"newSession"}

# SparkSession missing method check
check_missing_methods(
ClassicSparkSession,
ConnectSparkSession,
"SparkSession",
expected_missing_methods_for_spark_session,
)


class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
pass


if __name__ == "__main__":
from pyspark.sql.tests.test_connect_compatibility import * # noqa: F401

try:
import xmlrunner # type: ignore

testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)

0 comments on commit 982028e

Please sign in to comment.