diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index a9c5f448e..c9b285a36 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -34,6 +34,9 @@ #: A function that generates a plan PlanGenerator = Callable[..., MsgGenerator] +#: A wrapper that takes a plan and preprocesses its messages +PlanWrapper = Callable[[MsgGenerator], MsgGenerator] + #: An object that encapsulates the device to do useful things to produce # data (e.g. move and read) Device = Union[ diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index cefdf390e..f8f1b9497 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,3 +1,4 @@ +import functools import logging from dataclasses import dataclass, field from importlib import import_module @@ -8,8 +9,10 @@ Callable, Dict, Generic, + Iterable, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -30,8 +33,10 @@ BLUESKY_PROTOCOLS, Device, HasName, + MsgGenerator, Plan, PlanGenerator, + PlanWrapper, is_bluesky_compatible_device, is_bluesky_plan_generator, ) @@ -51,12 +56,21 @@ class BlueskyContext: run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) + plan_wrappers: Sequence[PlanWrapper] = field(default_factory=list) plans: Dict[str, Plan] = field(default_factory=dict) devices: Dict[str, Device] = field(default_factory=dict) plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict) _reference_cache: Dict[Type, Type] = field(default_factory=dict) + def wrap(self, plan: MsgGenerator) -> Iterable[PlanWrapper]: + wrapped_plan = functools.reduce( + lambda wrapped, next_wrapper: next_wrapper(wrapped), + self.plan_wrappers, + plan, + ) + yield from wrapped_plan + def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]: """ Find a device in this context, allows for recursive search. diff --git a/src/blueapi/plugins/data_writing.py b/src/blueapi/plugins/data_writing.py index d662e0a64..58ba282a2 100644 --- a/src/blueapi/plugins/data_writing.py +++ b/src/blueapi/plugins/data_writing.py @@ -27,6 +27,8 @@ from .data_writing_server import DataCollection, DataCollectionSetupResult +DATA_COLLECTION_NUMBER = "data_collection_number" + class DataCollectionProvider(ABC): @abstractmethod @@ -87,7 +89,7 @@ def data_writing_wrapper( if message.command == "open_run": if next_scan_number is None: next_scan_number = next(scan_number) - message.kwargs["scan_number"] = next_scan_number + message.kwargs[DATA_COLLECTION_NUMBER] = next_scan_number yield message diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index 484e17aae..89874887a 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -26,7 +26,7 @@ def __init__( worker: Optional[Worker] = None, ) -> None: self.config = config or ApplicationConfig() - self.context = context or BlueskyContext() + self.context = context or BlueskyContext(plan_wrappers=[]) self.context.with_config(self.config.env) diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index d09add5c7..ced771828 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -1,9 +1,11 @@ +import functools import logging -from typing import Any, Mapping, Optional +from typing import Any, Iterable, Mapping, Optional from pydantic import BaseModel, Field from blueapi.core import BlueskyContext +from blueapi.core.bluesky_types import MsgGenerator, PlanWrapper from blueapi.utils import BlueapiBaseModel LOGGER = logging.getLogger(__name__) @@ -29,7 +31,8 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self._ensure_params(ctx) plan_generator = func(**prepared_params.dict()) - ctx.run_engine(plan_generator) + wrapped_plan_generator = ctx.wrap(plan_generator) + ctx.run_engine(wrapped_plan_generator) def _ensure_params(self, ctx: BlueskyContext) -> BaseModel: if self._prepared_params is None: diff --git a/tests/plugins/test_data_writing.py b/tests/plugins/test_data_writing.py index 6ad1a3593..39a4c48d6 100644 --- a/tests/plugins/test_data_writing.py +++ b/tests/plugins/test_data_writing.py @@ -16,7 +16,7 @@ from ophyd.sim import SynAxis from blueapi.core import DataEvent, MsgGenerator -from blueapi.plugins.data_writing import data_writing_wrapper +from blueapi.plugins.data_writing import DATA_COLLECTION_NUMBER, data_writing_wrapper @pytest.fixture @@ -60,15 +60,15 @@ def stageless_count() -> MsgGenerator: return (yield from bps.one_shot(detectors)) def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count(), md={"scan_number": 1}) - yield from run_wrapper(stageless_count(), md={"scan_number": 1}) - yield from run_wrapper(stageless_count(), md={"scan_number": 2}) - yield from run_wrapper(stageless_count(), md={"scan_number": 2}) + yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 1}) + yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 1}) + yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 2}) + yield from run_wrapper(stageless_count(), md={DATA_COLLECTION_NUMBER: 2}) yield from stage_wrapper(inner_plan(), detectors) -@run_decorator(md={"scan_number": 12345}) +@run_decorator(md={DATA_COLLECTION_NUMBER: 12345}) @set_run_key_decorator("outer") def nested_run_with_metadata(detectors: List[Readable]) -> MsgGenerator: yield from set_run_key_wrapper(bp.count(detectors), "inner") @@ -87,7 +87,7 @@ def test_simple_run_gets_scan_number( ) -> None: docs = collect_docs(run_engine, simple_run(detectors)) assert docs[0].name == "start" - assert docs[0].doc["scan_number"] == 0 + assert docs[0].doc[DATA_COLLECTION_NUMBER] == 0 @pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) @@ -99,8 +99,8 @@ def test_multi_run_gets_scan_numbers( docs = collect_docs(run_engine, plan(detectors)) start_docs = find_start_docs(docs) assert len(start_docs) == 2 - assert start_docs[0].doc["scan_number"] == 0 - assert start_docs[1].doc["scan_number"] == 1 + assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 1 def test_multi_run_single_stage( @@ -110,8 +110,8 @@ def test_multi_run_single_stage( docs = collect_docs(run_engine, multi_run_single_stage(detectors)) start_docs = find_start_docs(docs) assert len(start_docs) == 2 - assert start_docs[0].doc["scan_number"] == 0 - assert start_docs[1].doc["scan_number"] == 0 + assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 0 def test_multi_run_single_stage_multi_group( @@ -121,10 +121,10 @@ def test_multi_run_single_stage_multi_group( docs = collect_docs(run_engine, multi_run_single_stage_multi_group(detectors)) start_docs = find_start_docs(docs) assert len(start_docs) == 4 - assert start_docs[0].doc["scan_number"] == 0 - assert start_docs[1].doc["scan_number"] == 0 - assert start_docs[2].doc["scan_number"] == 0 - assert start_docs[3].doc["scan_number"] == 0 + assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[2].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[3].doc[DATA_COLLECTION_NUMBER] == 0 def test_nested_run_with_metadata( @@ -134,9 +134,9 @@ def test_nested_run_with_metadata( docs = collect_docs(run_engine, nested_run_with_metadata(detectors)) start_docs = find_start_docs(docs) assert len(start_docs) == 3 - assert start_docs[0].doc["scan_number"] == 0 - assert start_docs[1].doc["scan_number"] == 1 - assert start_docs[2].doc["scan_number"] == 2 + assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 1 + assert start_docs[2].doc[DATA_COLLECTION_NUMBER] == 2 def test_nested_run_without_metadata( @@ -146,9 +146,9 @@ def test_nested_run_without_metadata( docs = collect_docs(run_engine, nested_run_without_metadata(detectors)) start_docs = find_start_docs(docs) assert len(start_docs) == 3 - assert start_docs[0].doc["scan_number"] == 0 - assert start_docs[1].doc["scan_number"] == 1 - assert start_docs[2].doc["scan_number"] == 2 + assert start_docs[0].doc[DATA_COLLECTION_NUMBER] == 0 + assert start_docs[1].doc[DATA_COLLECTION_NUMBER] == 1 + assert start_docs[2].doc[DATA_COLLECTION_NUMBER] == 2 def collect_docs(run_engine: RunEngine, plan: MsgGenerator) -> List[DataEvent]: