Skip to content

Commit

Permalink
improve testing
Browse files Browse the repository at this point in the history
  • Loading branch information
olliesilvester committed Sep 20, 2023
1 parent 2c5862a commit 128b294
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 93 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ dev = [
"sphinx-design",
"tox-direct",
"types-mock",
"pyyaml",
"types-pyyaml",
]

[project.scripts]
Expand Down
87 changes: 11 additions & 76 deletions src/ophyd_async/core/_device/device_save_loader.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import Any, Dict, List, Union
from typing import Any, Dict, List

import yaml
from bluesky import Msg
Expand All @@ -9,7 +9,7 @@


def get_signal_RWs_from_device(
device: Device, prefix: str, signalRWs: Dict[str, SignalRW] = {}
device: Device, path_prefix: str = "", signalRWs: Dict[str, SignalRW] = {}
) -> Dict[str, SignalRW]:
"""Get all the signalRW's from a device and store with their dotted attribute paths.
Used by the save_device and load_device methods.
Expand All @@ -19,8 +19,8 @@ def get_signal_RWs_from_device(
device: Device
Ophyd device to retrieve read write signals from
prefix: Str
Device prefix
path_prefix: Str
For internal use, leave blank when calling method.
SignalRWs: Dict
A dictionary matching the string attribute path of a SignalRW with the
Expand All @@ -36,16 +36,16 @@ def get_signal_RWs_from_device(
dot = ""
# Place a dot inbetween the upper and lower class.
# Don't do this for highest level class.
if prefix:
if path_prefix:
dot = "."
dot_path = f"{prefix}{dot}{attr_name}"
dot_path = f"{path_prefix}{dot}{attr_name}"
if type(attr) is SignalRW:
signalRWs[dot_path] = attr
get_signal_RWs_from_device(attr, prefix=dot_path)
get_signal_RWs_from_device(attr, path_prefix=dot_path)
return signalRWs


def save_device(device: Device, savename: str, ignore: str = Union[None, List[str]]):
def save_device(device: Device, savename: str, ignore: List[str] = []):
"""
Plan to save the setup of a device by getting a list of its signals and their
readback values
Expand Down Expand Up @@ -81,9 +81,7 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st
phase_dicts: List[Dict[str, SignalRW]] = []
if len(signalRWs):
if hasattr(device, "sort_signal_by_phase"):
phase_dicts: List[Dict[str, SignalRW]] = device.sort_signal_by_phase(
signalRWs
)
phase_dicts = device.sort_signal_by_phase(device, signalRWs)
else:
phase_dicts.append(signalRWs)

Expand Down Expand Up @@ -111,6 +109,8 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st
for phase in phase_dicts:
signal_name_values: Dict[str, Any] = {}
for signal_name in phase.keys():
if signal_name in ignore:
continue
signal_name_values[signal_name] = signal_values[signal_value_index]
signal_value_index += 1

Expand All @@ -119,68 +119,3 @@ def save_device(device: Device, savename: str, ignore: str = Union[None, List[st
filename = f"{savename}.yaml"
with open(filename, "w") as file:
yaml.dump(phase_outputs, file)


# async def load_device(device, savename: str):
# """Does an abs_set on each signalRW which has differing values to the savefile"""

# # Locate all signals to later compare with loaded values, then only
# # change differing values

# signalRWs: Dict[str, SignalRW] = get_signal_RWs_from_device(
# device, ""
# ) # {'device.subdevice.etc: signalRW}
# signal_name_values = (
# {}
# ) # we want this to be {'device.subdevice.etc: signal location}
# signals_to_locate = []
# for sig in signalRWs.values():
# signals_to_locate.append(sig.locate())

# signal_values = await asyncio.gather(*signals_to_locate)

# # Copy logic from save plan to convert enums and np arrays
# for index, value in enumerate(signal_values):
# if isinstance(value, dict):
# for inner_key, inner_value in value.items():
# if isinstance(inner_value, ndarray):
# value[inner_key] = inner_value.tolist()
# # Convert enums to their values
# elif isinstance(signal_values[index], Enum):
# signal_values[index] = value.value

# for index, key in enumerate(signalRWs.keys()):
# signal_name_values[key] = signal_values[index]

# # Get PV info from yaml file
# filename = f"{savename}.yaml"
# with open(filename, "r") as file:
# data_by_phase: List[Dict[str, Any]] = yaml.full_load(file)

# """For each phase, find the location of the SignalRW's in that phase,
# load them
# to the correct value, and wait for the load to complete"""
# for phase_number, phase in enumerate(data_by_phase):
# phase_load_statuses: List[AsyncStatus] = []
# for key, value in phase.items():
# # If the values are different then do an abs_set
# if signal_name_values[key] != value:
# # Key is subdevices_x.subdevices_x+1.etc.signalname. First get
# # the attribute hierarchy
# components = key.split(".")
# lowest_device = device

# # If there are subdevices
# if len(components) > 1:
# signal_name: str = components[
# -1
# ] # Last string is the signal name
# for attribute in components[:-1]:
# lowest_device = getattr(lowest_device, attribute)
# else:
# signal_name: str = components[0]
# signalRW: SignalRW = getattr(lowest_device, signal_name)

# phase_load_statuses.append(signalRW.set(value, timeout=5))

# await asyncio.gather(*phase_load_statuses)
63 changes: 46 additions & 17 deletions tests/core/_device/test_device_save_loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import asyncio
from enum import Enum
from os import path
from typing import Dict, List
Expand All @@ -8,7 +7,6 @@
import pytest
import yaml
from bluesky import RunEngine
from numpy import int32, ndarray

from ophyd_async.core import Device, SignalR, SignalRW
from ophyd_async.core._device.device_save_loader import (
Expand All @@ -24,13 +22,21 @@ def __init__(self):
self.sig2: SignalR = epics_signal_r(str, "Value2")


class EnumTest(Enum):
VAL1 = "val1"
VAL2 = "val2"


class DummyDeviceGroup(Device):
def __init__(self, name: str):
self.child1: DummyChildDevice = DummyChildDevice()
self.child2: DummyChildDevice = DummyChildDevice()
self.parent_sig1: SignalRW = epics_signal_rw(str, "ParentValue1")
self.parent_sig2: SignalR = epics_signal_r(int, "ParentValue2")
self.position: npt.NDArray[int32]
self.parent_sig2: SignalR = epics_signal_r(
int, "ParentValue2"
) # Ensure only RW are found
self.parent_sig3: SignalRW = epics_signal_rw(str, "ParentValue3")
self.position: npt.NDArray[np.int32]


@pytest.fixture
Expand All @@ -48,34 +54,57 @@ async def device_with_phases() -> DummyDeviceGroup:
def sort_signal_by_phase(self, signalRWs) -> List[Dict[str, SignalRW]]:
phase_1 = {}
phase_2 = {}
phase_1[0] = "parent.child1.sig1"
phase_2[0] = "parent.child2.sig1"
phase_1["child1.sig1"] = self.child1.sig1
phase_2["child2.sig1"] = self.child2.sig1
return [phase_1, phase_2]

device.sort_signal_by_phase = sort_signal_by_phase
setattr(device, "sort_signal_by_phase", sort_signal_by_phase)
return device


def test_get_signal_RWs_from_device(device):
signalRWS = get_signal_RWs_from_device(device, "parent")
signalRWS = get_signal_RWs_from_device(device, "")
assert list(signalRWS.keys()) == [
"parent.child1.sig1",
"parent.child2.sig1",
"parent.parent_sig1",
"child1.sig1",
"child2.sig1",
"parent_sig1",
"parent_sig3",
]
assert all(isinstance(signal, SignalRW) for signal in list(signalRWS.values()))


async def test_save_device(device, device_with_phases, RE, tmp_path):
async def test_save_device_no_phase(device, device_with_phases, tmp_path):
RE = RunEngine()
await device.child1.sig1.set("string")
# mimic tables in devices
table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "val2": np.array([1, 1, 1, 1, 1])}

# Test tables PVs
table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])}
await device.child2.sig1.set(table_pv)

# Test enum PVs
await device.parent_sig3.set(EnumTest.VAL1)
RE(save_device(device, path.join(tmp_path, "test_file")))

with open(path.join(tmp_path, "test_file.yaml"), "r") as file:
yaml_content = yaml.safe_load(file)
for line in yaml_content:
#fill this in
assert yaml_content[0] == {
"child1.sig1": "string",
"child2.sig1": {
"VAL1": [1, 1, 1, 1, 1],
"VAL2": [1, 1, 1, 1, 1],
},
"parent_sig1": "",
"parent_sig3": "val1",
}


async def test_save_device_with_phase(device_with_phases, tmp_path):
RE = RunEngine()
await device_with_phases.child1.sig1.set("string")
# mimic tables in devices
table_pv = {"VAL1": np.array([1, 1, 1, 1, 1]), "VAL2": np.array([1, 1, 1, 1, 1])}
await device_with_phases.child2.sig1.set(table_pv)
RE(save_device(device_with_phases, path.join(tmp_path, "test_file")))
with open(path.join(tmp_path, "test_file.yaml"), "r") as file:
yaml_content = yaml.safe_load(file)
assert yaml_content[0] == {"child1.sig1": "string"}
assert yaml_content[1] == {"child2.sig1": table_pv}

0 comments on commit 128b294

Please sign in to comment.