diff --git a/src/ophyd_async/core/async_status.py b/src/ophyd_async/core/async_status.py index 474dd1eb15..96113208bb 100644 --- a/src/ophyd_async/core/async_status.py +++ b/src/ophyd_async/core/async_status.py @@ -9,6 +9,7 @@ Awaitable, Callable, Generic, + Optional, SupportsFloat, Type, TypeVar, @@ -106,18 +107,19 @@ class WatchableAsyncStatus(AsyncStatusBase, Generic[T]): def __init__( self, iterator: AsyncIterator[WatcherUpdate[T]], - timeout_s: float = 0.0, + timeout: Optional[SupportsFloat] = None, ): self._watchers: list[Watcher] = [] self._start = time.monotonic() - self._timeout = self._start + timeout_s if timeout_s else None self._last_update: WatcherUpdate[T] | None = None - super().__init__(self._notify_watchers_from(iterator)) + if isinstance(timeout, SupportsFloat): + timeout = float(timeout) + super().__init__( + asyncio.wait_for(self._notify_watchers_from(iterator), timeout) + ) async def _notify_watchers_from(self, iterator: AsyncIterator[WatcherUpdate[T]]): async for update in iterator: - if self._timeout and time.monotonic() > self._timeout: - raise TimeoutError() self._last_update = ( update if update.time_elapsed is None @@ -141,10 +143,9 @@ def watch(self, watcher: Watcher): def wrap( cls: Type[WAS], f: Callable[P, AsyncIterator[WatcherUpdate[T]]], - timeout_s: float = 0.0, ) -> Callable[P, WAS]: """Wrap an AsyncIterator in a WatchableAsyncStatus. If it takes - 'timeout_s' as an argument, this must be a float and it will be propagated + 'timeout' as an argument, this must be a float and it will be propagated to the status.""" @functools.wraps(f) @@ -152,9 +153,8 @@ def wrap_f(*args: P.args, **kwargs: P.kwargs) -> WAS: # We can't type this more properly because Concatenate/ParamSpec doesn't # yet support keywords # https://peps.python.org/pep-0612/#concatenating-keyword-parameters - _timeout = kwargs.get("timeout_s") - assert isinstance(_timeout, SupportsFloat) or _timeout is None - timeout = _timeout or 0.0 - return cls(f(*args, **kwargs), timeout_s=float(timeout)) + timeout = kwargs.get("timeout") + assert isinstance(timeout, SupportsFloat) or timeout is None + return cls(f(*args, **kwargs), timeout=timeout) return cast(Callable[P, WAS], wrap_f) diff --git a/src/ophyd_async/core/detector.py b/src/ophyd_async/core/detector.py index d0f6ddf9cd..100aacc203 100644 --- a/src/ophyd_async/core/detector.py +++ b/src/ophyd_async/core/detector.py @@ -300,10 +300,10 @@ async def _prepare(self, value: T) -> None: exposure=self._trigger_info.livetime, ) - def kickoff(self, timeout_s=0.0): + def kickoff(self, timeout: Optional[float] = None): self._fly_start = time.monotonic() self._fly_status = WatchableAsyncStatus( - self._observe_writer_indicies(self._last_frame), timeout_s + self._observe_writer_indicies(self._last_frame), timeout ) return self._fly_status diff --git a/src/ophyd_async/epics/motion/motor.py b/src/ophyd_async/epics/motion/motor.py index 09345c40a8..a0e2648d14 100644 --- a/src/ophyd_async/epics/motion/motor.py +++ b/src/ophyd_async/epics/motion/motor.py @@ -71,7 +71,7 @@ def move(self, new_position: float, timeout: Optional[float] = None): call_in_bluesky_event_loop(self._move(new_position), timeout) # type: ignore @WatchableAsyncStatus.wrap - async def set(self, new_position: float, timeout_s: float = 0.0): + async def set(self, new_position: float, timeout: float = 0.0): update = await self._move(new_position) start = time.monotonic() async for current_position in observe_value(self.user_readback): diff --git a/tests/core/test_async_status_wrapper.py b/tests/core/test_async_status_wrapper.py index 36fc3ec95a..c62af78df8 100644 --- a/tests/core/test_async_status_wrapper.py +++ b/tests/core/test_async_status_wrapper.py @@ -85,7 +85,7 @@ async def set(self, val): class ASTestDeviceTimeoutSet(ASTestDevice): @WatchableAsyncStatus.wrap - async def set(self, val, timeout_s=0.01): + async def set(self, val, timeout=0.01): assert self._staged await asyncio.sleep(0.01) self.sig._backend._set_value(val - 1) # type: ignore @@ -231,7 +231,7 @@ async def test_asyncstatus_times_out(RE): td = ASTestDeviceTimeoutSet() await td.connect() await td.stage() - st = td.set(6, timeout_s=0.01) + st = td.set(6, timeout=0.01) while not st.done: await asyncio.sleep(0.01) assert not st.success diff --git a/tests/epics/motion/test_motor.py b/tests/epics/motion/test_motor.py index b75adcd6ac..fa52f4da43 100644 --- a/tests/epics/motion/test_motor.py +++ b/tests/epics/motion/test_motor.py @@ -47,7 +47,7 @@ async def wait_for_eq(item, attribute, comparison, timeout): async def test_motor_moving_well(sim_motor: motor.Motor) -> None: set_sim_put_proceeds(sim_motor.user_setpoint, False) set_sim_value(sim_motor.motor_done_move, False) - s = sim_motor.set(0.55, timeout_s=1) + s = sim_motor.set(0.55, timeout=1) watcher = Mock(spec=Watcher) s.watch(watcher) done = Mock()