From 444043379cfaa3ae1ade2e5ab89e2e9004ac4632 Mon Sep 17 00:00:00 2001 From: Dominic Oram Date: Fri, 26 Apr 2024 10:22:13 +0100 Subject: [PATCH] Allow context to deal with new style unions and add tests (#436) Fixes #432 --- src/blueapi/core/context.py | 15 +++++---------- tests/core/test_context.py | 19 +++++++++++++++++++ 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 07dc90b10..a6d5ba2ff 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -4,15 +4,8 @@ from dataclasses import dataclass, field from importlib import import_module from inspect import Parameter, signature -from types import ModuleType -from typing import ( - Any, - Generic, - TypeVar, - get_args, - get_origin, - get_type_hints, -) +from types import ModuleType, UnionType +from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop from pydantic import create_model @@ -264,7 +257,7 @@ def _type_spec_for_function( ) return new_args - def _convert_type(self, typ: type) -> type: + def _convert_type(self, typ: type | Any) -> type: """ Recursively convert a type to something that can be deserialised by pydantic. Bluesky protocols (and types that extend them) are replaced @@ -288,6 +281,8 @@ def _convert_type(self, typ: type) -> type: if args: new_types = tuple(self._convert_type(i) for i in args) root = get_origin(typ) + if root == UnionType: + root = Union return root[new_types] if root else typ return typ diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 5440fbfd5..4e6583209 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1,5 +1,6 @@ from __future__ import annotations +from typing import Union from unittest.mock import patch import pytest @@ -274,6 +275,24 @@ def test_reference_type_conversion(empty_context: BlueskyContext) -> None: ) +def test_reference_type_conversion_union(empty_context: BlueskyContext) -> None: + movable_ref: type = empty_context._reference(Movable) + assert empty_context._convert_type(Movable) == movable_ref + assert ( + empty_context._convert_type(Union[Movable, int]) == Union[movable_ref, int] # noqa # type: ignore + ) + + +def test_reference_type_conversion_new_style_union( + empty_context: BlueskyContext, +) -> None: + movable_ref: type = empty_context._reference(Movable) + assert empty_context._convert_type(Movable) == movable_ref + assert ( + empty_context._convert_type(Movable | int) == movable_ref | int # type: ignore + ) + + def test_default_device_reference(empty_context: BlueskyContext) -> None: def default_movable(mov: Movable = "demo") -> MsgGenerator: # type: ignore ...