Skip to content

Commit

Permalink
Write plan preprocessors
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Aug 1, 2023
1 parent 60354a8 commit f14cca5
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 25 deletions.
3 changes: 3 additions & 0 deletions src/blueapi/core/bluesky_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@
#: A function that generates a plan
PlanGenerator = Callable[..., MsgGenerator]

#: A wrapper that takes a plan and preprocesses its messages
PlanWrapper = Callable[[MsgGenerator], MsgGenerator]

#: An object that encapsulates the device to do useful things to produce
# data (e.g. move and read)
Device = Union[
Expand Down
14 changes: 14 additions & 0 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import logging
from dataclasses import dataclass, field
from importlib import import_module
Expand All @@ -8,8 +9,10 @@
Callable,
Dict,
Generic,
Iterable,
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Expand All @@ -30,8 +33,10 @@
BLUESKY_PROTOCOLS,
Device,
HasName,
MsgGenerator,
Plan,
PlanGenerator,
PlanWrapper,
is_bluesky_compatible_device,
is_bluesky_plan_generator,
)
Expand All @@ -51,12 +56,21 @@ class BlueskyContext:
run_engine: RunEngine = field(
default_factory=lambda: RunEngine(context_managers=[])
)
plan_wrappers: Sequence[PlanWrapper] = field(default_factory=list)
plans: Dict[str, Plan] = field(default_factory=dict)
devices: Dict[str, Device] = field(default_factory=dict)
plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict)

_reference_cache: Dict[Type, Type] = field(default_factory=dict)

def wrap(self, plan: MsgGenerator) -> Iterable[PlanWrapper]:
wrapped_plan = functools.reduce(
lambda wrapped, next_wrapper: next_wrapper(wrapped),
self.plan_wrappers,
plan,
)
yield from wrapped_plan

def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]:
"""
Find a device in this context, allows for recursive search.
Expand Down
4 changes: 3 additions & 1 deletion src/blueapi/plugins/data_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@

from .data_writing_server import DataCollection, DataCollectionSetupResult

DATA_COLLECTION_NUMBER = "data_collection_number"


class DataCollectionProvider(ABC):
@abstractmethod
Expand Down Expand Up @@ -87,7 +89,7 @@ def data_writing_wrapper(
if message.command == "open_run":
if next_scan_number is None:
next_scan_number = next(scan_number)
message.kwargs["scan_number"] = next_scan_number
message.kwargs[DATA_COLLECTION_NUMBER] = next_scan_number
yield message


Expand Down
2 changes: 1 addition & 1 deletion src/blueapi/service/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __init__(
worker: Optional[Worker] = None,
) -> None:
self.config = config or ApplicationConfig()
self.context = context or BlueskyContext()
self.context = context or BlueskyContext(plan_wrappers=[])

self.context.with_config(self.config.env)

Expand Down
7 changes: 5 additions & 2 deletions src/blueapi/worker/task.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
import logging
from typing import Any, Mapping, Optional
from typing import Any, Iterable, Mapping, Optional

from pydantic import BaseModel, Field

from blueapi.core import BlueskyContext
from blueapi.core.bluesky_types import MsgGenerator, PlanWrapper
from blueapi.utils import BlueapiBaseModel

LOGGER = logging.getLogger(__name__)
Expand All @@ -29,7 +31,8 @@ def do_task(self, ctx: BlueskyContext) -> None:
func = ctx.plan_functions[self.name]
prepared_params = self._ensure_params(ctx)
plan_generator = func(**prepared_params.dict())
ctx.run_engine(plan_generator)
wrapped_plan_generator = ctx.wrap(plan_generator)
ctx.run_engine(wrapped_plan_generator)

def _ensure_params(self, ctx: BlueskyContext) -> BaseModel:
if self._prepared_params is None:
Expand Down
42 changes: 21 additions & 21 deletions tests/plugins/test_data_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ophyd.sim import SynAxis

from blueapi.core import DataEvent, MsgGenerator
from blueapi.plugins.data_writing import data_writing_wrapper
from blueapi.plugins.data_writing import DATA_COLLECTION_NUMBER, data_writing_wrapper


@pytest.fixture
Expand Down Expand Up @@ -60,15 +60,15 @@ def stageless_count() -> MsgGenerator:
return (yield from bps.one_shot(detectors))

def inner_plan() -> MsgGenerator:
yield from run_wrapper(stageless_count(), md={"scan_number": 1})
yield from run_wrapper(stageless_count(), md={"scan_number": 1})
yield from run_wrapper(stageless_count(), md={"scan_number": 2})
yield from run_wrapper(stageless_count(), md={"scan_number": 2})
yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 1})
yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 1})
yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 2})
yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 2})

yield from stage_wrapper(inner_plan(), detectors)


@run_decorator(md={"scan_number": 12345})
@run_decorator(md={DATA_COLLECTION_NUMBER: 12345})
@set_run_key_decorator("outer")
def nested_run_with_metadata(detectors: List[Readable]) -> MsgGenerator:
yield from set_run_key_wrapper(bp.count(detectors), "inner")
Expand All @@ -87,7 +87,7 @@ def test_simple_run_gets_scan_number(
) -> None:
docs = collect_docs(run_engine, simple_run(detectors))
assert docs[0].name == "start"
assert docs[0].doc["scan_number"] == 0
assert docs[0].doc[DATA_COLLECTION_NUMBER] == 0


@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan])
Expand All @@ -99,8 +99,8 @@ def test_multi_run_gets_scan_numbers(
docs = collect_docs(run_engine, plan(detectors))
start_docs = find_start_docs(docs)
assert len(start_docs) == 2
assert start_docs[0].doc["scan_number"] == 0
assert start_docs[1].doc["scan_number"] == 1
assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 1


def test_multi_run_single_stage(
Expand All @@ -110,8 +110,8 @@ def test_multi_run_single_stage(
docs = collect_docs(run_engine, multi_run_single_stage(detectors))
start_docs = find_start_docs(docs)
assert len(start_docs) == 2
assert start_docs[0].doc["scan_number"] == 0
assert start_docs[1].doc["scan_number"] == 0
assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 0


def test_multi_run_single_stage_multi_group(
Expand All @@ -121,10 +121,10 @@ def test_multi_run_single_stage_multi_group(
docs = collect_docs(run_engine, multi_run_single_stage_multi_group(detectors))
start_docs = find_start_docs(docs)
assert len(start_docs) == 4
assert start_docs[0].doc["scan_number"] == 0
assert start_docs[1].doc["scan_number"] == 0
assert start_docs[2].doc["scan_number"] == 0
assert start_docs[3].doc["scan_number"] == 0
assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[2].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[3].doc[DATA_COLLECTION_NUMBER] == 0


def test_nested_run_with_metadata(
Expand All @@ -134,9 +134,9 @@ def test_nested_run_with_metadata(
docs = collect_docs(run_engine, nested_run_with_metadata(detectors))
start_docs = find_start_docs(docs)
assert len(start_docs) == 3
assert start_docs[0].doc["scan_number"] == 0
assert start_docs[1].doc["scan_number"] == 1
assert start_docs[2].doc["scan_number"] == 2
assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 1
assert start_docs[2].doc[DATA_COLLECTION_NUMBER] == 2


def test_nested_run_without_metadata(
Expand All @@ -146,9 +146,9 @@ def test_nested_run_without_metadata(
docs = collect_docs(run_engine, nested_run_without_metadata(detectors))
start_docs = find_start_docs(docs)
assert len(start_docs) == 3
assert start_docs[0].doc["scan_number"] == 0
assert start_docs[1].doc["scan_number"] == 1
assert start_docs[2].doc["scan_number"] == 2
assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0
assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 1
assert start_docs[2].doc[DATA_COLLECTION_NUMBER] == 2


def collect_docs(run_engine: RunEngine, plan: MsgGenerator) -> List[DataEvent]:
Expand Down

0 comments on commit f14cca5

Please sign in to comment.