From 1eec835669560d3fbb729e1edcce5d2700e0f235 Mon Sep 17 00:00:00 2001 From: Oliver Silvester Date: Tue, 19 Sep 2023 09:57:11 +0100 Subject: [PATCH] improve testing --- .../core/_device/device_save_loader.py | 2 +- tests/core/_device/test_device_save_loader.py | 53 ++++++++++++++----- 2 files changed, 42 insertions(+), 13 deletions(-) diff --git a/src/ophyd_async/core/_device/device_save_loader.py b/src/ophyd_async/core/_device/device_save_loader.py index 312bf323e0..cdeed306ee 100644 --- a/src/ophyd_async/core/_device/device_save_loader.py +++ b/src/ophyd_async/core/_device/device_save_loader.py @@ -82,7 +82,7 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st if len(signalRWs): if hasattr(device, "sort_signal_by_phase"): phase_dicts: List[Dict[str, SignalRW]] = device.sort_signal_by_phase( - signalRWs + device, signalRWs ) else: phase_dicts.append(signalRWs) diff --git a/tests/core/_device/test_device_save_loader.py b/tests/core/_device/test_device_save_loader.py index 58fada5fb5..29ce6edbd1 100644 --- a/tests/core/_device/test_device_save_loader.py +++ b/tests/core/_device/test_device_save_loader.py @@ -1,4 +1,3 @@ -import asyncio from enum import Enum from os import path from typing import Dict, List @@ -8,7 +7,6 @@ import pytest import yaml from bluesky import RunEngine -from numpy import int32, ndarray from ophyd_async.core import Device, SignalR, SignalRW from ophyd_async.core._device.device_save_loader import ( @@ -24,13 +22,21 @@ def __init__(self): self.sig2: SignalR = epics_signal_r(str, "Value2") +class EnumTest(Enum): + VAL1 = "val1" + VAL2 = "val2" + + class DummyDeviceGroup(Device): def __init__(self, name: str): self.child1: DummyChildDevice = DummyChildDevice() self.child2: DummyChildDevice = DummyChildDevice() self.parent_sig1: SignalRW = epics_signal_rw(str, "ParentValue1") - self.parent_sig2: SignalR = epics_signal_r(int, "ParentValue2") - self.position: npt.NDArray[int32] + self.parent_sig2: SignalR = epics_signal_r( + int, "ParentValue2" + ) # Ensure only RW are found + self.parent_sig3: SignalRW = epics_signal_rw(str, "ParentValue3") + self.position: npt.NDArray[np.int32] @pytest.fixture @@ -48,8 +54,8 @@ async def device_with_phases() -> DummyDeviceGroup: def sort_signal_by_phase(self, signalRWs) -> List[Dict[str, SignalRW]]: phase_1 = {} phase_2 = {} - phase_1[0] = "parent.child1.sig1" - phase_2[0] = "parent.child2.sig1" + phase_1["child1.sig1"] = self.child1.sig1 + phase_2["child2.sig1"] = self.child2.sig1 return [phase_1, phase_2] device.sort_signal_by_phase = sort_signal_by_phase @@ -62,20 +68,43 @@ def test_get_signal_RWs_from_device(device): "parent.child1.sig1", "parent.child2.sig1", "parent.parent_sig1", + "parent.parent_sig3", ] assert all(isinstance(signal, SignalRW) for signal in list(signalRWS.values())) -async def test_save_device(device, device_with_phases, RE, tmp_path): +async def test_save_device_no_phase(device, device_with_phases, tmp_path): RE = RunEngine() await device.child1.sig1.set("string") - # mimic tables in devices - table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "val2": np.array([1, 1, 1, 1, 1])} - + # Test tables PVs + table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])} await device.child2.sig1.set(table_pv) + + # Test enum PVs + await device.parent_sig3.set(EnumTest.VAL1) RE(save_device(device, path.join(tmp_path, "test_file"))) with open(path.join(tmp_path, "test_file.yaml"), "r") as file: yaml_content = yaml.safe_load(file) - for line in yaml_content: - #fill this in + assert yaml_content[0] == { + "child1.sig1": "string", + "child2.sig1": { + "VAL1": [1, 1, 1, 1, 1], + "VAL2": [1, 1, 1, 1, 1], + }, + "parent_sig1": "", + "parent_sig3": "val1", + } + + +async def test_save_device_with_phase(device_with_phases, tmp_path): + RE = RunEngine() + await device_with_phases.child1.sig1.set("string") + # mimic tables in devices + table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])} + await device_with_phases.child2.sig1.set(table_pv) + RE(save_device(device_with_phases, path.join(tmp_path, "test_file"))) + with open(path.join(tmp_path, "test_file.yaml"), "r") as file: + yaml_content = yaml.safe_load(file) + assert yaml_content[0] == {"child1.sig1": "string"} + assert yaml_content[1] == {"child2.sig1": table_pv}