Skip to content

Commit

Permalink
test passed
Browse files Browse the repository at this point in the history
  • Loading branch information
jingz-db committed Oct 11, 2024
1 parent 9e3bc77 commit 2876875
Show file tree
Hide file tree
Showing 13 changed files with 215 additions and 1,015 deletions.
59 changes: 34 additions & 25 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# limitations under the License.
#
import sys
import itertools
from typing import Any, Iterator, List, Union, TYPE_CHECKING, cast
import warnings

Expand Down Expand Up @@ -403,6 +404,9 @@ def transformWithStateInPandas(
The output mode of the stateful processor.
timeMode : str
The time mode semantics of the stateful processor for timers and TTL.
initialState: "GroupedData"
Optional. The grouped dataframe on given grouping key as initial states used for initialization
of state variables in the first batch.
Examples
--------
Expand Down Expand Up @@ -487,6 +491,10 @@ def transformWithStateInPandas(
from pyspark.sql.functions import pandas_udf

assert isinstance(self, GroupedData)
if initialState is not None:
assert isinstance(initialState, GroupedData)
if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
Expand All @@ -510,64 +518,65 @@ def transformWithStateWithInitStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
initialStates: Iterator["PandasDataFrameLike"]
# for non first batch, initialStates will be None
initialStates: Iterator["PandasDataFrameLike"] = None
) -> Iterator["PandasDataFrameLike"]:
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
statefulProcessor.init(handle)
# only process initial state if first batch
is_first_batch = statefulProcessorApiClient.is_first_batch()
if is_first_batch:
initial_state_iter = initialStates
# if we don't have initial state for the given key, iterator could be None
if initial_state_iter is not None:
for cur_initial_state in initial_state_iter:
print(f"got initial state here for key: {key},"
f" initial state: {cur_initial_state}\n")
if cur_initial_state.empty:
print(f"got empty initial state here for key: {key},"
f" initial state: {cur_initial_state}\n")
else:
statefulProcessorApiClient.set_implicit_key(key)
statefulProcessor.handleInitialState(key, cur_initial_state)
statefulProcessorApiClient.remove_implicit_key()
if is_first_batch and initialStates is not None:
seen_init_state_on_key = False
for cur_initial_state in initialStates:
if seen_init_state_on_key:
raise Exception(f"TransformWithStateWithInitState: Cannot have more "
f"than one row in the initial states for the same key. "
f"Grouping key: {key}.")
statefulProcessorApiClient.set_implicit_key(key)
statefulProcessor.handleInitialState(key, cur_initial_state)
seen_init_state_on_key = True
statefulProcessorApiClient.set_handle_state(
StatefulProcessorHandleState.INITIALIZED
)

# if we don't have state for the given key, iterator could be None
if inputRows is not None:
# if we don't have state for the given key but in initial state,
# iterator could be None
input_rows_empty = False
try:
first = next(inputRows)
except StopIteration:
input_rows_empty = True
else:
inputRows = itertools.chain([first], inputRows)

if not input_rows_empty:
statefulProcessorApiClient.set_implicit_key(key)
result = statefulProcessor.handleInputRows(key, inputRows)
else:
result = iter([])

return result

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))

df = self._df

if initialState is None:
initial_state_java_obj = None

udf = pandas_udf(
transformWithStateUDF, # type: ignore
returnType=outputStructType,
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_UDF,
)
udf_column = udf(*[df[col] for col in df.columns])
else:
print(f"I am here, not empty initial state\n")
initial_state_java_obj = initialState._jgd

udf = pandas_udf(
transformWithStateWithInitStateUDF, # type: ignore
returnType=outputStructType,
functionType=PythonEvalType.SQL_TRANSFORM_WITH_STATE_PANDAS_INIT_STATE_UDF,
)
udf_column = udf(*[df[col] for col in df.columns])

udf_column = udf(*[df[col] for col in df.columns])
jdf = self._jgd.transformWithStateInPandas(
udf_column._jc,
self.session._jsparkSession.parseDataType(outputStructType.json()),
Expand Down
65 changes: 54 additions & 11 deletions python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,7 +1148,6 @@ def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_p
)
self.arrow_max_records_per_batch = arrow_max_records_per_batch
self.key_offsets = None
self.init_key_offsets = None

def load_stream(self, stream):
"""
Expand All @@ -1163,13 +1162,64 @@ def load_stream(self, stream):
def generate_data_batches(batches):
"""
Deserialize ArrowRecordBatches and return a generator of pandas.Series list.
The deserialization logic assumes that Arrow RecordBatches contain the data with the
ordering that data chunks for same grouping key will appear sequentially.
This function must avoid materializing multiple Arrow RecordBatches into memory at the
same time. And data chunks from the same grouping key should appear sequentially.
"""
for batch in batches:
data_pandas = [
self.arrow_to_pandas(c) for c in pa.Table.from_batches([batch]).itercolumns()
]
key_series = [data_pandas[o] for o in self.key_offsets]
batch_key = tuple(s[0] for s in key_series)
yield (batch_key, data_pandas)

_batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
data_batches = generate_data_batches(_batches)

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)

def dump_stream(self, iterator, stream):
"""
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow
RecordBatches, and write batches to stream.
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
super().dump_stream(result, stream)


class TransformWithStateInPandasInitStateSerializer(TransformWithStateInPandasSerializer):
"""
Serializer used by Python worker to evaluate UDF for
:meth:`pyspark.sql.GroupedData.transformWithStateInPandasInitStateSerializer`.
Parameters
----------
Same as input parameters in TransformWithStateInPandasSerializer.
"""

def __init__(self, timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch):
super(TransformWithStateInPandasInitStateSerializer, self).__init__(
timezone, safecheck, assign_cols_by_name, arrow_max_records_per_batch
)
self.init_key_offsets = None

def load_stream(self, stream):
import pyarrow as pa

def generate_data_batches(batches):
"""
Deserialize ArrowRecordBatches and return a generator of pandas.Series list.
The deserialization logic assumes that Arrow RecordBatches contain the data with the
ordering that data chunks for same grouping key will appear sequentially.
See `TransformWithStateInPandasPythonBaseRunner` for arrow batch schema sent from JVM.
This function flatten the columns of input rows and initial state rows and feed them into
the data generator.
"""
def flatten_columns(cur_batch, col_name):
state_column = cur_batch.column(cur_batch.schema.get_field_index(col_name))
state_field_names = [state_column.type[i].name for i in range(state_column.type.num_fields)]
Expand All @@ -1189,6 +1239,7 @@ def flatten_columns(cur_batch, col_name):
]
key_series = [data_pandas[o] for o in self.key_offsets]
init_key_series = [init_data_pandas[o] for o in self.init_key_offsets]

if any(s.empty for s in key_series):
# If any row is empty, assign batch_key using init_key_series
batch_key = tuple(s[0] for s in init_key_series)
Expand All @@ -1202,11 +1253,3 @@ def flatten_columns(cur_batch, col_name):

for k, g in groupby(data_batches, key=lambda x: x[0]):
yield (k, g)

def dump_stream(self, iterator, stream):
"""
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow
RecordBatches, and write batches to stream.
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
super().dump_stream(result, stream)
16 changes: 7 additions & 9 deletions python/pyspark/sql/streaming/StateMessage_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2876875

Please sign in to comment.