Skip to content

Commit

Permalink
refactor: expose a Stream.get_middleware function
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Oct 4, 2024
1 parent ca4a4a4 commit 5f8e069
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 20 deletions.
52 changes: 37 additions & 15 deletions kstreams/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .exceptions import DuplicateStreamException, EngineNotStartedException
from .middleware import ExceptionMiddleware, Middleware
from .middleware import Middleware
from .middleware.udf_middleware import UdfHandler
from .prometheus.monitor import PrometheusMonitor
from .rebalance_listener import MetricsRebalanceListener, RebalanceListener
Expand Down Expand Up @@ -343,10 +343,39 @@ def get_stream(self, name: str) -> typing.Optional[Stream]:
return stream

def add_stream(
self, stream: Stream, error_policy: StreamErrorPolicy = StreamErrorPolicy.STOP
self, stream: Stream, error_policy: typing.Optional[StreamErrorPolicy] = None
) -> None:
"""
Add a stream to the engine.
This method registers a new stream with the engine, setting up necessary
configurations and handlers. If a stream with the same name already exists,
a DuplicateStreamException is raised.
Args:
stream: The stream to be added.
error_policy: An optional error policy to be applied to the stream.
You should probably set directly when instanciating a Stream, not here.
Raises:
DuplicateStreamException: If a stream with the same name already exists.
Notes:
- If the stream does not have a deserializer, the engine's deserializer
is assigned to it.
- If the stream does not have a rebalance listener, a default
MetricsRebalanceListener is assigned.
- The stream's UDF handler is set up with the provided function and
engine's send method.
- If the stream's UDF handler type is not NO_TYPING, a middleware stack
is built for the stream's function.
"""
if self.exist_stream(stream.name):
raise DuplicateStreamException(name=stream.name)

if error_policy is not None:
stream.error_policy = error_policy

stream.backend = self.backend
if stream.deserializer is None:
stream.deserializer = self.deserializer
Expand All @@ -357,8 +386,8 @@ def add_stream(
# when the callbacks are called
stream.rebalance_listener = MetricsRebalanceListener()

stream.rebalance_listener.stream = stream # type: ignore
stream.rebalance_listener.engine = self # type: ignore
stream.rebalance_listener.stream = stream
stream.rebalance_listener.engine = self

stream.udf_handler = UdfHandler(
next_call=stream.func,
Expand All @@ -369,21 +398,14 @@ def add_stream(
# NOTE: When `no typing` support is deprecated this check can
# be removed
if stream.udf_handler.type != UDFType.NO_TYPING:
stream.func = self.build_stream_middleware_stack(
stream=stream, error_policy=error_policy
)
stream.func = self._build_stream_middleware_stack(stream=stream)

def build_stream_middleware_stack(
self, *, stream: Stream, error_policy: StreamErrorPolicy
) -> NextMiddlewareCall:
def _build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall:
assert stream.udf_handler, "UdfHandler can not be None"

stream.middlewares = [
Middleware(ExceptionMiddleware, engine=self, error_policy=error_policy),
] + stream.middlewares

middlewares = stream.get_middlewares(self)
next_call = stream.udf_handler
for middleware, options in reversed(stream.middlewares):
for middleware, options in reversed(middlewares):
next_call = middleware(
next_call=next_call, send=self.send, stream=stream, **options
)
Expand Down
31 changes: 30 additions & 1 deletion kstreams/streams.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import collections
import inspect
import logging
import typing
Expand All @@ -9,16 +10,20 @@

from kstreams import ConsumerRecord, TopicPartition
from kstreams.exceptions import BackendNotSet
from kstreams.middleware.middleware import ExceptionMiddleware
from kstreams.structs import TopicPartitionOffset

from .backends.kafka import Kafka
from .clients import Consumer
from .middleware import Middleware, udf_middleware
from .rebalance_listener import RebalanceListener
from .serializers import Deserializer
from .streams_utils import UDFType
from .streams_utils import StreamErrorPolicy, UDFType
from .types import StreamFunc

if typing.TYPE_CHECKING:
from kstreams import StreamEngine

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -152,6 +157,7 @@ def __init__(
initial_offsets: typing.Optional[typing.List[TopicPartitionOffset]] = None,
rebalance_listener: typing.Optional[RebalanceListener] = None,
middlewares: typing.Optional[typing.List[Middleware]] = None,
error_policy: StreamErrorPolicy = StreamErrorPolicy.STOP,
) -> None:
self.func = func
self.backend = backend
Expand All @@ -169,13 +175,36 @@ def __init__(
self.udf_handler: typing.Optional[udf_middleware.UdfHandler] = None
self.topics = [topics] if isinstance(topics, str) else topics
self.subscribe_by_pattern = subscribe_by_pattern
self.error_policy = error_policy

def _create_consumer(self) -> Consumer:
if self.backend is None:
raise BackendNotSet("A backend has not been set for this stream")
config = {**self.backend.model_dump(), **self.config}
return self.consumer_class(**config)

def get_middlewares(
self, engine: "StreamEngine"
) -> collections.abc.Sequence[Middleware]:
"""
Retrieve the list of middlewares for the stream engine.
Use this instead of the `middlewares` attribute to get the list of middlewares.
Args:
engine: The stream engine instance.
Returns:
A sequence of Middleware instances.
Including the ExceptionMiddleware with the specified error policy and any
additional middlewares.
"""
return [
Middleware(
ExceptionMiddleware, engine=engine, error_policy=self.error_policy
)
] + self.middlewares

async def stop(self) -> None:
if self.running:
# Don't run anymore to prevent new events comming
Expand Down
17 changes: 13 additions & 4 deletions tests/middleware/test_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,19 @@ async def process(cr: ConsumerRecord, stream: Stream):
...

my_stream = stream_engine.get_stream(stream_name)
my_stream_local = stream_engine.get_stream(stream_name)
if my_stream is None:
raise ValueError("Stream not found")
my_stream_local = stream_engine.get_stream(stream_name_local)
if my_stream_local is None:
raise ValueError("Stream not found")

middlewares = [
middleware_factory.middleware for middleware_factory in my_stream.middlewares
middleware_factory.middleware
for middleware_factory in my_stream.get_middlewares(stream_engine)
]
middlewares_stream_local = [
middleware_factory.middleware
for middleware_factory in my_stream_local.middlewares
for middleware_factory in my_stream_local.get_middlewares(stream_engine)
]
assert (
middlewares
Expand All @@ -63,8 +69,11 @@ async def consume(cr: ConsumerRecord):
...

my_stream = stream_engine.get_stream(stream_name)
if my_stream is None:
raise ValueError("Stream not found")
middlewares = [
middleware_factory.middleware for middleware_factory in my_stream.middlewares
middleware_factory.middleware
for middleware_factory in my_stream.get_middlewares(stream_engine)
]
assert middlewares == [
middleware.ExceptionMiddleware,
Expand Down
1 change: 1 addition & 0 deletions tests/test_streams_error_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ async def test_stop_application_error_policy(stream_engine: StreamEngine):
client = TestStreamClient(stream_engine)

with mock.patch("signal.raise_signal"):

@stream_engine.stream(topic, error_policy=StreamErrorPolicy.STOP_APPLICATION)
async def my_stream(cr: ConsumerRecord):
raise ValueError("Crashing Stream...")
Expand Down

0 comments on commit 5f8e069

Please sign in to comment.