Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Apr 8, 2022
1 parent 3d5de60 commit 81d3380
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 28 deletions.
7 changes: 6 additions & 1 deletion src/blueapi/controller/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Any, Mapping

from blueapi.core import Ability, Plan
from blueapi.core import Ability, AsyncEventStreamBase, Plan


class BlueskyControllerBase(ABC):
Expand All @@ -26,6 +26,11 @@ async def run_plan(self, __name: str, __params: Mapping[str, Any]) -> None:

...

@property
@abstractmethod
def worker_events(self) -> AsyncEventStreamBase:
...

@property
@abstractmethod
def plans(self) -> Mapping[str, Plan]:
Expand Down
8 changes: 5 additions & 3 deletions src/blueapi/controller/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from blueapi.core import (
Ability,
AsyncEventStreamBase,
AsyncEventStreamWrapper,
Plan,
create_bluesky_protocol_conversions,
nested_deserialize_with_overrides,
Expand Down Expand Up @@ -55,9 +57,9 @@ async def run_plan(self, name: str, params: Mapping[str, Any]) -> None:
task = RunPlan(plan_function(**sanitized_params))
loop.call_soon_threadsafe(self._worker.submit_task, task)

async def worker_events(self) -> AsyncIterable[WorkerEvent]:
async for value in async_events(self._worker.worker_events.subscribe):
yield value
@property
def worker_events(self) -> AsyncEventStreamBase:
return AsyncEventStreamWrapper(self._worker.worker_events)

@property
def plans(self) -> Mapping[str, Plan]:
Expand Down
9 changes: 8 additions & 1 deletion src/blueapi/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
from .bluesky_types import BLUESKY_PROTOCOLS, Ability, Plan, PlanGenerator
from .device_lookup import create_bluesky_protocol_conversions
from .event import EventStream, EventStreamBase
from .event import (
AsyncEventStreamBase,
AsyncEventStreamWrapper,
EventStream,
EventStreamBase,
)
from .schema import nested_deserialize_with_overrides, schema_for_func

__all__ = [
Expand All @@ -13,4 +18,6 @@
"schema_for_func",
"nested_deserialize_with_overrides",
"create_bluesky_protocol_conversions",
"AsyncEventStreamBase",
"AsyncEventStreamWrapper",
]
46 changes: 45 additions & 1 deletion src/blueapi/core/event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio
from abc import ABC, abstractmethod
from typing import AsyncIterable, Callable, Dict, Generic, List, Mapping, TypeVar
from typing import Awaitable, Callable, Dict, Generic, Optional, TypeVar

import janus

E = TypeVar("E")
S = TypeVar("S")
Expand All @@ -19,6 +22,20 @@ def unsubscribe_all(self) -> None:
...


class AsyncEventStreamBase(ABC, Generic[E, S]):
@abstractmethod
def subscribe(self, __callback: Callable[[E], Awaitable[None]]) -> S:
...

@abstractmethod
def unsubscribe(self, __subscription: S) -> None:
...

@abstractmethod
def unsubscribe_all(self) -> None:
...


class EventStream(EventStreamBase[E, int]):
_subscriptions: Dict[int, Callable[[E], None]]
_count: int
Expand All @@ -37,3 +54,30 @@ def unsubscribe_all(self) -> None:
def notify(self, value: E) -> None:
for callback in self._subscriptions.values():
callback(value)


class AsyncEventStreamWrapper(AsyncEventStreamBase[E, S]):
_wrapped: EventStreamBase[E, S]
_loop: asyncio.AbstractEventLoop

def __init__(
self,
wrapped: EventStreamBase[E, S],
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
if loop is None:
loop = asyncio.get_event_loop()
self._wrapped = wrapped
self._loop = loop

def subscribe(self, callback: Callable[[E], Awaitable[None]]) -> S:
def sync_callback(value: E) -> None:
asyncio.run_coroutine_threadsafe(callback(value), self._loop)

return self._wrapped.subscribe(sync_callback)

def unsubscribe(self, subscription: S) -> None:
return self._wrapped.unsubscribe(subscription)

def unsubscribe_all(self) -> None:
return self._wrapped.unsubscribe_all()
6 changes: 6 additions & 0 deletions src/blueapi/example/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from blueapi.controller import BlueskyContext, BlueskyController
from blueapi.core import BLUESKY_PROTOCOLS, Ability, Plan
from blueapi.worker import WorkerEvent

ctx = BlueskyContext()
logging.basicConfig(level=logging.INFO)
Expand Down Expand Up @@ -89,6 +90,11 @@ async def run_plan(request: Request, name: str) -> uuid.UUID:
@app.websocket("/run/status")
async def subscribe_run_status(websocket: WebSocket) -> None:
await websocket.accept()

async def reply(event: WorkerEvent):
await websocket.send_json(serialize(event))

controller.worker_events.subscribe(reply)
with await controller.worker_events() as events:
async for event in events:
await websocket.send_json(serialize(event))
Expand Down
4 changes: 2 additions & 2 deletions src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .aio_threading import async_events, concurrent_future_to_aio_future
from .aio_threading import concurrent_future_to_aio_future
from .thread_exception import handle_all_exceptions

__all__ = ["handle_all_exceptions", "concurrent_future_to_aio_future", "async_events"]
__all__ = ["handle_all_exceptions", "concurrent_future_to_aio_future"]
21 changes: 1 addition & 20 deletions src/blueapi/utils/aio_threading.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import asyncio
from asyncio import Future as AioFuture
from concurrent.futures import Future as ConcurrentFuture
from typing import Any, AsyncIterable, Callable, Optional, TypeVar

import janus
from typing import Optional


def concurrent_future_to_aio_future(
Expand All @@ -26,20 +24,3 @@ def on_complete(future: ConcurrentFuture) -> None:
concurrent_future.add_done_callback(on_complete)

return aio_future


E = TypeVar("E")


async def async_events(
subscribe: Callable[[Callable[[E], Any]], Any],
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> AsyncIterable[E]:
if loop is None:
loop = asyncio.get_event_loop()

queue: janus.Queue = janus.Queue()
subscribe(queue.sync_q.put)

while True:
yield await queue.async_q.get()

0 comments on commit 81d3380

Please sign in to comment.