Skip to content

Commit

Permalink
streaming transform sync
Browse files Browse the repository at this point in the history
Signed-off-by: Sidhant Kohli <[email protected]>
  • Loading branch information
kohlisid committed Oct 22, 2024
1 parent 57d7636 commit 71c7cfc
Show file tree
Hide file tree
Showing 6 changed files with 284 additions and 110 deletions.
2 changes: 1 addition & 1 deletion pynumaflow/sourcetransformer/multiproc_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pynumaflow.info.types import ServerInfo, MINIMUM_NUMAFLOW_VERSION, ContainerType
from pynumaflow.sourcetransformer.servicer.server import SourceTransformServicer
from pynumaflow.sourcetransformer.servicer.servicer import SourceTransformServicer

from pynumaflow.shared.server import start_multiproc_server

Expand Down
2 changes: 1 addition & 1 deletion pynumaflow/sourcetransformer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pynumaflow.shared import NumaflowServer
from pynumaflow.shared.server import sync_server_start
from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable
from pynumaflow.sourcetransformer.servicer.server import SourceTransformServicer
from pynumaflow.sourcetransformer.servicer.servicer import SourceTransformServicer


class SourceTransformServer(NumaflowServer):
Expand Down
73 changes: 0 additions & 73 deletions pynumaflow/sourcetransformer/servicer/server.py

This file was deleted.

150 changes: 150 additions & 0 deletions pynumaflow/sourcetransformer/servicer/servicer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import threading
from concurrent.futures import ThreadPoolExecutor
from collections.abc import Iterable

from google.protobuf import empty_pb2 as _empty_pb2
from google.protobuf import timestamp_pb2 as _timestamp_pb2

from pynumaflow.shared.server import exit_on_error
from pynumaflow.shared.synciter import SyncIterator
from pynumaflow.sourcetransformer import Datum
from pynumaflow.sourcetransformer._dtypes import SourceTransformCallable
from pynumaflow.proto.sourcetransformer import transform_pb2
from pynumaflow.proto.sourcetransformer import transform_pb2_grpc
from pynumaflow.types import NumaflowServicerContext
from pynumaflow._constants import _LOGGER, STREAM_EOF


def _create_read_handshake_response() -> transform_pb2.SourceTransformResponse:
"""
Create a handshake response for the SourceTransform function.
Returns:
SourceTransformResponse: A SourceTransformResponse object indicating a successful handshake.
"""
return transform_pb2.SourceTransformResponse(
handshake=transform_pb2.Handshake(sot=True),
)


class SourceTransformServicer(transform_pb2_grpc.SourceTransformServicer):
"""
This class is used to create a new grpc SourceTransform servicer instance.
It implements the SourceTransformServicer interface from the proto transform.proto file.
Provides the functionality for the required rpc methods.
"""

def __init__(self, handler: SourceTransformCallable, multiproc: bool = False):
self.__transform_handler: SourceTransformCallable = handler
# This indicates whether the grpc server attached is multiproc or not
self.multiproc = multiproc
# create a thread pool with 50 worker threads
self.executor = ThreadPoolExecutor(max_workers=50)

def SourceTransformFn(
self,
request_iterator: Iterable[transform_pb2.SourceTransformRequest],
context: NumaflowServicerContext,
) -> Iterable[transform_pb2.SourceTransformResponse]:
"""
Applies a function to each datum element.
The pascal case function name comes from the generated transform_pb2_grpc.py file.
"""

# proto repeated field(keys) is of type google._upb._message.RepeatedScalarContainer
# we need to explicitly convert it to list
try:
# The first message to be received should be a valid handshake
req = next(request_iterator)
# check if it is a valid handshake req
if not (req.handshake and req.handshake.sot):
raise Exception("SourceTransformFn: expected handshake message")
yield _create_read_handshake_response()

result_queue = SyncIterator()

# Reader thread to keep reading from the request iterator and once done close it.
reader_thread = threading.Thread(
target=self._process_requests, args=(context, request_iterator, result_queue)
)
reader_thread.start()

# Read the result queue and keep forwarding the results
for res in result_queue.read_iterator():
# if error check for that
if isinstance(res, BaseException):
# Terminate the current server process due to exception
exit_on_error(context, repr(res), parent=self.multiproc)
return
# keep returning the results back upstream
yield res
# yield _create_transform_response(res)

reader_thread.join()

except BaseException as err:
_LOGGER.critical("UDFError, re-raising the error", exc_info=True)
# Terminate the current server process due to exception
exit_on_error(context, repr(err), parent=self.multiproc)
return

def _process_requests(
self,
context: NumaflowServicerContext,
request_iterator: Iterable[transform_pb2.SourceTransformRequest],
result_queue: SyncIterator,
):
try:
# read through all incoming requests and submit to the
# threadpool for invocation
for request in request_iterator:
_ = self.executor.submit(self._invoke_transformer, context, request, result_queue)
# wait for all tasks to finish
self.executor.shutdown(wait=True)
# Indicate to the result queue that no more messages left
result_queue.put(STREAM_EOF)
except BaseException as e:
_LOGGER.critical("SourceTransformFnError, re-raising the error", exc_info=True)
result_queue.put(e)

def _invoke_transformer(
self, context, request: transform_pb2.SourceTransformRequest, result_queue: SyncIterator
):
try:
d = Datum(
keys=list(request.request.keys),
value=request.request.value,
event_time=request.request.event_time.ToDatetime(),
watermark=request.request.watermark.ToDatetime(),
headers=dict(request.request.headers),
)
responses = self.__transform_handler(list(request.request.keys), d)

results = []
for resp in responses:
event_time_timestamp = _timestamp_pb2.Timestamp()
event_time_timestamp.FromDatetime(dt=resp.event_time)
results.append(
transform_pb2.SourceTransformResponse.Result(
keys=list(resp.keys),
value=resp.value,
tags=resp.tags,
event_time=event_time_timestamp,
)
)
result_queue.put(
transform_pb2.SourceTransformResponse(results=results, id=request.request.id)
)
except BaseException as e:
_LOGGER.critical("SourceTransform handler error", exc_info=True)
result_queue.put(e)
return

def IsReady(
self, request: _empty_pb2.Empty, context: NumaflowServicerContext
) -> transform_pb2.ReadyResponse:
"""
IsReady is the heartbeat endpoint for gRPC.
The pascal case function name comes from the proto transform_pb2_grpc.py file.
"""
return transform_pb2.ReadyResponse(ready=True)
1 change: 0 additions & 1 deletion tests/sink/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ def test_udsink_err_handshake(self):
method.send_request(test_datums[0])

metadata, code, details = method.termination()
print("HERE", details)
self.assertTrue("UDSinkError: Exception('SinkFn: expected handshake message')" in details)
self.assertEqual(StatusCode.UNKNOWN, code)

Expand Down
Loading

0 comments on commit 71c7cfc

Please sign in to comment.