Skip to content

Commit

Permalink
[SPARK-48755][SS][PYTHON] transformWithState pyspark base implementat…
Browse files Browse the repository at this point in the history
…ion and ValueState support

### What changes were proposed in this pull request?

- Base implementation for Python State V2
- Implemented ValueState

Below we specifically highlight some key files/components for this change:
- Python
  - `group_ops.py`: defines transformWithStateInPandas function and its udf.
  - `serializer.py`: defines how we load and dump arrow streams for data rows between the JVM and Python process.
  - `stateful_processor.py`: defines StatefulProcessorHandle, ValueState functionalities and StatefulProcessor interface.
  - `state_api_client.py` and `value_state_client.py`: contains logics to send API request in protobuf format to the server (JVM)
- Scala
  - `TransformWithStateInPandasExec`: physical operator for `TransformWithStateInPandas`.
  - `TransformWithStateInPandasPythonRunner`: python runner that launches python worker that executes the udf.
  - `TransformWithStateInPandasStateServer`: class that handles state requests in protobuf format from python side.

### Why are the changes needed?

Support Python State V2 API

### Does this PR introduce _any_ user-facing change?

Yes

### How was this patch tested?

Did local integration test with below command
```
import pandas as pd
from pyspark.sql import Row
from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, StructField, LongType, StringType
from typing import Iterator
spark.conf.set("spark.sql.streaming.stateStore.providerClass","org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
spark.conf.set("spark.sql.shuffle.partitions","1")
output_schema = StructType([
    StructField("value", LongType(), True)
])
state_schema = StructType([
    StructField("id", LongType(), True),
    StructField("value", StringType(), True),
    StructField("comment", StringType(), True)
])

class SimpleStatefulProcessor(StatefulProcessor):
  def init(self, handle: StatefulProcessorHandle) -> None:
    self.value_state = handle.getValueState("testValueState", state_schema)
  def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
    self.value_state.update((1,"test_value","comment"))
    exists = self.value_state.exists()
    print(f"value state exists: {exists}")
    value = self.value_state.get()
    print(f"get value: {value}")
    print("clearing value state")
    self.value_state.clear()
    print("value state cleared")
    return rows
  def close(self) -> None:
    pass

q = spark.readStream.format("rate").option("rowsPerSecond", "1").option("numPartitions", "1").load().groupBy("value").transformWithStateInPandas(stateful_processor = SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update", timeMode="None").writeStream.format("console").option("checkpointLocation", "/tmp/streaming/temp_ckp").outputMode("update").start()
```
Verified from the logs that value state methods work as expected for key `11`
```
value state exists: True
get value:    id       value  comment
0   1  test_value  comment
clearing value state
value state cleared
```

Will add unit test
### Was this patch authored or co-authored using generative AI tooling?

No

Closes #47133 from bogao007/state-v2-initial.

Lead-authored-by: bogao007 <[email protected]>
Co-authored-by: Bhuwan Sahni <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
2 people authored and HeartSaVioR committed Aug 15, 2024
1 parent b9fbdf0 commit def42d4
Show file tree
Hide file tree
Showing 31 changed files with 12,153 additions and 5 deletions.
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -3919,6 +3919,12 @@
],
"sqlState" : "42802"
},
"STATEFUL_PROCESSOR_UNKNOWN_TIME_MODE" : {
"message" : [
"Unknown time mode <timeMode>. Accepted timeMode modes are 'none', 'processingTime', 'eventTime'"
],
"sqlState" : "42802"
},
"STATE_STORE_CANNOT_CREATE_COLUMN_FAMILY_WITH_RESERVED_CHARS" : {
"message" : [
"Failed to create column family with unsupported starting character and name=<colFamilyName>."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ private[spark] object LogKeys {
case object START_INDEX extends LogKey
case object START_TIME extends LogKey
case object STATEMENT_ID extends LogKey
case object STATE_NAME extends LogKey
case object STATE_STORE_ID extends LogKey
case object STATE_STORE_PROVIDER extends LogKey
case object STATE_STORE_VERSION extends LogKey
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ private[spark] object PythonEvalType {
val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
val SQL_GROUPED_MAP_ARROW_UDF = 209
val SQL_COGROUPED_MAP_ARROW_UDF = 210
val SQL_TRANSFORM_WITH_STATE_PANDAS_UDF = 211

val SQL_TABLE_UDF = 300
val SQL_ARROW_TABLE_UDF = 301
Expand All @@ -82,6 +83,7 @@ private[spark] object PythonEvalType {
case SQL_COGROUPED_MAP_ARROW_UDF => "SQL_COGROUPED_MAP_ARROW_UDF"
case SQL_TABLE_UDF => "SQL_TABLE_UDF"
case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
case SQL_TRANSFORM_WITH_STATE_PANDAS_UDF => "SQL_TRANSFORM_WITH_STATE_PANDAS_UDF"
}
}

Expand Down
2 changes: 2 additions & 0 deletions dev/checkstyle-suppressions.xml
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@
files="src/main/java/org/apache/spark/network/util/LimitedInputStream.java" />
<suppress checks="Header"
files="src/test/java/org/apache/spark/util/collection/TestTimSort.java" />
<suppress checks=".*"
files="src/main/java/org/apache/spark/sql/execution/streaming/state/StateMessage.java"/>
</suppressions>
1 change: 1 addition & 0 deletions python/docs/source/reference/pyspark.sql/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ This page gives an overview of all public Spark SQL API.
variant_val
protobuf
datasource
stateful_processor
29 changes: 29 additions & 0 deletions python/docs/source/reference/pyspark.sql/stateful_processor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
.. Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
.. http://www.apache.org/licenses/LICENSE-2.0
.. Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
==================
Stateful Processor
==================
.. currentmodule:: pyspark.sql.streaming

.. autosummary::
:toctree: api/

StatefulProcessor.init
StatefulProcessor.handleInputRows
StatefulProcessor.close
1 change: 1 addition & 0 deletions python/pyspark/sql/pandas/_typing/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ ArrowMapIterUDFType = Literal[207]
PandasGroupedMapUDFWithStateType = Literal[208]
ArrowGroupedMapUDFType = Literal[209]
ArrowCogroupedMapUDFType = Literal[210]
PandasGroupedMapUDFTransformWithStateType = Literal[211]

class PandasVariadicScalarToScalarFunction(Protocol):
def __call__(self, *_: DataFrameOrSeriesLike_) -> DataFrameOrSeriesLike_: ...
Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/pandas/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ def calculate(iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
None,
Expand Down Expand Up @@ -453,6 +454,7 @@ def _validate_pandas_udf(f, evalType) -> int:
PythonEvalType.SQL_MAP_ARROW_ITER_UDF,
PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE,
PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
PythonEvalType.SQL_GROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_COGROUPED_MAP_ARROW_UDF,
PythonEvalType.SQL_ARROW_BATCHED_UDF,
Expand Down
174 changes: 173 additions & 1 deletion python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,19 @@
# limitations under the License.
#
import sys
from typing import List, Union, TYPE_CHECKING, cast
from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast
import warnings

from pyspark.errors import PySparkTypeError
from pyspark.util import PythonEvalType
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.streaming.state import GroupStateTimeout
from pyspark.sql.streaming.stateful_processor_api_client import (
StatefulProcessorApiClient,
StatefulProcessorHandleState,
)
from pyspark.sql.streaming.stateful_processor import StatefulProcessor, StatefulProcessorHandle
from pyspark.sql.types import StructType, _parse_datatype_string

if TYPE_CHECKING:
Expand All @@ -33,6 +38,7 @@
PandasCogroupedMapFunction,
ArrowGroupedMapFunction,
ArrowCogroupedMapFunction,
DataFrameLike as PandasDataFrameLike,
)
from pyspark.sql.group import GroupedData

Expand Down Expand Up @@ -358,6 +364,172 @@ def applyInPandasWithState(
)
return DataFrame(jdf, self.session)

def transformWithStateInPandas(
self,
statefulProcessor: StatefulProcessor,
outputStructType: Union[StructType, str],
outputMode: str,
timeMode: str,
) -> DataFrame:
"""
Invokes methods defined in the stateful processor used in arbitrary state API v2. It
requires protobuf, pandas and pyarrow as dependencies to process input/state data. We
allow the user to act on per-group set of input rows along with keyed state and the user
can choose to output/return 0 or more rows.
For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
in each trigger and the user's state/state variables will be stored persistently across
invocations.
The `statefulProcessor` should be a Python class that implements the interface defined in
:class:`StatefulProcessor`.
The `outputStructType` should be a :class:`StructType` describing the schema of all
elements in the returned value, `pandas.DataFrame`. The column labels of all elements in
returned `pandas.DataFrame` must either match the field names in the defined schema if
specified as strings, or match the field data types by position if not strings,
e.g. integer indices.
The size of each `pandas.DataFrame` in both the input and output can be arbitrary. The
number of `pandas.DataFrame` in both the input and output can also be arbitrary.
.. versionadded:: 4.0.0
Parameters
----------
statefulProcessor : :class:`pyspark.sql.streaming.stateful_processor.StatefulProcessor`
Instance of StatefulProcessor whose functions will be invoked by the operator.
outputStructType : :class:`pyspark.sql.types.DataType` or str
The type of the output records. The value can be either a
:class:`pyspark.sql.types.DataType` object or a DDL-formatted type string.
outputMode : str
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and TTL.
Examples
--------
>>> from typing import Iterator
...
>>> import pandas as pd # doctest: +SKIP
...
>>> from pyspark.sql import Row
>>> from pyspark.sql.functions import col, split
>>> from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
>>> from pyspark.sql.types import IntegerType, LongType, StringType, StructField, StructType
...
>>> spark.conf.set("spark.sql.streaming.stateStore.providerClass",
... "org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
... # Below is a simple example to find erroneous sensors from temperature sensor data. The
... # processor returns a count of total readings, while keeping erroneous reading counts
... # in streaming state. A violation is defined when the temperature is above 100.
... # The input data is a DataFrame with the following schema:
... # `id: string, temperature: long`.
... # The output schema and state schema are defined as below.
>>> output_schema = StructType([
... StructField("id", StringType(), True),
... StructField("count", IntegerType(), True)
... ])
>>> state_schema = StructType([
... StructField("value", IntegerType(), True)
... ])
>>> class SimpleStatefulProcessor(StatefulProcessor):
... def init(self, handle: StatefulProcessorHandle):
... self.num_violations_state = handle.getValueState("numViolations", state_schema)
...
... def handleInputRows(self, key, rows):
... new_violations = 0
... count = 0
... exists = self.num_violations_state.exists()
... if exists:
... existing_violations_row = self.num_violations_state.get()
... existing_violations = existing_violations_row[0]
... else:
... existing_violations = 0
... for pdf in rows:
... pdf_count = pdf.count()
... count += pdf_count.get('temperature')
... violations_pdf = pdf.loc[pdf['temperature'] > 100]
... new_violations += violations_pdf.count().get('temperature')
... updated_violations = new_violations + existing_violations
... self.num_violations_state.update((updated_violations,))
... yield pd.DataFrame({'id': key, 'count': count})
...
... def close(self) -> None:
... pass
Input DataFrame:
+---+-----------+
| id|temperature|
+---+-----------+
| 0| 123|
| 0| 23|
| 1| 33|
| 1| 188|
| 1| 88|
+---+-----------+
>>> df.groupBy("value").transformWithStateInPandas(statefulProcessor =
... SimpleStatefulProcessor(), outputStructType=output_schema, outputMode="Update",
... timeMode="None") # doctest: +SKIP
Output DataFrame:
+---+-----+
| id|count|
+---+-----+
| 0| 2|
| 1| 3|
+---+-----+
Notes
-----
This function requires a full shuffle.
This API is experimental.
"""

from pyspark.sql import GroupedData
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)

def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
statefulProcessor.init(handle)
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.INITIALIZED
)

statefulProcessorApiClient.set_implicit_key(key)
result = statefulProcessor.handleInputRows(key, inputRows)

return result

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

udf = pandas_udf(
transformWithStateUDF, # type: ignore
returnType=outputStructType,
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
df = self._df
udf_column = udf(*[df[col] for col in df.columns])

jdf = self._jgd.transformWithStateInPandas(
udf_column._jc.expr(),
self.session._jsparkSession.parseDataType(outputStructType.json()),
outputMode,
timeMode,
)
return DataFrame(jdf, self.session)

def applyInArrow(
self, func: "ArrowGroupedMapFunction", schema: Union[StructType, str]
) -> "DataFrame":
Expand Down
76 changes: 75 additions & 1 deletion python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,16 @@
Serializers for PyArrow and pandas conversions. See `pyspark.serializers` for more details.
"""

from itertools import groupby
from pyspark.errors import PySparkRuntimeError, PySparkTypeError, PySparkValueError
from pyspark.loose_version import LooseVersion
from pyspark.serializers import Serializer, read_int, write_int, UTF8Deserializer, CPickleSerializer
from pyspark.serializers import (
Serializer,
read_int,
write_int,
UTF8Deserializer,
CPickleSerializer,
)
from pyspark.sql.pandas.types import (
from_arrow_type,
to_arrow_type,
Expand Down Expand Up @@ -1116,3 +1123,70 @@ def init_stream_yield_batches(batches):
batches_to_write = init_stream_yield_batches(serialize_batches())

return ArrowStreamSerializer.dump_stream(self, batches_to_write, stream)


class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
"""
Serializer used by Python worker to evaluate UDF for
:meth:`pyspark.sql.GroupedData.transformWithStateInPandasSerializer`.
Parameters
----------
timezone : str
A timezone to respect when handling timestamp values
safecheck : bool
If True, conversion from Arrow to Pandas checks for overflow/truncation
assign_cols_by_name : bool
If True, then Pandas DataFrames will get columns by name
arrow_max_records_per_batch : int
Limit of the number of records that can be written to a single ArrowRecordBatch in memory.
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
super(TransformWithStateInPandasSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.key_offsets = None

def load_stream(self, stream):
"""
Read ArrowRecordBatches from stream, deserialize them to populate a list of data chunk, and
convert the data into a list of pandas.Series.
Please refer the doc of inner function `generate_data_batches` for more details how
this function works in overall.
"""
import pyarrow as pa

def generate_data_batches(batches):
"""
Deserialize ArrowRecordBatches and return a generator of pandas.Series list.
The deserialization logic assumes that Arrow RecordBatches contain the data with the
ordering that data chunks for same grouping key will appear sequentially.
This function must avoid materializing multiple Arrow RecordBatches into memory at the
same time. And data chunks from the same grouping key should appear sequentially.
"""
for batch in batches:
data_pandas = [
self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()
]
key_series = [data_pandas[o] for o in self.key_offsets]
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas)

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)

def dump_stream(self, iterator, stream):
"""
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow
RecordBatches, and write batches to stream.
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
super().dump_stream(result, stream)
Loading

0 comments on commit def42d4

Please sign in to comment.