Skip to content

Commit

Permalink
Merge branch 'main' into config
Browse files Browse the repository at this point in the history
  • Loading branch information
burkeds committed Sep 18, 2024
2 parents 40d22be + 989beeb commit 365fa7e
Show file tree
Hide file tree
Showing 67 changed files with 1,458 additions and 625 deletions.
2 changes: 1 addition & 1 deletion docs/examples/foo_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(
self.hdf,
path_provider,
lambda: self.name,
adcore.ADBaseShapeProvider(self.drv),
adcore.ADBaseDatasetDescriber(self.drv),
),
config_sigs=(self.drv.acquire_time,),
name=name,
Expand Down
12 changes: 9 additions & 3 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@
from ._providers import (
AutoIncrementFilenameProvider,
AutoIncrementingPathProvider,
DatasetDescriber,
FilenameProvider,
NameProvider,
PathInfo,
PathProvider,
ShapeProvider,
StaticFilenameProvider,
StaticPathProvider,
UUIDFilenameProvider,
Expand All @@ -61,9 +61,14 @@
soft_signal_rw,
wait_for_value,
)
from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum
from ._signal_backend import (
RuntimeSubsetEnum,
SignalBackend,
SubsetEnum,
)
from ._soft_signal_backend import SignalMetadata, SoftSignalBackend
from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
from ._table import Table
from ._utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
Expand Down Expand Up @@ -117,7 +122,7 @@
"NameProvider",
"PathInfo",
"PathProvider",
"ShapeProvider",
"DatasetDescriber",
"StaticFilenameProvider",
"StaticPathProvider",
"UUIDFilenameProvider",
Expand Down Expand Up @@ -152,6 +157,7 @@
"CalculateTimeout",
"NotConnected",
"ReadingValueCallback",
"Table",
"T",
"WatcherUpdate",
"get_dtype",
Expand Down
94 changes: 54 additions & 40 deletions src/ophyd_async/core/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,24 @@ class TriggerInfo(BaseModel):
"""Minimal set of information required to setup triggering on a detector"""

#: Number of triggers that will be sent, 0 means infinite
number: int = Field(gt=0)
number: int = Field(ge=0)
#: Sort of triggers that will be sent
trigger: DetectorTrigger = Field()
trigger: DetectorTrigger = Field(default=DetectorTrigger.internal)
#: What is the minimum deadtime between triggers
deadtime: float | None = Field(ge=0)
deadtime: float | None = Field(default=None, ge=0)
#: What is the maximum high time of the triggers
livetime: float | None = Field(ge=0)
livetime: float | None = Field(default=None, ge=0)
#: What is the maximum timeout on waiting for a frame
frame_timeout: float | None = Field(None, gt=0)
frame_timeout: float | None = Field(default=None, gt=0)
#: How many triggers make up a single StreamDatum index, to allow multiple frames
#: from a faster detector to be zipped with a single frame from a slow detector
#: e.g. if num=10 and multiplier=5 then the detector will take 10 frames,
#: but publish 2 indices, and describe() will show a shape of (5, h, w)
multiplier: int = 1
#: The number of times the detector can go through a complete cycle of kickoff and
#: complete without needing to re-arm. This is important for detectors where the
#: process of arming is expensive in terms of time
iteration: int = 1


class DetectorControl(ABC):
Expand All @@ -78,27 +82,35 @@ def get_deadtime(self, exposure: float | None) -> float:
"""For a given exposure, how long should the time between exposures be"""

@abstractmethod
async def arm(
self,
num: int,
trigger: DetectorTrigger = DetectorTrigger.internal,
exposure: Optional[float] = None,
) -> AsyncStatus:
async def prepare(self, trigger_info: TriggerInfo):
"""
Arm detector, do all necessary steps to prepare detector for triggers.
Do all necessary steps to prepare the detector for triggers.
Args:
num: Expected number of frames
trigger: Type of trigger for which to prepare the detector. Defaults to
DetectorTrigger.internal.
exposure: Exposure time with which to set up the detector. Defaults to None
if not applicable or the detector is expected to use its previously-set
exposure time.
trigger_info: This is a Pydantic model which contains
number Expected number of frames.
trigger Type of trigger for which to prepare the detector. Defaults
to DetectorTrigger.internal.
livetime Livetime / Exposure time with which to set up the detector.
Defaults to None
if not applicable or the detector is expected to use its previously-set
exposure time.
deadtime Defaults to None. This is the minimum deadtime between
triggers.
multiplier The number of triggers grouped into a single StreamDatum
index.
"""

Returns:
AsyncStatus: Status representing the arm operation. This function returning
represents the start of the arm. The returned status completing means
the detector is now armed.
@abstractmethod
async def arm(self) -> None:
"""
Arm the detector
"""

@abstractmethod
async def wait_for_idle(self):
"""
This will wait on the internal _arm_status and wait for it to get disarmed/idle
"""

@abstractmethod
Expand Down Expand Up @@ -186,7 +198,7 @@ def __init__(
self._watchers: List[Callable] = []
self._fly_status: Optional[WatchableAsyncStatus] = None
self._fly_start: float

self._iterations_completed: int = 0
self._intial_frame: int
self._last_frame: int
super().__init__(name)
Expand Down Expand Up @@ -224,7 +236,7 @@ async def _check_config_sigs(self):
@AsyncStatus.wrap
async def unstage(self) -> None:
# Stop data writing.
await self.writer.close()
await asyncio.gather(self.writer.close(), self.controller.disarm())

async def read_configuration(self) -> Dict[str, Reading]:
return await merge_gathered_dicts(sig.read() for sig in self._config_sigs)
Expand All @@ -248,15 +260,15 @@ async def trigger(self) -> None:
trigger=DetectorTrigger.internal,
deadtime=None,
livetime=None,
frame_timeout=None,
)
)
assert self._trigger_info
assert self._trigger_info.trigger is DetectorTrigger.internal
# Arm the detector and wait for it to finish.
indices_written = await self.writer.get_indices_written()
written_status = await self.controller.arm(
num=self._trigger_info.number,
trigger=self._trigger_info.trigger,
)
await written_status
await self.controller.arm()
await self.controller.wait_for_idle()
end_observation = indices_written + 1

async for index in self.writer.observe_indices_written(
Expand All @@ -283,35 +295,35 @@ async def prepare(self, value: TriggerInfo) -> None:
Args:
value: TriggerInfo describing how to trigger the detector
"""
self._trigger_info = value
if value.trigger != DetectorTrigger.internal:
assert (
value.deadtime
), "Deadtime must be supplied when in externally triggered mode"
if value.deadtime:
required = self.controller.get_deadtime(self._trigger_info.livetime)
required = self.controller.get_deadtime(value.livetime)
assert required <= value.deadtime, (
f"Detector {self.controller} needs at least {required}s deadtime, "
f"but trigger logic provides only {value.deadtime}s"
)
self._trigger_info = value
self._initial_frame = await self.writer.get_indices_written()
self._last_frame = self._initial_frame + self._trigger_info.number
self._arm_status = await self.controller.arm(
num=self._trigger_info.number,
trigger=self._trigger_info.trigger,
exposure=self._trigger_info.livetime,
self._describe, _ = await asyncio.gather(
self.writer.open(value.multiplier), self.controller.prepare(value)
)
self._fly_start = time.monotonic()
self._describe = await self.writer.open(value.multiplier)
if value.trigger != DetectorTrigger.internal:
await self.controller.arm()
self._fly_start = time.monotonic()

@AsyncStatus.wrap
async def kickoff(self):
if not self._arm_status:
raise Exception("Detector not armed!")
assert self._trigger_info, "Prepare must be called before kickoff!"
if self._iterations_completed >= self._trigger_info.iteration:
raise Exception(f"Kickoff called more than {self._trigger_info.iteration}")
self._iterations_completed += 1

@WatchableAsyncStatus.wrap
async def complete(self):
assert self._arm_status, "Prepare not run"
assert self._trigger_info
async for index in self.writer.observe_indices_written(
self._trigger_info.frame_timeout
Expand All @@ -332,6 +344,8 @@ async def complete(self):
)
if index >= self._trigger_info.number:
break
if self._iterations_completed == self._trigger_info.iteration:
await self.controller.wait_for_idle()

async def describe_collect(self) -> Dict[str, DataKey]:
return self._describe
Expand Down
12 changes: 12 additions & 0 deletions src/ophyd_async/core/_device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bluesky.plan_stubs import abs_set, wait
from bluesky.protocols import Location
from bluesky.utils import Msg
from pydantic import BaseModel

from ._device import Device
from ._signal import SignalRW
Expand All @@ -18,6 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No
)


def pydantic_model_abstraction_representer(
dumper: yaml.Dumper, model: BaseModel
) -> yaml.Node:
return dumper.represent_data(model.model_dump(mode="python"))


class OphydDumper(yaml.Dumper):
def represent_data(self, data: Any) -> Any:
if isinstance(data, Enum):
Expand Down Expand Up @@ -134,6 +141,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:
"""

yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)
yaml.add_multi_representer(
BaseModel,
pydantic_model_abstraction_representer,
Dumper=yaml.Dumper,
)

with open(save_path, "w") as file:
yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False)
Expand Down
8 changes: 4 additions & 4 deletions src/ophyd_async/core/_mock_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from functools import cached_property
from typing import Callable, Optional, Type
from unittest.mock import Mock
from unittest.mock import AsyncMock

from bluesky.protocols import Descriptor, Reading

Expand Down Expand Up @@ -46,8 +46,8 @@ async def connect(self, timeout: float = DEFAULT_TIMEOUT) -> None:
pass

@cached_property
def put_mock(self) -> Mock:
return Mock(name="put", spec=Callable)
def put_mock(self) -> AsyncMock:
return AsyncMock(name="put", spec=Callable)

@cached_property
def put_proceeds(self) -> asyncio.Event:
Expand All @@ -56,7 +56,7 @@ def put_proceeds(self) -> asyncio.Event:
return put_proceeds

async def put(self, value: Optional[T], wait=True, timeout=None):
self.put_mock(value, wait=wait, timeout=timeout)
await self.put_mock(value, wait=wait, timeout=timeout)
await self.soft_backend.put(value, wait=wait, timeout=timeout)

if wait:
Expand Down
12 changes: 7 additions & 5 deletions src/ophyd_async/core/_mock_signal_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from contextlib import asynccontextmanager, contextmanager
from typing import Any, Callable, Iterable
from unittest.mock import Mock
from typing import Any, Awaitable, Callable, Iterable
from unittest.mock import AsyncMock

from ._mock_signal_backend import MockSignalBackend
from ._signal import Signal
Expand Down Expand Up @@ -41,7 +41,7 @@ async def mock_puts_blocked(*signals: Signal):
set_mock_put_proceeds(signal, True)


def get_mock_put(signal: Signal) -> Mock:
def get_mock_put(signal: Signal) -> AsyncMock:
"""Get the mock associated with the put call on the signal."""
return _get_mock_signal_backend(signal).put_mock

Expand Down Expand Up @@ -136,12 +136,14 @@ def set_mock_values(


@contextmanager
def _unset_side_effect_cm(put_mock: Mock):
def _unset_side_effect_cm(put_mock: AsyncMock):
yield
put_mock.side_effect = None


def callback_on_mock_put(signal: Signal[T], callback: Callable[[T], None]):
def callback_on_mock_put(
signal: Signal[T], callback: Callable[[T], None] | Callable[[T], Awaitable[None]]
):
"""For setting a callback when a backend is put to.
Can either be used in a context, with the callback being
Expand Down
8 changes: 6 additions & 2 deletions src/ophyd_async/core/_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,11 @@ def __call__(self) -> str:
"""Get the name to be used as a data_key in the descriptor document"""


class ShapeProvider(Protocol):
class DatasetDescriber(Protocol):
@abstractmethod
async def __call__(self) -> tuple:
async def np_datatype(self) -> str:
"""Represents the numpy datatype"""

@abstractmethod
async def shape(self) -> tuple[int, ...]:
"""Get the shape of the data collection"""
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]):
try:
await asyncio.wait_for(self._wait_for_value(signal), timeout)
except asyncio.TimeoutError as e:
raise TimeoutError(
raise asyncio.TimeoutError(
f"{signal.name} didn't match {self._matcher_name} in {timeout}s, "
f"last value {self._last_value!r}"
) from e
Expand Down
15 changes: 14 additions & 1 deletion src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
Literal,
Optional,
Tuple,
Type,
)

from ._protocol import DataKey, Reading
from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T
Expand All @@ -11,6 +19,11 @@ class SignalBackend(Generic[T]):
#: Datatype of the signal value
datatype: Optional[Type[T]] = None

@classmethod
@abstractmethod
def datatype_allowed(cls, dtype: type):
"""Check if a given datatype is acceptable for this signal backend."""

#: Like ca://PV_PREFIX:SIGNAL
@abstractmethod
def source(self, name: str) -> str:
Expand Down
Loading

0 comments on commit 365fa7e

Please sign in to comment.