From 72413289736d59c15912337c38d0982c6aa56598 Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Thu, 26 Oct 2023 10:48:09 +0100 Subject: [PATCH] Bulk out ophyd async connect logic --- src/blueapi/core/context.py | 25 ++-- src/blueapi/utils/__init__.py | 2 + src/blueapi/utils/ophyd_async_connect.py | 49 +++++++ tests/utils/test_ophyd_async_connect.py | 173 +++++++++++++++++++++++ 4 files changed, 235 insertions(+), 14 deletions(-) create mode 100644 src/blueapi/utils/ophyd_async_connect.py create mode 100644 tests/utils/test_ophyd_async_connect.py diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 6118a420a..80fdb7377 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -23,14 +23,16 @@ ) from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop -from ophyd_async.core import Device as AsyncDevice -from ophyd_async.core import wait_for_connection from pydantic import create_model from pydantic.fields import FieldInfo, ModelField from blueapi.config import EnvironmentConfig, SourceKind from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider -from blueapi.utils import BlueapiPlanModelConfig, load_module_all +from blueapi.utils import ( + BlueapiPlanModelConfig, + connect_ophyd_async_devices, + load_module_all, +) from .bluesky_types import ( BLUESKY_PROTOCOLS, @@ -105,17 +107,12 @@ def with_config(self, config: EnvironmentConfig) -> None: elif source.kind is SourceKind.DODAL: self.with_dodal_module(mod) - call_in_bluesky_event_loop(self.connect_devices(self.sim)) - - async def connect_devices(self, sim: bool = False) -> None: - coros = {} - for device_name, device in self.devices.items(): - if isinstance(device, AsyncDevice): - device.set_name(device_name) - coros[device_name] = device.connect(sim) - - if len(coros) > 0: - await asyncio.wait(wait_for_connection(**coros), timeout=30.0) + call_in_bluesky_event_loop( + connect_ophyd_async_devices( + self.devices.values(), + self.sim, + ) + ) def with_plan_module(self, module: ModuleType) -> None: """ diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b871f842a..b3c212a51 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,6 +1,7 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig from .invalid_config_error import InvalidConfigError from .modules import load_module_all +from .ophyd_async_connect import connect_ophyd_async_devices from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -13,4 +14,5 @@ "BlueapiModelConfig", "BlueapiPlanModelConfig", "InvalidConfigError", + "connect_ophyd_async_devices", ] diff --git a/src/blueapi/utils/ophyd_async_connect.py b/src/blueapi/utils/ophyd_async_connect.py new file mode 100644 index 000000000..81514d0b1 --- /dev/null +++ b/src/blueapi/utils/ophyd_async_connect.py @@ -0,0 +1,49 @@ +import asyncio +import logging +from contextlib import suppress +from typing import Any, Dict, Iterable + +from ophyd_async.core import DEFAULT_TIMEOUT +from ophyd_async.core import Device as OphydAsyncDevice +from ophyd_async.core import NotConnected + + +async def connect_ophyd_async_devices( + devices: Iterable[Any], + sim: bool = False, + timeout: float = DEFAULT_TIMEOUT, +) -> None: + tasks: Dict[asyncio.Task, str] = {} + for device in devices: + if isinstance(device, OphydAsyncDevice): + task = asyncio.Task(device.connect(sim=sim)) + tasks[task] = device.name + await _wait_for_tasks(tasks, timeout=timeout) + + +async def _wait_for_tasks( + tasks: Dict[asyncio.Task, str], + timeout: float, +): + done, pending = await asyncio.wait(tasks, timeout=timeout) + if pending: + msg = f"{len(pending)} Devices did not connect:" + for t in pending: + t.cancel() + with suppress(Exception): + await t + e = t.exception() + msg += f"\n {tasks[t]}: {type(e).__name__}" + lines = str(e).splitlines() + if len(lines) <= 1: + msg += f": {e}" + else: + msg += "".join(f"\n {line}" for line in lines) + logging.error(msg) + raised = [t for t in done if t.exception()] + if raised: + logging.error(f"{len(raised)} Devices raised an error:") + for t in raised: + logging.exception(f" {tasks[t]}:", exc_info=t.exception()) + if pending or raised: + raise NotConnected("Not all Devices connected") diff --git a/tests/utils/test_ophyd_async_connect.py b/tests/utils/test_ophyd_async_connect.py new file mode 100644 index 000000000..aef42c9b9 --- /dev/null +++ b/tests/utils/test_ophyd_async_connect.py @@ -0,0 +1,173 @@ +import asyncio +import itertools +import logging +from typing import Callable, Iterable, Tuple, Type, cast + +import pytest +from ophyd_async.core import Device, DeviceVector, NotConnected, StandardReadable + +from blueapi.utils import connect_ophyd_async_devices + + +class DummyBaseDevice(Device): + def __init__(self) -> None: + self.connected = False + + async def connect(self, sim=False): + self.connected = True + + +class DummyDeviceThatErrorsWhenConnecting(Device): + async def connect(self, sim: bool = False): + raise IOError("Connection failed") + + +class DummyDeviceThatTimesOutWhenConnecting(StandardReadable): + async def connect(self, sim: bool = False): + try: + await asyncio.Future() + except asyncio.CancelledError: + raise NotConnected("source: foo") + + +class DummyDeviceGroup(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyBaseDevice() + self.child2 = DummyBaseDevice() + self.dict_with_children: DeviceVector[DummyBaseDevice] = DeviceVector( + {123: DummyBaseDevice()} + ) + self.set_name(name) + + +class DummyDeviceGroupThatTimesOut(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyDeviceThatTimesOutWhenConnecting() + self.set_name(name) + + +class DummyDeviceGroupThatErrors(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyDeviceThatErrorsWhenConnecting() + self.set_name(name) + + +class DummyDeviceGroupThatErrorsAndTimesOut(Device): + def __init__(self, name: str) -> None: + self.child1 = DummyDeviceThatErrorsWhenConnecting() + self.child2 = DummyDeviceThatTimesOutWhenConnecting() + self.set_name(name) + + +ALL_DEVICE_CONSTRUCTORS = [ + DummyDeviceThatErrorsWhenConnecting, + DummyDeviceThatTimesOutWhenConnecting, + DummyDeviceGroupThatErrors, + DummyDeviceGroupThatTimesOut, + DummyDeviceGroupThatErrorsAndTimesOut, +] + + +@pytest.mark.parametrize("device_constructor", ALL_DEVICE_CONSTRUCTORS) +async def test_device_collector_propagates_errors_and_timeouts( + device_constructor: Callable[[str], Device] +): + await _assert_failing_device_does_not_connect(device_constructor("test")) + + +@pytest.mark.parametrize( + "device_constructor_1,device_constructor_2", + list(itertools.permutations(ALL_DEVICE_CONSTRUCTORS, 2)), +) +async def test_device_collector_propagates_errors_and_timeouts_from_multiple_devices( + device_constructor_1: Callable[[str], Device], + device_constructor_2: Callable[[str], Device], +): + await _assert_failing_devices_do_not_connect( + [device_constructor_1("test1"), device_constructor_2("test2")] + ) + + +async def test_device_collector_logs_exceptions_for_raised_errors( + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.INFO) + await _assert_failing_device_does_not_connect(DummyDeviceGroupThatErrors) + assert caplog.records[0].message == "1 Devices raised an error:" + assert caplog.records[1].message == " should_fail:" + _assert_exception_type_and_message( + caplog.records[1], + OSError, + "Connection failed", + ) + + +async def test_device_collector_logs_exceptions_for_timeouts( + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.INFO) + await _assert_failing_device_does_not_connect(DummyDeviceGroupThatTimesOut) + assert caplog.records[0].message == "1 Devices did not connect:" + assert caplog.records[1].message == " should_fail:" + _assert_exception_type_and_message( + caplog.records[1], + NotConnected, + "child1: source: foo", + ) + + +async def test_device_collector_logs_exceptions_for_multiple_devices( + caplog: pytest.LogCaptureFixture, +): + caplog.set_level(logging.INFO) + await _assert_failing_devices_do_not_connect( + [ + DummyDeviceGroupThatErrorsAndTimesOut("test1"), + DummyDeviceGroupThatErrors("test2"), + ] + ) + assert caplog.records[0].message == "1 Devices did not connect:" + assert caplog.records[1].message == " should_fail_1:" + _assert_exception_type_and_message( + caplog.records[1], + OSError, + "Connection failed", + ) + assert caplog.records[2].message == "1 Devices raised an error:" + assert caplog.records[3].message == " should_fail_2:" + _assert_exception_type_and_message( + caplog.records[3], + OSError, + "Connection failed", + ) + + +async def _assert_failing_device_does_not_connect( + device: Device, +) -> pytest.ExceptionInfo[NotConnected]: + return await _assert_failing_devices_do_not_connect([device]) + + +async def _assert_failing_devices_do_not_connect( + devices: Iterable[Device], +) -> pytest.ExceptionInfo[NotConnected]: + with pytest.raises(NotConnected) as excepton_info: + await connect_ophyd_async_devices( + devices, + sim=True, + timeout=0.1, + ) + return excepton_info + + +def _assert_exception_type_and_message( + record: logging.LogRecord, + expected_type: Type[Exception], + expected_message: str, +): + exception_type, exception, _ = cast( + Tuple[Type[Exception], Exception, str], + record.exc_info, + ) + assert expected_type is exception_type + assert (expected_message,) == exception.args