From 71a86a04eb0a36706cfa38ac1acd779907edc083 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Mon, 23 Oct 2023 18:12:33 +0100 Subject: [PATCH] Begin fixing attach metadata tests --- src/blueapi/preprocessors/attach_metadata.py | 9 +- src/blueapi/service/handler.py | 7 +- tests/plugins/file_writing_detector.py | 27 -- tests/plugins/test_data_writing.py | 251 -------------- tests/{plugins => preprocessors}/__init__.py | 0 tests/preprocessors/test_attach_metadata.py | 343 +++++++++++++++++++ 6 files changed, 353 insertions(+), 284 deletions(-) delete mode 100644 tests/plugins/file_writing_detector.py delete mode 100644 tests/plugins/test_data_writing.py rename tests/{plugins => preprocessors}/__init__.py (100%) create mode 100644 tests/preprocessors/test_attach_metadata.py diff --git a/src/blueapi/preprocessors/attach_metadata.py b/src/blueapi/preprocessors/attach_metadata.py index fb7cd47bd..e11f5596a 100644 --- a/src/blueapi/preprocessors/attach_metadata.py +++ b/src/blueapi/preprocessors/attach_metadata.py @@ -11,9 +11,8 @@ def attach_metadata( - data_groups: List[str], - provider: VisitDirectoryProvider, plan: MsgGenerator, + provider: VisitDirectoryProvider, ) -> MsgGenerator: """Updates a directory provider default location for file storage.""" staging = False @@ -26,10 +25,12 @@ def attach_metadata( if (message.command == "stage") and (not staging and will_write_data): yield from bps.wait_for([provider.update]) staging = True + elif message.command == "unstage": + staging = False if message.command == "open_run": - message.kwargs[DATA_SESSION] = provider().filename_prefix - message.kwargs[DATA_GROUPS] = data_groups + directory_info = provider() + message.kwargs[DATA_SESSION] = directory_info.filename_prefix yield message diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index c6bc04c7a..de0d680f8 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -2,7 +2,10 @@ from functools import partial from typing import Mapping, Optional -from dodal.parameters.gda_directory_provider import VisitDirectoryProvider +from dodal.parameters.gda_directory_provider import ( + VisitDirectoryProvider, + VisitServiceClient, +) from blueapi.config import ApplicationConfig from blueapi.core import BlueskyContext @@ -90,9 +93,9 @@ def setup_handler( if config: provider = VisitDirectoryProvider( - url=config.env.data_writing.visit_service_url, data_group_name=config.env.data_writing.group_name, data_directory=config.env.data_writing.visit_directory, + client=VisitServiceClient(config.env.data_writing.visit_service_url), ) # Make all dodal devices created by the context use provider if they can diff --git a/tests/plugins/file_writing_detector.py b/tests/plugins/file_writing_detector.py deleted file mode 100644 index b95aa63c9..000000000 --- a/tests/plugins/file_writing_detector.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import List - -from ophyd import Component, Device, Signal - -from blueapi.plugins.data_writing import DataCollectionProvider - - -class FakeFileWritingDetector(Device): - image_count: Signal = Component(Signal, value=0, kind="hinted") - collection_number: Signal = Component(Signal, value=0, kind="config") - - _provider: DataCollectionProvider - - def __init__(self, name: str, provider: DataCollectionProvider, **kwargs): - super().__init__(name=name, **kwargs) - self.stage_sigs[self.image_count] = 0 - self._provider = provider - - def trigger(self, *args, **kwargs): - return self.image_count.set(self.image_count.get() + 1) - - def stage(self) -> List[object]: - collection = self._provider.current_data_collection - if collection is None: - raise Exception("No active collection") - self.stage_sigs[self.collection_number] = collection.collection_number - return super().stage() diff --git a/tests/plugins/test_data_writing.py b/tests/plugins/test_data_writing.py deleted file mode 100644 index 62ce20316..000000000 --- a/tests/plugins/test_data_writing.py +++ /dev/null @@ -1,251 +0,0 @@ -from typing import Any, Callable, List, Mapping - -import bluesky.plan_stubs as bps -import bluesky.plans as bp -import pytest -from bluesky import RunEngine -from bluesky.preprocessors import ( - run_decorator, - run_wrapper, - set_run_key_decorator, - set_run_key_wrapper, - stage_wrapper, -) - -from blueapi.core import DataEvent, MsgGenerator -from blueapi.plugins.data_writing import ( - DATA_SESSION, - DataCollectionProvider, - InMemoryDataCollectionProvider, - data_writing_wrapper, -) - -from .file_writing_detector import FakeFileWritingDetector - - -@pytest.fixture -def provider() -> DataCollectionProvider: - return InMemoryDataCollectionProvider("example") - - -@pytest.fixture -def run_engine() -> RunEngine: - return RunEngine() - - -@pytest.fixture(params=[1, 2]) -def detectors( - request, provider: DataCollectionProvider -) -> List[FakeFileWritingDetector]: - number_of_detectors = request.param - return [ - FakeFileWritingDetector( - name=f"test_detector_{i}", - provider=provider, - ) - for i in range(number_of_detectors) - ] - - -def simple_run(detectors: List[FakeFileWritingDetector]) -> MsgGenerator: - yield from bp.count(detectors) - - -def multi_run(detectors: List[FakeFileWritingDetector]) -> MsgGenerator: - yield from bp.count(detectors) - yield from bp.count(detectors) - - -def multi_nested_plan(detectors: List[FakeFileWritingDetector]) -> MsgGenerator: - yield from simple_run(detectors) - yield from simple_run(detectors) - - -def multi_run_single_stage(detectors: List[FakeFileWritingDetector]) -> MsgGenerator: - def stageless_count() -> MsgGenerator: - return (yield from bps.one_shot(detectors)) - - def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count()) - yield from run_wrapper(stageless_count()) - - yield from stage_wrapper(inner_plan(), detectors) - - -def multi_run_single_stage_multi_group( - detectors: List[FakeFileWritingDetector], -) -> MsgGenerator: - def stageless_count() -> MsgGenerator: - return (yield from bps.one_shot(detectors)) - - def inner_plan() -> MsgGenerator: - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) - yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) - - yield from stage_wrapper(inner_plan(), detectors) - - -@run_decorator(md={DATA_SESSION: 12345}) -@set_run_key_decorator("outer") -def nested_run_with_metadata(detectors: List[FakeFileWritingDetector]) -> MsgGenerator: - yield from set_run_key_wrapper(bp.count(detectors), "inner") - yield from set_run_key_wrapper(bp.count(detectors), "inner") - - -@run_decorator() -@set_run_key_decorator("outer") -def nested_run_without_metadata( - detectors: List[FakeFileWritingDetector], -) -> MsgGenerator: - yield from set_run_key_wrapper(bp.count(detectors), "inner") - yield from set_run_key_wrapper(bp.count(detectors), "inner") - - -def test_simple_run_gets_scan_number( - run_engine: RunEngine, - detectors: List[FakeFileWritingDetector], - provider: DataCollectionProvider, -) -> None: - docs = collect_docs( - run_engine, - simple_run(detectors), - provider, - ) - assert docs[0].name == "start" - assert docs[0].doc[DATA_SESSION] == 0 - assert_all_detectors_used_collection_numbers(docs, detectors, [0]) - - -@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) -def test_multi_run_gets_scan_numbers( - run_engine: RunEngine, - detectors: List[FakeFileWritingDetector], - plan: Callable[[List[FakeFileWritingDetector]], MsgGenerator], - provider: DataCollectionProvider, -) -> None: - docs = collect_docs( - run_engine, - plan(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 2 - assert start_docs[0].doc[DATA_SESSION] == 0 - assert start_docs[1].doc[DATA_SESSION] == 1 - assert_all_detectors_used_collection_numbers(docs, detectors, [0, 1]) - - -def test_multi_run_single_stage( - run_engine: RunEngine, - detectors: List[FakeFileWritingDetector], - provider: DataCollectionProvider, -) -> None: - docs = collect_docs( - run_engine, - multi_run_single_stage(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 2 - assert start_docs[0].doc[DATA_SESSION] == 0 - assert start_docs[1].doc[DATA_SESSION] == 0 - assert_all_detectors_used_collection_numbers(docs, detectors, [0, 0]) - - -def test_multi_run_single_stage_multi_group( - run_engine: RunEngine, - detectors: List[FakeFileWritingDetector], - provider: DataCollectionProvider, -) -> None: - docs = collect_docs( - run_engine, - multi_run_single_stage_multi_group(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 4 - assert start_docs[0].doc[DATA_SESSION] == 0 - assert start_docs[1].doc[DATA_SESSION] == 0 - assert start_docs[2].doc[DATA_SESSION] == 0 - assert start_docs[3].doc[DATA_SESSION] == 0 - assert_all_detectors_used_collection_numbers(docs, detectors, [0, 0, 0, 0]) - - -def test_nested_run_with_metadata( - run_engine: RunEngine, - detectors: List[FakeFileWritingDetector], - provider: DataCollectionProvider, -) -> None: - docs = collect_docs( - run_engine, - nested_run_with_metadata(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 3 - assert start_docs[0].doc[DATA_SESSION] == 0 - assert start_docs[1].doc[DATA_SESSION] == 1 - assert start_docs[2].doc[DATA_SESSION] == 2 - assert_all_detectors_used_collection_numbers(docs, detectors, [1, 2]) - - -def test_nested_run_without_metadata( - run_engine: RunEngine, - detectors: List[FakeFileWritingDetector], - provider: DataCollectionProvider, -) -> None: - docs = collect_docs( - run_engine, - nested_run_without_metadata(detectors), - provider, - ) - start_docs = find_start_docs(docs) - assert len(start_docs) == 3 - assert start_docs[0].doc[DATA_SESSION] == 0 - assert start_docs[1].doc[DATA_SESSION] == 1 - assert start_docs[2].doc[DATA_SESSION] == 2 - assert_all_detectors_used_collection_numbers(docs, detectors, [1, 2]) - - -def collect_docs( - run_engine: RunEngine, - plan: MsgGenerator, - provider: DataCollectionProvider, -) -> List[DataEvent]: - events = [] - - def on_event(name: str, doc: Mapping[str, Any]) -> None: - events.append(DataEvent(name=name, doc=doc)) - - wrapped_plan = data_writing_wrapper(plan, provider) - run_engine(wrapped_plan, on_event) - return events - - -def assert_all_detectors_used_collection_numbers( - docs: List[DataEvent], - detectors: List[FakeFileWritingDetector], - collection_number_history: List[int], -) -> None: - descriptors = find_descriptor_docs(docs) - assert len(descriptors) == len(collection_number_history) - - for descriptor, collection_number in zip(descriptors, collection_number_history): - for detector in detectors: - attr_name = f"{detector.name}_collection_number" - actual_collection_number = ( - descriptor.doc.get("configuration", {}) - .get(detector.name, {}) - .get("data", {})[attr_name] - ) - assert actual_collection_number == collection_number - - -def find_start_docs(docs: List[DataEvent]) -> List[DataEvent]: - return list(filter(lambda event: event.name == "start", docs)) - - -def find_descriptor_docs(docs: List[DataEvent]) -> List[DataEvent]: - return list(filter(lambda event: event.name == "descriptor", docs)) diff --git a/tests/plugins/__init__.py b/tests/preprocessors/__init__.py similarity index 100% rename from tests/plugins/__init__.py rename to tests/preprocessors/__init__.py diff --git a/tests/preprocessors/test_attach_metadata.py b/tests/preprocessors/test_attach_metadata.py new file mode 100644 index 000000000..ba578e584 --- /dev/null +++ b/tests/preprocessors/test_attach_metadata.py @@ -0,0 +1,343 @@ +from pathlib import Path +from typing import Any, Callable, Dict, List, Mapping + +import bluesky.plan_stubs as bps +import bluesky.plans as bp +import pytest +from bluesky import RunEngine +from bluesky.preprocessors import ( + run_decorator, + run_wrapper, + set_run_key_decorator, + set_run_key_wrapper, + stage_wrapper, +) +from bluesky.protocols import HasName, Readable, Reading, Status, Triggerable +from dodal.parameters.gda_directory_provider import ( + DataCollectionIdentifier, + VisitDirectoryProvider, + VisitServiceClient, +) +from event_model.documents.event_descriptor import DataKey +from ophyd.status import StatusBase +from ophyd_async.core import DirectoryProvider + +from blueapi.core import DataEvent, MsgGenerator +from blueapi.preprocessors.attach_metadata import DATA_SESSION, attach_metadata + +DATA_DIRECTORY = Path("/tmp") +DATA_GROUP_NAME = "test" + + +RUN_0 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-0" +RUN_1 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-1" +RUN_2 = DATA_DIRECTORY / f"{DATA_GROUP_NAME}-2" + + +class MockVisitServiceClient(VisitServiceClient): + _count: int + + def __init__(self) -> None: + super().__init__("http://example.com") + self._count = 0 + + async def create_new_collection(self) -> DataCollectionIdentifier: + count = self._count + self._count += 1 + return DataCollectionIdentifier(collectionNumber=count) + + async def get_current_collection(self) -> DataCollectionIdentifier: + return DataCollectionIdentifier(collectionNumber=self._count) + + +@pytest.fixture +def client() -> VisitServiceClient: + return MockVisitServiceClient() + + +@pytest.fixture +def provider(client: VisitServiceClient) -> VisitDirectoryProvider: + return VisitDirectoryProvider( + data_directory=DATA_DIRECTORY, + data_group_name=DATA_GROUP_NAME, + client=client, + ) + + +@pytest.fixture +def run_engine() -> RunEngine: + return RunEngine() + + +class FakeDetector(Readable, HasName, Triggerable): + _name: str + _provider: DirectoryProvider + + def __init__( + self, + name: str, + provider: DirectoryProvider, + ) -> None: + self._name = name + self._provider = provider + + async def read(self) -> Dict[str, Reading]: + return { + f"{self.name}_data": { + "value": "test", + "timestamp": 0.0, + }, + } + + async def describe(self) -> Dict[str, DataKey]: + directory_info = self._provider() + path = f"{directory_info.directory_path}/{directory_info.filename_prefix}" + return { + f"{self.name}_data": { + "dtype": "string", + "shape": [1], + "source": path, + } + } + + def trigger(self) -> Status: + status = StatusBase() + status.set_finished() + return status + + @property + def name(self) -> str: + return self._name + + @property + def parent(self) -> None: + return None + + +@pytest.fixture(params=[1, 2]) +def detectors(request, provider: VisitDirectoryProvider) -> List[Readable]: + number_of_detectors = request.param + return [ + FakeDetector( + name=f"test_detector_{i}", + provider=provider, + ) + for i in range(number_of_detectors) + ] + + +def simple_run(detectors: List[Readable]) -> MsgGenerator: + yield from bp.count(detectors) + + +def multi_run(detectors: List[Readable]) -> MsgGenerator: + yield from bp.count(detectors) + yield from bp.count(detectors) + + +def multi_nested_plan(detectors: List[Readable]) -> MsgGenerator: + yield from simple_run(detectors) + yield from simple_run(detectors) + + +def multi_run_single_stage(detectors: List[Readable]) -> MsgGenerator: + def stageless_count() -> MsgGenerator: + return (yield from bps.one_shot(detectors)) + + def inner_plan() -> MsgGenerator: + yield from run_wrapper(stageless_count()) + yield from run_wrapper(stageless_count()) + + yield from stage_wrapper(inner_plan(), detectors) + + +def multi_run_single_stage_multi_group( + detectors: List[Readable], +) -> MsgGenerator: + def stageless_count() -> MsgGenerator: + return (yield from bps.one_shot(detectors)) + + def inner_plan() -> MsgGenerator: + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 1}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) + yield from run_wrapper(stageless_count(), md={DATA_SESSION: 2}) + + yield from stage_wrapper(inner_plan(), detectors) + + +@run_decorator(md={DATA_SESSION: 12345}) +@set_run_key_decorator("outer") +def nested_run_with_metadata(detectors: List[Readable]) -> MsgGenerator: + yield from set_run_key_wrapper(bp.count(detectors), "inner") + yield from set_run_key_wrapper(bp.count(detectors), "inner") + + +@run_decorator() +@set_run_key_decorator("outer") +def nested_run_without_metadata( + detectors: List[Readable], +) -> MsgGenerator: + yield from set_run_key_wrapper(bp.count(detectors), "inner") + yield from set_run_key_wrapper(bp.count(detectors), "inner") + + +def test_simple_run_gets_scan_number( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + simple_run(detectors), + provider, + ) + assert docs[0].name == "start" + assert docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0]) + + +@pytest.mark.parametrize("plan", [multi_run, multi_nested_plan]) +def test_multi_run_gets_scan_numbers( + run_engine: RunEngine, + detectors: List[Readable], + plan: Callable[[List[Readable]], MsgGenerator], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + plan(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 2 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_0, RUN_1]) + + +def test_multi_run_single_stage( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + multi_run_single_stage(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 2 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers( + docs, + detectors, + [ + RUN_0, + RUN_0, + ], + ) + + +def test_multi_run_single_stage_multi_group( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + multi_run_single_stage_multi_group(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 4 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[3].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers( + docs, + detectors, + [ + RUN_0, + RUN_0, + RUN_0, + RUN_0, + ], + ) + + +def test_nested_run_with_metadata( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + nested_run_with_metadata(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 3 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-1" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-2" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_1, RUN_2]) + + +def test_nested_run_without_metadata( + run_engine: RunEngine, + detectors: List[Readable], + provider: DirectoryProvider, +) -> None: + docs = collect_docs( + run_engine, + nested_run_without_metadata(detectors), + provider, + ) + start_docs = find_start_docs(docs) + assert len(start_docs) == 3 + assert start_docs[0].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[1].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert start_docs[2].doc[DATA_SESSION] == f"{DATA_GROUP_NAME}-0" + assert_all_detectors_used_collection_numbers(docs, detectors, [RUN_1, RUN_2]) + + +def collect_docs( + run_engine: RunEngine, + plan: MsgGenerator, + provider: DirectoryProvider, +) -> List[DataEvent]: + events = [] + + def on_event(name: str, doc: Mapping[str, Any]) -> None: + events.append(DataEvent(name=name, doc=doc)) + + wrapped_plan = attach_metadata(plan, provider) + run_engine(wrapped_plan, on_event) + return events + + +def assert_all_detectors_used_collection_numbers( + docs: List[DataEvent], + detectors: List[Readable], + source_history: List[Path], +) -> None: + descriptors = find_descriptor_docs(docs) + assert len(descriptors) == len(source_history) + + for descriptor, expected_source in zip(descriptors, source_history): + for detector in detectors: + source = descriptor.doc.get("data_keys", {}).get(f"{detector.name}_data")[ + "source" + ] + assert Path(source) == expected_source + + +def find_start_docs(docs: List[DataEvent]) -> List[DataEvent]: + return list(filter(lambda event: event.name == "start", docs)) + + +def find_descriptor_docs(docs: List[DataEvent]) -> List[DataEvent]: + return list(filter(lambda event: event.name == "descriptor", docs))