diff --git a/src/blueapi/cli/amq.py b/src/blueapi/cli/amq.py index 6d493a5dc..abd85678d 100644 --- a/src/blueapi/cli/amq.py +++ b/src/blueapi/cli/amq.py @@ -1,8 +1,13 @@ import threading -from typing import Callable, Optional +from typing import Callable, Optional, Union +from bluesky.callbacks.best_effort import BestEffortCallback + +from blueapi.core import DataEvent from blueapi.messaging import MessageContext, MessagingTemplate -from blueapi.worker import WorkerEvent +from blueapi.worker import ProgressEvent, WorkerEvent + +from .updates import CliEventRenderer class BlueskyRemoteError(Exception): @@ -33,12 +38,22 @@ def subscribe_to_topics( ) -> None: """Run callbacks on events/progress events with a given correlation id.""" - def on_event_wrapper(ctx: MessageContext, event: WorkerEvent) -> None: - if (on_event is not None) and (ctx.correlation_id == correlation_id): - on_event(event) + progress_bar = CliEventRenderer(correlation_id) + callback = BestEffortCallback() + + def on_event_wrapper( + ctx: MessageContext, event: Union[WorkerEvent, ProgressEvent, DataEvent] + ) -> None: + if isinstance(event, WorkerEvent): + if (on_event is not None) and (ctx.correlation_id == correlation_id): + on_event(event) - if (event.is_complete()) and (ctx.correlation_id == correlation_id): - self.complete.set() + if (event.is_complete()) and (ctx.correlation_id == correlation_id): + self.complete.set() + elif isinstance(event, ProgressEvent): + progress_bar.on_progress_event(event) + elif isinstance(event, DataEvent): + callback(event.name, event.doc) self.app.subscribe( self.app.destinations.topic("public.worker.event"),