diff --git a/.github/actions/install_requirements/action.yml b/.github/actions/install_requirements/action.yml index aea1e0d4b..6036af752 100644 --- a/.github/actions/install_requirements/action.yml +++ b/.github/actions/install_requirements/action.yml @@ -30,7 +30,7 @@ runs: - name: Create lockfile run: | mkdir -p lockfiles - pip freeze --exclude-editable --exclude dodal > lockfiles/${{ inputs.requirements_file }} + pip freeze --exclude-editable --exclude ophyd-async --exclude dls-dodal --exclude dls-bluesky-core --exclude bluesky > lockfiles/${{ inputs.requirements_file }} # delete the self referencing line and make sure it isn't blank sed -i '/file:/d' lockfiles/${{ inputs.requirements_file }} shell: bash diff --git a/.github/workflows/code.yml b/.github/workflows/code.yml index aadd793ec..878105599 100644 --- a/.github/workflows/code.yml +++ b/.github/workflows/code.yml @@ -49,14 +49,13 @@ jobs: # https://github.com/pytest-dev/pytest/issues/2042 PY_IGNORE_IMPORTMISMATCH: "1" BLUEAPI_TEST_STOMP_PORTS: "[61613,61614]" - steps: - name: Start RabbitMQ uses: namoshek/rabbitmq-github-action@v1 with: - ports: '61614:61613' - plugins: rabbitmq_stomp + ports: "61614:61613" + plugins: rabbitmq_stomp - name: Checkout uses: actions/checkout@v3 @@ -165,7 +164,7 @@ jobs: uses: docker/build-push-action@v3 with: build-args: | - PIP_OPTIONS=-r lockfiles/requirements.txt dist/*.whl + PIP_OPTIONS=-r lockfiles/requirements.txt git+https://github.com/bluesky/bluesky.git git+https://github.com/DiamondLightSource/dls-bluesky-core.git git+https://github.com/DiamondLightSource/dodal.git@directory_provider dist/*.whl push: ${{ github.event_name == 'push' && startsWith(github.ref, 'refs/tags') }} load: ${{ ! (github.event_name == 'push' && startsWith(github.ref, 'refs/tags')) }} tags: ${{ steps.meta.outputs.tags }} diff --git a/.vscode/settings.json b/.vscode/settings.json index 7eac220fe..8bf638004 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -14,4 +14,4 @@ }, "esbonio.server.enabled": true, "esbonio.sphinx.confDir": "", -} +} \ No newline at end of file diff --git a/docs/developer/explanations/lifecycle.rst b/docs/developer/explanations/lifecycle.rst index a28004ee1..e80f2eaa7 100644 --- a/docs/developer/explanations/lifecycle.rst +++ b/docs/developer/explanations/lifecycle.rst @@ -9,7 +9,8 @@ of being written, loaded and run. Take the following plan. from typing import Any, List, Mapping, Optional, Union import bluesky.plans as bp - from blueapi.core import MsgGenerator, inject + from blueapi.core import MsgGenerator + from dls_bluesky_core.core import inject from bluesky.protocols import Readable diff --git a/pyproject.toml b/pyproject.toml index 99ffaa94d..88de76627 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,19 +13,20 @@ classifiers = [ ] description = "Lightweight Bluesky-as-a-service wrapper application. Also usable as a library." dependencies = [ - "bluesky<1.11", + "bluesky @ git+https://github.com/bluesky/bluesky.git", "ophyd", "nslsii", "pyepics", "pydantic<2.0", "stomp.py", + "aiohttp", "PyYAML", "click<8.1.4", "fastapi[all]<0.100", "uvicorn", "requests", - "dls_bluesky_core", - "dls-dodal", + "dls-bluesky-core @ git+https://github.com/DiamondLightSource/dls-bluesky-core.git", #requires ophyd-async + "dls-dodal @ git+https://github.com/DiamondLightSource/dodal.git@directory_provider", # requires aioca... "typing_extensions<4.6", ] dynamic = ["version"] @@ -43,6 +44,7 @@ dev = [ "pre-commit", "pydata-sphinx-theme>=0.12", "pytest-cov", + "pytest-asyncio", "sphinx-autobuild", "sphinx-copybutton", "sphinx-click", diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 7ae5c3ad7..368a7a54d 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -49,6 +49,12 @@ class ScratchConfig(BlueapiBaseModel): auto_make_directory: bool = Field(default=False) +class DataWritingConfig(BlueapiBaseModel): + visit_service_url: Optional[str] = None # e.g. "http://localhost:8088/api" + visit_directory: Path = Path("/tmp/0-0") + group_name: str = "example" + + class EnvironmentConfig(BlueapiBaseModel): """ Config for the RunEngine environment @@ -63,11 +69,7 @@ class EnvironmentConfig(BlueapiBaseModel): Source(kind=SourceKind.PLAN_FUNCTIONS, module="dls_bluesky_core.stubs"), ] scratch: Optional[ScratchConfig] = Field(default=None) - - def __eq__(self, other: object) -> bool: - if isinstance(other, EnvironmentConfig): - return str(self.sources) == str(other.sources) - return False + data_writing: DataWritingConfig = Field(default_factory=DataWritingConfig) class LoggingConfig(BlueapiBaseModel): diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index 0080b2b63..7b2306b0f 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -12,7 +12,6 @@ is_bluesky_plan_generator, ) from .context import BlueskyContext -from .device_lookup import inject from .event import EventPublisher, EventStream __all__ = [ @@ -26,7 +25,6 @@ "EventStream", "DataEvent", "WatchableStatus", - "inject", "is_bluesky_compatible_device", "is_bluesky_plan_generator", "is_bluesky_compatible_device_type", diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 9ff3e0074..b09a2648d 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -19,6 +19,7 @@ WritesExternalAssets, ) from dls_bluesky_core.core import MsgGenerator, PlanGenerator +from ophyd_async.core import Device as AsyncDevice from pydantic import BaseModel, Field from blueapi.utils import BlueapiBaseModel @@ -28,6 +29,8 @@ except ImportError: from typing_extensions import Protocol, runtime_checkable # type: ignore +PlanWrapper = Callable[[MsgGenerator], MsgGenerator] + #: An object that encapsulates the device to do useful things to produce # data (e.g. move and read) Device = Union[ @@ -45,6 +48,7 @@ WritesExternalAssets, Configurable, Triggerable, + AsyncDevice, ] #: Protocols defining interface to hardware diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index e8082132d..ffad9b98a 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -1,3 +1,4 @@ +import functools import logging from dataclasses import dataclass, field from importlib import import_module @@ -10,6 +11,7 @@ Generic, List, Optional, + Sequence, Tuple, Type, TypeVar, @@ -19,19 +21,24 @@ get_type_hints, ) -from bluesky import RunEngine +from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop +from ophyd_async.core import Device as AsyncDevice +from ophyd_async.core import wait_for_connection from pydantic import create_model from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind +from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider from blueapi.utils import BlueapiPlanModelConfig, load_module_all from .bluesky_types import ( BLUESKY_PROTOCOLS, Device, HasName, + MsgGenerator, Plan, PlanGenerator, + PlanWrapper, is_bluesky_compatible_device, is_bluesky_plan_generator, ) @@ -51,12 +58,23 @@ class BlueskyContext: run_engine: RunEngine = field( default_factory=lambda: RunEngine(context_managers=[]) ) + plan_wrappers: Sequence[PlanWrapper] = field(default_factory=list) plans: Dict[str, Plan] = field(default_factory=dict) devices: Dict[str, Device] = field(default_factory=dict) plan_functions: Dict[str, PlanGenerator] = field(default_factory=dict) + directory_provider: Optional[VisitDirectoryProvider] = field(default=None) + sim: bool = field(default=False) _reference_cache: Dict[Type, Type] = field(default_factory=dict) + def wrap(self, plan: MsgGenerator) -> MsgGenerator: + wrapped_plan = functools.reduce( + lambda wrapped, next_wrapper: next_wrapper(wrapped), + self.plan_wrappers, + plan, + ) + yield from wrapped_plan + def find_device(self, addr: Union[str, List[str]]) -> Optional[Device]: """ Find a device in this context, allows for recursive search. @@ -86,6 +104,18 @@ def with_config(self, config: EnvironmentConfig) -> None: elif source.kind is SourceKind.DODAL: self.with_dodal_module(mod) + call_in_bluesky_event_loop(self.connect_devices(self.sim)) + + async def connect_devices(self, sim: bool = False) -> None: + coros = {} + for device_name, device in self.devices.items(): + if isinstance(device, AsyncDevice): + device.set_name(device_name) + coros[device_name] = device.connect(sim) + + if len(coros) > 0: + await wait_for_connection(**coros) + def with_plan_module(self, module: ModuleType) -> None: """ Register all functions in the module supplied as plans. @@ -113,10 +143,10 @@ def plan_2(...) -> MsgGenerator: def with_device_module(self, module: ModuleType) -> None: self.with_dodal_module(module) - def with_dodal_module(self, module: ModuleType, **kwargs) -> None: + def with_dodal_module(self, module: ModuleType) -> None: from dodal.utils import make_all_devices - for device in make_all_devices(module, **kwargs).values(): + for device in make_all_devices(module).values(): self.device(device) def plan(self, plan: PlanGenerator) -> PlanGenerator: diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index ca1506a46..957a057f7 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -48,21 +48,3 @@ def find_component(obj: Any, addr: List[str]) -> Optional[D]: f"Found {component} in {obj} while searching for {addr} " "but it is not a device" ) - - -def inject(name: str): - """ - Function to mark a default argument of a plan method as a reference to a device - that is stored in the Blueapi context. - Bypasses mypy linting, returning x as Any and therefore valid as a default - argument. - - Args: - name (str): Name of a device to be fetched from the Blueapi context - - Returns: - Any: name but without typing checking, valid as any default type - - """ - - return name diff --git a/src/blueapi/data_management/__init__.py b/src/blueapi/data_management/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/blueapi/data_management/gda_directory_provider.py b/src/blueapi/data_management/gda_directory_provider.py new file mode 100644 index 000000000..f966a3d7a --- /dev/null +++ b/src/blueapi/data_management/gda_directory_provider.py @@ -0,0 +1,128 @@ +import logging +from abc import ABC, abstractmethod +from pathlib import Path +from typing import Optional + +from aiohttp import ClientSession +from ophyd_async.core import DirectoryInfo, DirectoryProvider +from pydantic import BaseModel + + +class DataCollectionIdentifier(BaseModel): + collectionNumber: int + + +class VisitServiceClientBase(ABC): + """ + Object responsible for I/O in determining collection number + """ + + @abstractmethod + async def create_new_collection(self) -> DataCollectionIdentifier: + ... + + @abstractmethod + async def get_current_collection(self) -> DataCollectionIdentifier: + ... + + +class VisitServiceClient(VisitServiceClientBase): + _url: str + + def __init__(self, url: str) -> None: + self._url = url + + async def create_new_collection(self) -> DataCollectionIdentifier: + async with ClientSession() as session: + async with session.post(f"{self._url}/numtracker") as response: + if response.status == 200: + json = await response.json() + return DataCollectionIdentifier.parse_obj(json) + else: + raise Exception(response.status) + + async def get_current_collection(self) -> DataCollectionIdentifier: + async with ClientSession() as session: + async with session.get(f"{self._url}/numtracker") as response: + if response.status == 200: + json = await response.json() + return DataCollectionIdentifier.parse_obj(json) + else: + raise Exception(response.status) + + +class LocalVisitServiceClient(VisitServiceClientBase): + _count: int + + def __init__(self) -> None: + self._count = 0 + + async def create_new_collection(self) -> DataCollectionIdentifier: + self._count += 1 + return DataCollectionIdentifier(collectionNumber=self._count) + + async def get_current_collection(self) -> DataCollectionIdentifier: + return DataCollectionIdentifier(collectionNumber=self._count) + + +class VisitDirectoryProvider(DirectoryProvider): + """ + Gets information from a remote service to construct the path that detectors + should write to, and determine how their files should be named. + """ + + _data_group_name: str + _data_directory: Path + + _client: VisitServiceClientBase + _current_collection: Optional[DirectoryInfo] + _session: Optional[ClientSession] + + def __init__( + self, + data_group_name: str, + data_directory: Path, + client: VisitServiceClientBase, + ): + self._data_group_name = data_group_name + self._data_directory = data_directory + self._client = client + + self._current_collection = None + self._session = None + + async def update(self) -> None: + """ + Calls the visit service to create a new data collection in the current visit. + """ + # TODO: After visit service is more feature complete: + # TODO: Allow selecting visit as part of the request to BlueAPI + # TODO: Consume visit information from BlueAPI and pass down to this class + # TODO: Query visit service to get information about visit and data collection + # TODO: Use AuthN information as part of verification with visit service + + try: + collection_id_info = await self._client.create_new_collection() + self._current_collection = self._generate_directory_info(collection_id_info) + except Exception as ex: + # TODO: The catch all is needed because the RunEngine will not + # currently handle it, see + # https://github.com/bluesky/bluesky/pull/1623 + self._current_collection = None + logging.exception(ex) + + def _generate_directory_info( + self, + collection_id_info: DataCollectionIdentifier, + ) -> DirectoryInfo: + collection_id = collection_id_info.collectionNumber + file_prefix = f"{self._data_group_name}-{collection_id}" + return DirectoryInfo(str(self._data_directory), file_prefix) + + def __call__(self) -> DirectoryInfo: + if self._current_collection is not None: + return self._current_collection + else: + raise ValueError( + "No current collection, update() needs to be called at least once" + ) diff --git a/src/blueapi/preprocessors/attach_metadata.py b/src/blueapi/preprocessors/attach_metadata.py new file mode 100644 index 000000000..0b7a0306a --- /dev/null +++ b/src/blueapi/preprocessors/attach_metadata.py @@ -0,0 +1,72 @@ +import bluesky.plan_stubs as bps +from bluesky.utils import make_decorator + +from blueapi.core import MsgGenerator +from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider + +DATA_SESSION = "data_session" +DATA_GROUPS = "data_groups" + + +def attach_metadata( + plan: MsgGenerator, + provider: VisitDirectoryProvider, +) -> MsgGenerator: + """ + Attach data session metadata to the runs within a plan and make it correlate + with an ophyd-async DirectoryProvider. + + This wrapper is meant to ensure (on a best-effort basis) that detectors write + their data to the same place for a given run, and that their writings are + tied together in the run via the data_session metadata keyword in the run + start document. + + The wrapper groups data by staging and bundles it with runs as best it can. + Since staging is inherently decoupled from runs this is done on a best-effort + basis. In the following sequence of messages: + + |stage|, stage, |open_run|, close_run, unstage, unstage, |stage|, stage, + |open_run|, close_run, unstage, unstage + + A new group is created at each |stage| and bundled into the start document + at each |open_run|. + + Args: + plan: The plan to preprocess + provider: The directory provider that participating detectors are aware of. + + Returns: + MsgGenerator: A plan + + Yields: + Iterator[Msg]: Plan messages + """ + + group_in_progress = False + + for message in plan: + # If the first stage in a series of stages is detected, + # update the directory provider and create a new group. + if (message.command == "stage") and (not group_in_progress): + yield from bps.wait_for([provider.update]) + group_in_progress = True + # Mark if detectors are being unstaged so that the start + # of the next sequence of stages is detectable. + elif message.command == "unstage": + group_in_progress = False + + # If a run is being opened, attempt to bundle the information + # on any existing group into the start document. + if message.command == "open_run": + # Handle the case where we're opening a run but no detectors + # have been staged yet. Common for nested runs. + if not group_in_progress: + yield from bps.wait_for([provider.update]) + directory_info = provider() + message.kwargs[DATA_SESSION] = directory_info.filename_prefix + + # This is a preprocessor so we yield the original message. + yield message + + +attach_metadata_decorator = make_decorator(attach_metadata) diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index 484e17aae..ed63b0e60 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -4,8 +4,15 @@ from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext from blueapi.core.event import EventStream +from blueapi.data_management.gda_directory_provider import ( + LocalVisitServiceClient, + VisitDirectoryProvider, + VisitServiceClient, + VisitServiceClientBase, +) from blueapi.messaging import StompMessagingTemplate from blueapi.messaging.base import MessagingTemplate +from blueapi.preprocessors.attach_metadata import attach_metadata from blueapi.worker.reworker import RunEngineWorker from blueapi.worker.worker import Worker @@ -80,7 +87,48 @@ def setup_handler( config: Optional[ApplicationConfig] = None, ) -> None: global HANDLER - handler = Handler(config) + + provider = None + plan_wrappers = [] + + if config: + visit_service_client: VisitServiceClientBase + if config.env.data_writing.visit_service_url is not None: + visit_service_client = VisitServiceClient( + config.env.data_writing.visit_service_url + ) + else: + visit_service_client = LocalVisitServiceClient() + + provider = VisitDirectoryProvider( + data_group_name=config.env.data_writing.group_name, + data_directory=config.env.data_writing.visit_directory, + client=visit_service_client, + ) + + # Make all dodal devices created by the context use provider if they can + try: + from dodal.parameters.gda_directory_provider import ( + set_directory_provider_singleton, + ) + + set_directory_provider_singleton(provider) + except ImportError: + logging.error( + "Unable to set directory provider for ophyd-async devices, " + "a newer version of dodal is required" + ) + + plan_wrappers.append(lambda plan: attach_metadata(plan, provider)) + + handler = Handler( + config, + context=BlueskyContext( + plan_wrappers=plan_wrappers, + directory_provider=provider, + sim=False, + ), + ) handler.start() HANDLER = handler diff --git a/src/blueapi/startup/example_plans.py b/src/blueapi/startup/example_plans.py index e763f297b..6ec7aa99a 100644 --- a/src/blueapi/startup/example_plans.py +++ b/src/blueapi/startup/example_plans.py @@ -1,10 +1,11 @@ from typing import List from bluesky.protocols import Movable, Readable +from dls_bluesky_core.core import inject from dls_bluesky_core.plans import count from dls_bluesky_core.stubs import move -from blueapi.core import MsgGenerator, inject +from blueapi.core import MsgGenerator def stp_snapshot( diff --git a/src/blueapi/worker/task.py b/src/blueapi/worker/task.py index d09add5c7..d1b4de07d 100644 --- a/src/blueapi/worker/task.py +++ b/src/blueapi/worker/task.py @@ -29,7 +29,8 @@ def do_task(self, ctx: BlueskyContext) -> None: func = ctx.plan_functions[self.name] prepared_params = self._ensure_params(ctx) plan_generator = func(**prepared_params.dict()) - ctx.run_engine(plan_generator) + wrapped_plan_generator = ctx.wrap(plan_generator) + ctx.run_engine(wrapped_plan_generator) def _ensure_params(self, ctx: BlueskyContext) -> BaseModel: if self._prepared_params is None: diff --git a/tests/conftest.py b/tests/conftest.py index 07bc4e795..362a91d01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,12 @@ +import asyncio + # Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501 from typing import Iterator from unittest.mock import MagicMock import pytest -from bluesky.run_engine import RunEngineStateMachine +from bluesky import RunEngine +from bluesky.run_engine import RunEngineStateMachine, TransitionError from fastapi.testclient import TestClient from blueapi.service.handler import Handler, get_handler @@ -43,8 +46,28 @@ def client(self) -> TestClient: return TestClient(app) +@pytest.fixture(scope="function") +def RE(request): + loop = asyncio.new_event_loop() + loop.set_debug(True) + RE = RunEngine({}, call_returns_result=True, loop=loop) + + def clean_event_loop(): + if RE.state not in ("idle", "panicked"): + try: + RE.halt() + except TransitionError: + pass + loop.call_soon_threadsafe(loop.stop) + RE._th.join() + loop.close() + + request.addfinalizer(clean_event_loop) + return RE + + @pytest.fixture -def handler() -> Iterator[Handler]: +def handler(RE: RunEngine) -> Iterator[Handler]: context: BlueskyContext = BlueskyContext(run_engine=MagicMock()) context.run_engine.state = RunEngineStateMachine.States.IDLE handler = Handler(context=context, messaging_template=MagicMock()) diff --git a/tests/core/test_context.py b/tests/core/test_context.py index b3a95d1ab..b42dd7594 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1,16 +1,15 @@ from __future__ import annotations from typing import Dict, List, Type, Union -from unittest.mock import patch import pytest from bluesky.protocols import Descriptor, Movable, Readable, Reading, SyncOrAsync -from dls_bluesky_core.core import MsgGenerator, PlanGenerator +from dls_bluesky_core.core import MsgGenerator, PlanGenerator, inject from ophyd.sim import SynAxis, SynGauss from pydantic import ValidationError, parse_obj_as from blueapi.config import EnvironmentConfig, Source, SourceKind -from blueapi.core import BlueskyContext, inject, is_bluesky_compatible_device +from blueapi.core import BlueskyContext, is_bluesky_compatible_device from blueapi.core.context import DefaultFactory SIM_MOTOR_NAME = "sim" @@ -172,21 +171,6 @@ def test_add_devices_from_module(empty_context: BlueskyContext) -> None: } == empty_context.devices.keys() -def test_extra_kwargs_in_with_dodal_module_passed_to_make_all_devices( - empty_context: BlueskyContext, -) -> None: - import tests.core.fake_device_module as device_module - - with patch("dodal.utils.make_all_devices") as mock_make_all_devices: - empty_context.with_dodal_module( - device_module, some_argument=1, another_argument="two" - ) - - mock_make_all_devices.assert_called_once_with( - device_module, some_argument=1, another_argument="two" - ) - - @pytest.mark.parametrize( "addr", ["sim", "sim_det", "sim.setpoint", ["sim"], ["sim", "setpoint"]] ) diff --git a/tests/data_writing/__init__.py b/tests/data_writing/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/data_writing/test_gda_directory_provider.py b/tests/data_writing/test_gda_directory_provider.py new file mode 100644 index 000000000..10dd76d08 --- /dev/null +++ b/tests/data_writing/test_gda_directory_provider.py @@ -0,0 +1,66 @@ +from pathlib import Path + +import pytest +from ophyd_async.core import DirectoryInfo + +from blueapi.data_management.gda_directory_provider import ( + DataCollectionIdentifier, + LocalVisitServiceClient, + VisitDirectoryProvider, + VisitServiceClientBase, +) + + +@pytest.fixture +def visit_service_client() -> VisitServiceClientBase: + return LocalVisitServiceClient() + + +@pytest.fixture +def visit_directory_provider( + visit_service_client: VisitServiceClientBase, +) -> VisitDirectoryProvider: + return VisitDirectoryProvider("example", Path("/tmp"), visit_service_client) + + +@pytest.mark.asyncio +async def test_client_can_view_collection( + visit_service_client: VisitServiceClientBase, +) -> None: + collection = await visit_service_client.get_current_collection() + assert collection == DataCollectionIdentifier(collectionNumber=0) + + +@pytest.mark.asyncio +async def test_client_can_create_collection( + visit_service_client: VisitServiceClientBase, +) -> None: + collection = await visit_service_client.create_new_collection() + assert collection == DataCollectionIdentifier(collectionNumber=1) + + +@pytest.mark.asyncio +async def test_update_sets_collection_number( + visit_directory_provider: VisitDirectoryProvider, +) -> None: + await visit_directory_provider.update() + assert visit_directory_provider() == DirectoryInfo( + directory_path="/tmp", + filename_prefix="example-1", + ) + + +@pytest.mark.asyncio +async def test_update_sets_collection_number_multi( + visit_directory_provider: VisitDirectoryProvider, +) -> None: + await visit_directory_provider.update() + assert visit_directory_provider() == DirectoryInfo( + directory_path="/tmp", + filename_prefix="example-1", + ) + await visit_directory_provider.update() + assert visit_directory_provider() == DirectoryInfo( + directory_path="/tmp", + filename_prefix="example-2", + ) diff --git a/tests/preprocessors/__init__.py b/tests/preprocessors/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/preprocessors/test_attach_metadata.py b/tests/preprocessors/test_attach_metadata.py new file mode 100644 index 000000000..a7ba1b1de --- /dev/null +++ b/tests/preprocessors/test_attach_metadata.py @@ -0,0 +1,389 @@ +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping + +import bluesky.plan_stubs as bps +import bluesky.plans as bp +import pytest +from bluesky import RunEngine +from bluesky.preprocessors import ( + run_decorator, + run_wrapper, + set_run_key_decorator, + set_run_key_wrapper, + stage_wrapper, +) +from bluesky.protocols import HasName, Readable, Reading, Status, Triggerable +from event_model.documents.event_descriptor import DataKey +from ophyd.status import StatusBase +from ophyd_async.core import DirectoryProvider + +from blueapi.core import DataEvent, MsgGenerator +from blueapi.data_management.gda_directory_provider import ( + DataCollectionIdentifier, + VisitDirectoryProvider, + VisitServiceClient, +) +from blueapi.preprocessors.attach_metadata import DATA_SESSION, attach_metadata + +DATA_DIRECTORY = Path("/tmp") +DATA_GROUP_NAME = "test" + + +RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-0" +RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-1" +RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-2" + + +class MockVisitServiceClient(VisitServiceClient): + _count: int + _fail: bool + + def __init__(self) -> None: + super().__init__("http://example.com") + self._count = 0 + self._fail = False + + def always_fail(self) -> None: + self._fail = True + + async def create_new_collection(self) -> DataCollectionIdentifier: + if self._fail: + raise ConnectionError() + + count = self._count + self._count += 1 + return DataCollectionIdentifier(collectionNumber=count) + + async def get_current_collection(self) -> DataCollectionIdentifier: + if self._fail: + raise ConnectionError() + + return DataCollectionIdentifier(collectionNumber=self._count) + + +@pytest.fixture +def client() -> VisitServiceClient: + return MockVisitServiceClient() + + +@pytest.fixture +def provider(client: VisitServiceClient) -> VisitDirectoryProvider: + return VisitDirectoryProvider( + data_directory=DATA_DIRECTORY, + data_group_name=DATA_GROUP_NAME, + client=client, + ) + + +@pytest.fixture +def run_engine() -> RunEngine: + return RunEngine() + + +class FakeDetector(Readable, HasName, Triggerable): + _name: str + _provider: DirectoryProvider + + def __init__( + self, + name: str, + provider: DirectoryProvider, + ) -> None: + self._name = name + self._provider = provider + + async def read(self) -> Dict[str, Reading]: + return { + f"{self.name}_data": { + "value": "test", + "timestamp": 0.0, + }, + } + + async def describe(self) -> Dict[str, DataKey]: + directory_info = self._provider() + path = f"{directory_info.directory_path}/{directory_info.filename_prefix}" + return { + f"{self.name}_data": { + "dtype": "string", + "shape": [1], + "source": path, + } + } + + def trigger(self) -> Status: + status = StatusBase() + status.set_finished() + return status + + @property + def name(self) -> str: + return self._name + + @property + def parent(self) -> None: + return None + + +@pytest.fixture(params=[1, 2]) +def detectors(request, provider: VisitDirectoryProvider) -> List[Readable]: + number_of_detectors = request.param + return [ + FakeDetector( + name=f"test_detector_{i}", + provider=provider, + ) + for i in range(number_of_detectors) + ] + + +def simple_run(detectors: List[Readable]) -> MsgGenerator: + yield from bp.count(detectors) + + +def multi_run(detectors: List[Readable]) -> MsgGenerator: + yield from bp.count(detectors) + yield from bp.count(detectors) + + +def multi_nested_plan(detectors: List[Readable]) -> MsgGenerator: + yield from simple_run(detectors) + yield from simple_run(detectors) + + +def multi_run_single_stage(detectors: List[Readable]) -> MsgGenerator: + def stageless_count() -> MsgGenerator: + return (yield from bps.one_shot(detectors)) + + def inner_plan() -> MsgGenerator: + yield from run_wrapper(stageless_count()) + yield from run_wrapper(stageless_count()) + + yield from stage_wrapper(inner_plan(), detectors) + + +def multi_run_single_stage_multi_group( + detectors: List[Readable], +) -> MsgGenerator: + def stageless_count() -> MsgGenerator: + return (yield from bps.one_shot(detectors)) + + def inner_plan() -> MsgGenerator: + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) + + yield from stage_wrapper(inner_plan(), detectors) + + +@run_decorator(md={DATA_SESSION: 12345}) +@set_run_key_decorator("outer") +def nested_run_with_metadata(detectors: List[Readable]) -> MsgGenerator: + yield from set_run_key_wrapper(bp.count(detectors), "inner") + yield from set_run_key_wrapper(bp.count(detectors), "inner") + + +@run_decorator() +@set_run_key_decorator("outer") +def nested_run_without_metadata( + detectors: List[Readable], +) -> MsgGenerator: + yield from set_run_key_wrapper(bp.count(detectors), "inner") + yield from set_run_key_wrapper(bp.count(detectors), "inner") + + +def test_simple_run_gets_scan_number( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + assert docs[0].name == "start" + assert docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0]) + + +@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) +def test_multi_run_gets_scan_numbers( + run_engine: RunEngine, + detectors: List[Readable], + plan: Callable[[List[Readable]], MsgGenerator], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + plan(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 2 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_1]) + + +def test_multi_run_single_stage( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + multi_run_single_stage(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 2 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers( + docs, + detectors, + [ + RUN_0, + RUN_0, + ], + ) + + +def test_multi_run_single_stage_multi_group( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + multi_run_single_stage_multi_group(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 4 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[3].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers( + docs, + detectors, + [ + RUN_0, + RUN_0, + RUN_0, + RUN_0, + ], + ) + + +def test_nested_run_with_metadata( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + nested_run_with_metadata(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 3 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-2" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_1, RUN_2]) + + +def test_nested_run_without_metadata( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + nested_run_without_metadata(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 3 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-2" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_1, RUN_2]) + + +def test_visit_directory_provider_fails( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, + client: MockVisitServiceClient, +) -> None: + client.always_fail() + with pytest.raises(ValueError): + collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + + +def test_visit_directory_provider_fails_after_one_sucess( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, + client: MockVisitServiceClient, +) -> None: + collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + client.always_fail() + with pytest.raises(ValueError): + collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + + +def collect_docs( + run_engine: RunEngine, + plan: MsgGenerator, + provider: DirectoryProvider, +) -> List[DataEvent]: + events = [] + + def on_event(name: str, doc: Mapping[str, Any]) -> None: + events.append(DataEvent(name=name, doc=doc)) + + wrapped_plan = attach_metadata(plan, provider) + run_engine(wrapped_plan, on_event) + return events + + +def assert_all_detectors_used_collection_numbers( + docs: List[DataEvent], + detectors: List[Readable], + source_history: List[Path], +) -> None: + descriptors = find_descriptor_docs(docs) + assert len(descriptors) == len(source_history) + + for descriptor, expected_source in zip(descriptors, source_history): + for detector in detectors: + source = descriptor.doc.get("data_keys", {}).get(f"{detector.name}_data")[ + "source" + ] + assert Path(source) == expected_source + + +def find_start_docs(docs: List[DataEvent]) -> List[DataEvent]: + return list(filter(lambda event: event.name == "start", docs)) + + +def find_descriptor_docs(docs: List[DataEvent]) -> List[DataEvent]: + return list(filter(lambda event: event.name == "descriptor", docs))