Skip to content

Commit

Permalink
[SPARK-49694][PYTHON][CONNECT] Support scatter plots
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Support scatter plots with plotly backend on both Spark Connect and Spark classic.

### Why are the changes needed?
While Pandas on Spark supports plotting, PySpark currently lacks this feature. The proposed API will enable users to generate visualizations. This will provide users with an intuitive, interactive way to explore and understand large datasets directly from PySpark DataFrames, streamlining the data analysis workflow in distributed environments.

See more at [PySpark Plotting API Specification](https://docs.google.com/document/d/1IjOEzC8zcetG86WDvqkereQPj_NGLNW7Bdu910g30Dg/edit?usp=sharing) in progress.

Part of https://issues.apache.org/jira/browse/SPARK-49530.

### Does this PR introduce _any_ user-facing change?
Yes. Scatter plots are supported as shown below.

```py
>>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
>>> columns = ["length", "width", "species"]
>>> sdf = spark.createDataFrame(data, columns)
>>> fig = sdf.plot(kind="scatter", x="length", y="width")  # or fig = sdf.plot.scatter(x="length", y="width")
>>> fig.show()
```
![newplot (6)](https://github.com/user-attachments/assets/deef452b-74d1-4f6d-b1ae-60722f3c2b17)

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48219 from xinrong-meng/plot_scatter.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Xinrong Meng <[email protected]>
  • Loading branch information
xinrong-meng committed Sep 24, 2024
1 parent 438a6e7 commit 6bdd151
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
34 changes: 34 additions & 0 deletions python/pyspark/sql/plot/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class PySparkPlotAccessor:
"bar": PySparkTopNPlotBase().get_top_n,
"barh": PySparkTopNPlotBase().get_top_n,
"line": PySparkSampledPlotBase().get_sampled,
"scatter": PySparkSampledPlotBase().get_sampled,
}
_backends = {} # type: ignore[var-annotated]

Expand Down Expand Up @@ -230,3 +231,36 @@ def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure":
... ) # doctest: +SKIP
"""
return self(kind="barh", x=x, y=y, **kwargs)

def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure":
"""
Create a scatter plot with varying marker point size and color.
The coordinates of each point are defined by two dataframe columns and
filled circles are used to represent each point. This kind of plot is
useful to see complex correlations between two variables. Points could
be for instance natural 2D coordinates like longitude and latitude in
a map or, in general, any pair of metrics that can be plotted against
each other.
Parameters
----------
x : str
Name of column to use as horizontal coordinates for each point.
y : str or list of str
Name of column to use as vertical coordinates for each point.
**kwargs: Optional
Additional keyword arguments.
Returns
-------
:class:`plotly.graph_objs.Figure`
Examples
--------
>>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
>>> columns = ['length', 'width', 'species']
>>> df = spark.createDataFrame(data, columns)
>>> df.plot.scatter(x='length', y='width') # doctest: +SKIP
"""
return self(kind="scatter", x=x, y=y, **kwargs)
19 changes: 19 additions & 0 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ def sdf(self):
columns = ["category", "int_val", "float_val"]
return self.spark.createDataFrame(data, columns)

@property
def sdf2(self):
data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)]
columns = ["length", "width", "species"]
return self.spark.createDataFrame(data, columns)

def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""):
if kind == "line":
self.assertEqual(fig_data["mode"], "lines")
Expand All @@ -37,6 +43,9 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=
elif kind == "barh":
self.assertEqual(fig_data["type"], "bar")
self.assertEqual(fig_data["orientation"], "h")
elif kind == "scatter":
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")

self.assertEqual(fig_data["xaxis"], "x")
self.assertEqual(list(fig_data["x"]), expected_x)
Expand Down Expand Up @@ -79,6 +88,16 @@ def test_barh_plot(self):
self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val")
self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val")

def test_scatter_plot(self):
fig = self.sdf2.plot(kind="scatter", x="length", y="width")
self._check_fig_data(
"scatter", fig["data"][0], [5.1, 4.9, 7.0, 6.4, 5.9], [3.5, 3.0, 3.2, 3.2, 3.0]
)
fig = self.sdf2.plot.scatter(x="width", y="length")
self._check_fig_data(
"scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9]
)


class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase):
pass
Expand Down

0 comments on commit 6bdd151

Please sign in to comment.