diff --git a/pyproject.toml b/pyproject.toml index d65a1f2a15..0f6a47f2d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,8 @@ dev = [ "sphinx-design", "tox-direct", "types-mock", + "pyyaml", + "types-pyyaml", ] [project.scripts] diff --git a/src/ophyd_async/core/_device/device_save_loader.py b/src/ophyd_async/core/_device/device_save_loader.py index 312bf323e0..73a6bcb06e 100644 --- a/src/ophyd_async/core/_device/device_save_loader.py +++ b/src/ophyd_async/core/_device/device_save_loader.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List import yaml from bluesky import Msg @@ -9,7 +9,7 @@ def get_signal_RWs_from_device( - device: Device, prefix: str, signalRWs: Dict[str, SignalRW] = {} + device: Device, path_prefix: str = "", signalRWs: Dict[str, SignalRW] = {} ) -> Dict[str, SignalRW]: """Get all the signalRW's from a device and store with their dotted attribute paths. Used by the save_device and load_device methods. @@ -19,8 +19,8 @@ def get_signal_RWs_from_device( device: Device Ophyd device to retrieve read write signals from - prefix: Str - Device prefix + path_prefix: Str + For internal use, leave blank when calling method. SignalRWs: Dict A dictionary matching the string attribute path of a SignalRW with the @@ -36,16 +36,16 @@ def get_signal_RWs_from_device( dot = "" # Place a dot inbetween the upper and lower class. # Don't do this for highest level class. - if prefix: + if path_prefix: dot = "." - dot_path = f"{prefix}{dot}{attr_name}" + dot_path = f"{path_prefix}{dot}{attr_name}" if type(attr) is SignalRW: signalRWs[dot_path] = attr - get_signal_RWs_from_device(attr, prefix=dot_path) + get_signal_RWs_from_device(attr, path_prefix=dot_path) return signalRWs -def save_device(device: Device, savename: str, ignore: str = Union[None, List[str]]): +def save_device(device: Device, savename: str, ignore: List[str] = []): """ Plan to save the setup of a device by getting a list of its signals and their readback values @@ -81,9 +81,7 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st phase_dicts: List[Dict[str, SignalRW]] = [] if len(signalRWs): if hasattr(device, "sort_signal_by_phase"): - phase_dicts: List[Dict[str, SignalRW]] = device.sort_signal_by_phase( - signalRWs - ) + phase_dicts = device.sort_signal_by_phase(device, signalRWs) else: phase_dicts.append(signalRWs) @@ -111,6 +109,8 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st for phase in phase_dicts: signal_name_values: Dict[str, Any] = {} for signal_name in phase.keys(): + if signal_name in ignore: + continue signal_name_values[signal_name] = signal_values[signal_value_index] signal_value_index += 1 @@ -119,68 +119,3 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st filename = f"{savename}.yaml" with open(filename, "w") as file: yaml.dump(phase_outputs, file) - - -# async def load_device(device, savename: str): -# """Does an abs_set on each signalRW which has differing values to the savefile""" - -# # Locate all signals to later compare with loaded values, then only -# # change differing values - -# signalRWs: Dict[str, SignalRW] = get_signal_RWs_from_device( -# device, "" -# ) # {'device.subdevice.etc: signalRW} -# signal_name_values = ( -# {} -# ) # we want this to be {'device.subdevice.etc: signal location} -# signals_to_locate = [] -# for sig in signalRWs.values(): -# signals_to_locate.append(sig.locate()) - -# signal_values = await asyncio.gather(*signals_to_locate) - -# # Copy logic from save plan to convert enums and np arrays -# for index, value in enumerate(signal_values): -# if isinstance(value, dict): -# for inner_key, inner_value in value.items(): -# if isinstance(inner_value, ndarray): -# value[inner_key] = inner_value.tolist() -# # Convert enums to their values -# elif isinstance(signal_values[index], Enum): -# signal_values[index] = value.value - -# for index, key in enumerate(signalRWs.keys()): -# signal_name_values[key] = signal_values[index] - -# # Get PV info from yaml file -# filename = f"{savename}.yaml" -# with open(filename, "r") as file: -# data_by_phase: List[Dict[str, Any]] = yaml.full_load(file) - -# """For each phase, find the location of the SignalRW's in that phase, -# load them -# to the correct value, and wait for the load to complete""" -# for phase_number, phase in enumerate(data_by_phase): -# phase_load_statuses: List[AsyncStatus] = [] -# for key, value in phase.items(): -# # If the values are different then do an abs_set -# if signal_name_values[key] != value: -# # Key is subdevices_x.subdevices_x+1.etc.signalname. First get -# # the attribute hierarchy -# components = key.split(".") -# lowest_device = device - -# # If there are subdevices -# if len(components) > 1: -# signal_name: str = components[ -# -1 -# ] # Last string is the signal name -# for attribute in components[:-1]: -# lowest_device = getattr(lowest_device, attribute) -# else: -# signal_name: str = components[0] -# signalRW: SignalRW = getattr(lowest_device, signal_name) - -# phase_load_statuses.append(signalRW.set(value, timeout=5)) - -# await asyncio.gather(*phase_load_statuses) diff --git a/tests/core/_device/test_device_save_loader.py b/tests/core/_device/test_device_save_loader.py index 58fada5fb5..a0faa43463 100644 --- a/tests/core/_device/test_device_save_loader.py +++ b/tests/core/_device/test_device_save_loader.py @@ -1,4 +1,3 @@ -import asyncio from enum import Enum from os import path from typing import Dict, List @@ -8,7 +7,6 @@ import pytest import yaml from bluesky import RunEngine -from numpy import int32, ndarray from ophyd_async.core import Device, SignalR, SignalRW from ophyd_async.core._device.device_save_loader import ( @@ -24,13 +22,21 @@ def __init__(self): self.sig2: SignalR = epics_signal_r(str, "Value2") +class EnumTest(Enum): + VAL1 = "val1" + VAL2 = "val2" + + class DummyDeviceGroup(Device): def __init__(self, name: str): self.child1: DummyChildDevice = DummyChildDevice() self.child2: DummyChildDevice = DummyChildDevice() self.parent_sig1: SignalRW = epics_signal_rw(str, "ParentValue1") - self.parent_sig2: SignalR = epics_signal_r(int, "ParentValue2") - self.position: npt.NDArray[int32] + self.parent_sig2: SignalR = epics_signal_r( + int, "ParentValue2" + ) # Ensure only RW are found + self.parent_sig3: SignalRW = epics_signal_rw(str, "ParentValue3") + self.position: npt.NDArray[np.int32] @pytest.fixture @@ -48,34 +54,57 @@ async def device_with_phases() -> DummyDeviceGroup: def sort_signal_by_phase(self, signalRWs) -> List[Dict[str, SignalRW]]: phase_1 = {} phase_2 = {} - phase_1[0] = "parent.child1.sig1" - phase_2[0] = "parent.child2.sig1" + phase_1["child1.sig1"] = self.child1.sig1 + phase_2["child2.sig1"] = self.child2.sig1 return [phase_1, phase_2] - device.sort_signal_by_phase = sort_signal_by_phase + setattr(device, "sort_signal_by_phase", sort_signal_by_phase) return device def test_get_signal_RWs_from_device(device): - signalRWS = get_signal_RWs_from_device(device, "parent") + signalRWS = get_signal_RWs_from_device(device, "") assert list(signalRWS.keys()) == [ - "parent.child1.sig1", - "parent.child2.sig1", - "parent.parent_sig1", + "child1.sig1", + "child2.sig1", + "parent_sig1", + "parent_sig3", ] assert all(isinstance(signal, SignalRW) for signal in list(signalRWS.values())) -async def test_save_device(device, device_with_phases, RE, tmp_path): +async def test_save_device_no_phase(device, device_with_phases, tmp_path): RE = RunEngine() await device.child1.sig1.set("string") - # mimic tables in devices - table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "val2": np.array([1, 1, 1, 1, 1])} - + # Test tables PVs + table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])} await device.child2.sig1.set(table_pv) + + # Test enum PVs + await device.parent_sig3.set(EnumTest.VAL1) RE(save_device(device, path.join(tmp_path, "test_file"))) with open(path.join(tmp_path, "test_file.yaml"), "r") as file: yaml_content = yaml.safe_load(file) - for line in yaml_content: - #fill this in + assert yaml_content[0] == { + "child1.sig1": "string", + "child2.sig1": { + "VAL1": [1, 1, 1, 1, 1], + "VAL2": [1, 1, 1, 1, 1], + }, + "parent_sig1": "", + "parent_sig3": "val1", + } + + +async def test_save_device_with_phase(device_with_phases, tmp_path): + RE = RunEngine() + await device_with_phases.child1.sig1.set("string") + # mimic tables in devices + table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])} + await device_with_phases.child2.sig1.set(table_pv) + RE(save_device(device_with_phases, path.join(tmp_path, "test_file"))) + with open(path.join(tmp_path, "test_file.yaml"), "r") as file: + yaml_content = yaml.safe_load(file) + assert yaml_content[0] == {"child1.sig1": "string"} + assert yaml_content[1] == {"child2.sig1": table_pv}