diff --git a/src/blueapi/controller/base.py b/src/blueapi/controller/base.py index fccb0ef23..4b531bde9 100644 --- a/src/blueapi/controller/base.py +++ b/src/blueapi/controller/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Mapping -from blueapi.core import Ability, Plan +from blueapi.core import Ability, AsyncEventStreamBase, Plan class BlueskyControllerBase(ABC): @@ -26,6 +26,11 @@ async def run_plan(self, __name: str, __params: Mapping[str, Any]) -> None: ... + @property + @abstractmethod + def worker_events(self) -> AsyncEventStreamBase: + ... + @property @abstractmethod def plans(self) -> Mapping[str, Plan]: diff --git a/src/blueapi/controller/controller.py b/src/blueapi/controller/controller.py index bc713c454..da0aefd6a 100644 --- a/src/blueapi/controller/controller.py +++ b/src/blueapi/controller/controller.py @@ -6,6 +6,8 @@ from blueapi.core import ( Ability, + AsyncEventStreamBase, + AsyncEventStreamWrapper, Plan, create_bluesky_protocol_conversions, nested_deserialize_with_overrides, @@ -55,9 +57,9 @@ async def run_plan(self, name: str, params: Mapping[str, Any]) -> None: task = RunPlan(plan_function(**sanitized_params)) loop.call_soon_threadsafe(self._worker.submit_task, task) - async def worker_events(self) -> AsyncIterable[WorkerEvent]: - async for value in async_events(self._worker.worker_events.subscribe): - yield value + @property + def worker_events(self) -> AsyncEventStreamBase: + return AsyncEventStreamWrapper(self._worker.worker_events) @property def plans(self) -> Mapping[str, Plan]: diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index c9e1ea7dc..c56c2bef8 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -1,6 +1,11 @@ from .bluesky_types import BLUESKY_PROTOCOLS, Ability, Plan, PlanGenerator from .device_lookup import create_bluesky_protocol_conversions -from .event import EventStream, EventStreamBase +from .event import ( + AsyncEventStreamBase, + AsyncEventStreamWrapper, + EventStream, + EventStreamBase, +) from .schema import nested_deserialize_with_overrides, schema_for_func __all__ = [ @@ -13,4 +18,6 @@ "schema_for_func", "nested_deserialize_with_overrides", "create_bluesky_protocol_conversions", + "AsyncEventStreamBase", + "AsyncEventStreamWrapper", ] diff --git a/src/blueapi/core/event.py b/src/blueapi/core/event.py index 4e50161b5..0b15aa365 100644 --- a/src/blueapi/core/event.py +++ b/src/blueapi/core/event.py @@ -1,5 +1,8 @@ +import asyncio from abc import ABC, abstractmethod -from typing import AsyncIterable, Callable, Dict, Generic, List, Mapping, TypeVar +from typing import Awaitable, Callable, Dict, Generic, Optional, TypeVar + +import janus E = TypeVar("E") S = TypeVar("S") @@ -19,6 +22,20 @@ def unsubscribe_all(self) -> None: ... +class AsyncEventStreamBase(ABC, Generic[E, S]): + @abstractmethod + def subscribe(self, __callback: Callable[[E], Awaitable[None]]) -> S: + ... + + @abstractmethod + def unsubscribe(self, __subscription: S) -> None: + ... + + @abstractmethod + def unsubscribe_all(self) -> None: + ... + + class EventStream(EventStreamBase[E, int]): _subscriptions: Dict[int, Callable[[E], None]] _count: int @@ -37,3 +54,30 @@ def unsubscribe_all(self) -> None: def notify(self, value: E) -> None: for callback in self._subscriptions.values(): callback(value) + + +class AsyncEventStreamWrapper(AsyncEventStreamBase[E, S]): + _wrapped: EventStreamBase[E, S] + _loop: asyncio.AbstractEventLoop + + def __init__( + self, + wrapped: EventStreamBase[E, S], + loop: Optional[asyncio.AbstractEventLoop] = None, + ) -> None: + if loop is None: + loop = asyncio.get_event_loop() + self._wrapped = wrapped + self._loop = loop + + def subscribe(self, callback: Callable[[E], Awaitable[None]]) -> S: + def sync_callback(value: E) -> None: + asyncio.run_coroutine_threadsafe(callback(value), self._loop) + + return self._wrapped.subscribe(sync_callback) + + def unsubscribe(self, subscription: S) -> None: + return self._wrapped.unsubscribe(subscription) + + def unsubscribe_all(self) -> None: + return self._wrapped.unsubscribe_all() diff --git a/src/blueapi/example/server/app.py b/src/blueapi/example/server/app.py index f1aace3cd..f801a27ce 100644 --- a/src/blueapi/example/server/app.py +++ b/src/blueapi/example/server/app.py @@ -13,6 +13,7 @@ from blueapi.controller import BlueskyContext, BlueskyController from blueapi.core import BLUESKY_PROTOCOLS, Ability, Plan +from blueapi.worker import WorkerEvent ctx = BlueskyContext() logging.basicConfig(level=logging.INFO) @@ -89,6 +90,11 @@ async def run_plan(request: Request, name: str) -> uuid.UUID: @app.websocket("/run/status") async def subscribe_run_status(websocket: WebSocket) -> None: await websocket.accept() + + async def reply(event: WorkerEvent): + await websocket.send_json(serialize(event)) + + controller.worker_events.subscribe(reply) with await controller.worker_events() as events: async for event in events: await websocket.send_json(serialize(event)) diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index 6a97b488d..816abb8ee 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,4 +1,4 @@ -from .aio_threading import async_events, concurrent_future_to_aio_future +from .aio_threading import concurrent_future_to_aio_future from .thread_exception import handle_all_exceptions -__all__ = ["handle_all_exceptions", "concurrent_future_to_aio_future", "async_events"] +__all__ = ["handle_all_exceptions", "concurrent_future_to_aio_future"] diff --git a/src/blueapi/utils/aio_threading.py b/src/blueapi/utils/aio_threading.py index 24cc4adf0..ee27079b3 100644 --- a/src/blueapi/utils/aio_threading.py +++ b/src/blueapi/utils/aio_threading.py @@ -1,9 +1,7 @@ import asyncio from asyncio import Future as AioFuture from concurrent.futures import Future as ConcurrentFuture -from typing import Any, AsyncIterable, Callable, Optional, TypeVar - -import janus +from typing import Optional def concurrent_future_to_aio_future( @@ -26,20 +24,3 @@ def on_complete(future: ConcurrentFuture) -> None: concurrent_future.add_done_callback(on_complete) return aio_future - - -E = TypeVar("E") - - -async def async_events( - subscribe: Callable[[Callable[[E], Any]], Any], - loop: Optional[asyncio.AbstractEventLoop] = None, -) -> AsyncIterable[E]: - if loop is None: - loop = asyncio.get_event_loop() - - queue: janus.Queue = janus.Queue() - subscribe(queue.sync_q.put) - - while True: - yield await queue.async_q.get()