diff --git a/python/pyspark/sql/plot/core.py b/python/pyspark/sql/plot/core.py index eb00b8a04f977..0a3a0101e1898 100644 --- a/python/pyspark/sql/plot/core.py +++ b/python/pyspark/sql/plot/core.py @@ -96,6 +96,7 @@ class PySparkPlotAccessor: "bar": PySparkTopNPlotBase().get_top_n, "barh": PySparkTopNPlotBase().get_top_n, "line": PySparkSampledPlotBase().get_sampled, + "scatter": PySparkSampledPlotBase().get_sampled, } _backends = {} # type: ignore[var-annotated] @@ -230,3 +231,36 @@ def barh(self, x: str, y: Union[str, list[str]], **kwargs: Any) -> "Figure": ... ) # doctest: +SKIP """ return self(kind="barh", x=x, y=y, **kwargs) + + def scatter(self, x: str, y: str, **kwargs: Any) -> "Figure": + """ + Create a scatter plot with varying marker point size and color. + + The coordinates of each point are defined by two dataframe columns and + filled circles are used to represent each point. This kind of plot is + useful to see complex correlations between two variables. Points could + be for instance natural 2D coordinates like longitude and latitude in + a map or, in general, any pair of metrics that can be plotted against + each other. + + Parameters + ---------- + x : str + Name of column to use as horizontal coordinates for each point. + y : str or list of str + Name of column to use as vertical coordinates for each point. + **kwargs: Optional + Additional keyword arguments. + + Returns + ------- + :class:`plotly.graph_objs.Figure` + + Examples + -------- + >>> data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + >>> columns = ['length', 'width', 'species'] + >>> df = spark.createDataFrame(data, columns) + >>> df.plot.scatter(x='length', y='width') # doctest: +SKIP + """ + return self(kind="scatter", x=x, y=y, **kwargs) diff --git a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py index 1c52c93a23d3a..ccfe1a75424e0 100644 --- a/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py +++ b/python/pyspark/sql/tests/plot/test_frame_plot_plotly.py @@ -28,6 +28,12 @@ def sdf(self): columns = ["category", "int_val", "float_val"] return self.spark.createDataFrame(data, columns) + @property + def sdf2(self): + data = [(5.1, 3.5, 0), (4.9, 3.0, 0), (7.0, 3.2, 1), (6.4, 3.2, 1), (5.9, 3.0, 2)] + columns = ["length", "width", "species"] + return self.spark.createDataFrame(data, columns) + def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""): if kind == "line": self.assertEqual(fig_data["mode"], "lines") @@ -37,6 +43,9 @@ def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name= elif kind == "barh": self.assertEqual(fig_data["type"], "bar") self.assertEqual(fig_data["orientation"], "h") + elif kind == "scatter": + self.assertEqual(fig_data["type"], "scatter") + self.assertEqual(fig_data["orientation"], "v") self.assertEqual(fig_data["xaxis"], "x") self.assertEqual(list(fig_data["x"]), expected_x) @@ -79,6 +88,16 @@ def test_barh_plot(self): self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val") self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val") + def test_scatter_plot(self): + fig = self.sdf2.plot(kind="scatter", x="length", y="width") + self._check_fig_data( + "scatter", fig["data"][0], [5.1, 4.9, 7.0, 6.4, 5.9], [3.5, 3.0, 3.2, 3.2, 3.0] + ) + fig = self.sdf2.plot.scatter(x="width", y="length") + self._check_fig_data( + "scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9] + ) + class DataFramePlotPlotlyTests(DataFramePlotPlotlyTestsMixin, ReusedSQLTestCase): pass