Skip to content

Commit

Permalink
Bulk out ophyd async connect logic
Browse files Browse the repository at this point in the history
  • Loading branch information
callumforrester committed Oct 26, 2023
1 parent 7d89292 commit 7241328
Show file tree
Hide file tree
Showing 4 changed files with 235 additions and 14 deletions.
25 changes: 11 additions & 14 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,16 @@
)

from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop
from ophyd_async.core import Device as AsyncDevice
from ophyd_async.core import wait_for_connection
from pydantic import create_model
from pydantic.fields import FieldInfo, ModelField

from blueapi.config import EnvironmentConfig, SourceKind
from blueapi.data_management.gda_directory_provider import VisitDirectoryProvider
from blueapi.utils import BlueapiPlanModelConfig, load_module_all
from blueapi.utils import (
BlueapiPlanModelConfig,
connect_ophyd_async_devices,
load_module_all,
)

from .bluesky_types import (
BLUESKY_PROTOCOLS,
Expand Down Expand Up @@ -105,17 +107,12 @@ def with_config(self, config: EnvironmentConfig) -> None:
elif source.kind is SourceKind.DODAL:
self.with_dodal_module(mod)

call_in_bluesky_event_loop(self.connect_devices(self.sim))

async def connect_devices(self, sim: bool = False) -> None:
coros = {}
for device_name, device in self.devices.items():
if isinstance(device, AsyncDevice):
device.set_name(device_name)
coros[device_name] = device.connect(sim)

if len(coros) > 0:
await asyncio.wait(wait_for_connection(**coros), timeout=30.0)
call_in_bluesky_event_loop(
connect_ophyd_async_devices(
self.devices.values(),
self.sim,
)
)

def with_plan_module(self, module: ModuleType) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig
from .invalid_config_error import InvalidConfigError
from .modules import load_module_all
from .ophyd_async_connect import connect_ophyd_async_devices
from .serialization import serialize
from .thread_exception import handle_all_exceptions

Expand All @@ -13,4 +14,5 @@
"BlueapiModelConfig",
"BlueapiPlanModelConfig",
"InvalidConfigError",
"connect_ophyd_async_devices",
]
49 changes: 49 additions & 0 deletions src/blueapi/utils/ophyd_async_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import asyncio
import logging
from contextlib import suppress
from typing import Any, Dict, Iterable

from ophyd_async.core import DEFAULT_TIMEOUT
from ophyd_async.core import Device as OphydAsyncDevice
from ophyd_async.core import NotConnected


async def connect_ophyd_async_devices(
devices: Iterable[Any],
sim: bool = False,
timeout: float = DEFAULT_TIMEOUT,
) -> None:
tasks: Dict[asyncio.Task, str] = {}
for device in devices:
if isinstance(device, OphydAsyncDevice):
task = asyncio.Task(device.connect(sim=sim))
tasks[task] = device.name
await _wait_for_tasks(tasks, timeout=timeout)


async def _wait_for_tasks(
tasks: Dict[asyncio.Task, str],
timeout: float,
):
done, pending = await asyncio.wait(tasks, timeout=timeout)
if pending:
msg = f"{len(pending)} Devices did not connect:"
for t in pending:
t.cancel()
with suppress(Exception):
await t
e = t.exception()
msg += f"\n {tasks[t]}: {type(e).__name__}"
lines = str(e).splitlines()
if len(lines) <= 1:
msg += f": {e}"
else:
msg += "".join(f"\n {line}" for line in lines)
logging.error(msg)
raised = [t for t in done if t.exception()]
if raised:
logging.error(f"{len(raised)} Devices raised an error:")
for t in raised:
logging.exception(f" {tasks[t]}:", exc_info=t.exception())
if pending or raised:
raise NotConnected("Not all Devices connected")
173 changes: 173 additions & 0 deletions tests/utils/test_ophyd_async_connect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import asyncio
import itertools
import logging
from typing import Callable, Iterable, Tuple, Type, cast

import pytest
from ophyd_async.core import Device, DeviceVector, NotConnected, StandardReadable

from blueapi.utils import connect_ophyd_async_devices


class DummyBaseDevice(Device):
def __init__(self) -> None:
self.connected = False

async def connect(self, sim=False):
self.connected = True


class DummyDeviceThatErrorsWhenConnecting(Device):
async def connect(self, sim: bool = False):
raise IOError("Connection failed")


class DummyDeviceThatTimesOutWhenConnecting(StandardReadable):
async def connect(self, sim: bool = False):
try:
await asyncio.Future()
except asyncio.CancelledError:
raise NotConnected("source: foo")


class DummyDeviceGroup(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyBaseDevice()
self.child2 = DummyBaseDevice()
self.dict_with_children: DeviceVector[DummyBaseDevice] = DeviceVector(
{123: DummyBaseDevice()}
)
self.set_name(name)


class DummyDeviceGroupThatTimesOut(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyDeviceThatTimesOutWhenConnecting()
self.set_name(name)


class DummyDeviceGroupThatErrors(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyDeviceThatErrorsWhenConnecting()
self.set_name(name)


class DummyDeviceGroupThatErrorsAndTimesOut(Device):
def __init__(self, name: str) -> None:
self.child1 = DummyDeviceThatErrorsWhenConnecting()
self.child2 = DummyDeviceThatTimesOutWhenConnecting()
self.set_name(name)


ALL_DEVICE_CONSTRUCTORS = [
DummyDeviceThatErrorsWhenConnecting,
DummyDeviceThatTimesOutWhenConnecting,
DummyDeviceGroupThatErrors,
DummyDeviceGroupThatTimesOut,
DummyDeviceGroupThatErrorsAndTimesOut,
]


@pytest.mark.parametrize("device_constructor", ALL_DEVICE_CONSTRUCTORS)
async def test_device_collector_propagates_errors_and_timeouts(
device_constructor: Callable[[str], Device]
):
await _assert_failing_device_does_not_connect(device_constructor("test"))


@pytest.mark.parametrize(
"device_constructor_1,device_constructor_2",
list(itertools.permutations(ALL_DEVICE_CONSTRUCTORS, 2)),
)
async def test_device_collector_propagates_errors_and_timeouts_from_multiple_devices(
device_constructor_1: Callable[[str], Device],
device_constructor_2: Callable[[str], Device],
):
await _assert_failing_devices_do_not_connect(
[device_constructor_1("test1"), device_constructor_2("test2")]
)


async def test_device_collector_logs_exceptions_for_raised_errors(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.INFO)
await _assert_failing_device_does_not_connect(DummyDeviceGroupThatErrors)
assert caplog.records[0].message == "1 Devices raised an error:"
assert caplog.records[1].message == " should_fail:"
_assert_exception_type_and_message(
caplog.records[1],
OSError,
"Connection failed",
)


async def test_device_collector_logs_exceptions_for_timeouts(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.INFO)
await _assert_failing_device_does_not_connect(DummyDeviceGroupThatTimesOut)
assert caplog.records[0].message == "1 Devices did not connect:"
assert caplog.records[1].message == " should_fail:"
_assert_exception_type_and_message(
caplog.records[1],
NotConnected,
"child1: source: foo",
)


async def test_device_collector_logs_exceptions_for_multiple_devices(
caplog: pytest.LogCaptureFixture,
):
caplog.set_level(logging.INFO)
await _assert_failing_devices_do_not_connect(
[
DummyDeviceGroupThatErrorsAndTimesOut("test1"),
DummyDeviceGroupThatErrors("test2"),
]
)
assert caplog.records[0].message == "1 Devices did not connect:"
assert caplog.records[1].message == " should_fail_1:"
_assert_exception_type_and_message(
caplog.records[1],
OSError,
"Connection failed",
)
assert caplog.records[2].message == "1 Devices raised an error:"
assert caplog.records[3].message == " should_fail_2:"
_assert_exception_type_and_message(
caplog.records[3],
OSError,
"Connection failed",
)


async def _assert_failing_device_does_not_connect(
device: Device,
) -> pytest.ExceptionInfo[NotConnected]:
return await _assert_failing_devices_do_not_connect([device])


async def _assert_failing_devices_do_not_connect(
devices: Iterable[Device],
) -> pytest.ExceptionInfo[NotConnected]:
with pytest.raises(NotConnected) as excepton_info:
await connect_ophyd_async_devices(
devices,
sim=True,
timeout=0.1,
)
return excepton_info


def _assert_exception_type_and_message(
record: logging.LogRecord,
expected_type: Type[Exception],
expected_message: str,
):
exception_type, exception, _ = cast(
Tuple[Type[Exception], Exception, str],
record.exc_info,
)
assert expected_type is exception_type
assert (expected_message,) == exception.args

0 comments on commit 7241328

Please sign in to comment.