Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add recovery mechanism to streaming class #84

Draft
wants to merge 1 commit into
base: v0.x.x
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/frequenz/client/base/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,17 @@
"""The output type of the stream."""


class GrpcStreamBroadcaster(Generic[InputT, OutputT]):
class GrpcStreamBroadcaster(
Generic[InputT, OutputT]
): # pylint: disable=too-many-instance-attributes
"""Helper class to handle grpc streaming methods."""

def __init__(
def __init__( # pylint: disable=too-many-arguments
self,
stream_name: str,
stream_method: Callable[[], AsyncIterator[InputT]],
transform: Callable[[InputT], OutputT],
recovery_method: Callable[[OutputT], None] | None = None,
retry_strategy: retry.Strategy | None = None,
):
"""Initialize the streaming helper.
Expand All @@ -41,12 +44,15 @@ def __init__(
stream_method: A function that returns the grpc stream. This function is
called everytime the connection is lost and we want to retry.
transform: A function to transform the input type to the output type.
recovery_method: A function to call when the connection is lost.
Receives the last message received before the connection was lost.
retry_strategy: The retry strategy to use, when the connection is lost. Defaults
to retries every 3 seconds, with a jitter of 1 second, indefinitely.
"""
self._stream_name = stream_name
self._stream_method = stream_method
self._transform = transform
self._recovery_method = recovery_method
self._retry_strategy = (
retry.LinearBackoff() if retry_strategy is None else retry_strategy.copy()
)
Expand All @@ -55,6 +61,7 @@ def __init__(
name=f"GrpcStreamBroadcaster-{stream_name}"
)
self._task = asyncio.create_task(self._run())
self._last_message: OutputT | None = None

def new_receiver(self, maxsize: int = 50) -> channels.Receiver[OutputT]:
"""Create a new receiver for the stream.
Expand Down Expand Up @@ -87,8 +94,15 @@ async def _run(self) -> None:
_logger.info("%s: starting to stream", self._stream_name)
try:
call = self._stream_method()

# Call the recovery method with the last message received before the
# connection was lost.
if self._last_message is not None and self._recovery_method is not None:
self._recovery_method(self._last_message)

async for msg in call:
await sender.send(self._transform(msg))
self._last_message = self._transform(msg)
await sender.send(self._last_message)
except grpc.aio.AioRpcError as err:
error = err
error_str = f"Error: {error}" if error else "Stream exhausted"
Expand Down
Loading