Skip to content

Commit

Permalink
Allow context to deal with new style unions and add tests (#436)
Browse files Browse the repository at this point in the history
Fixes #432
  • Loading branch information
DominicOram authored Apr 26, 2024
1 parent 4d52871 commit 4440433
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 10 deletions.
15 changes: 5 additions & 10 deletions src/blueapi/core/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,8 @@
from dataclasses import dataclass, field
from importlib import import_module
from inspect import Parameter, signature
from types import ModuleType
from typing import (
Any,
Generic,
TypeVar,
get_args,
get_origin,
get_type_hints,
)
from types import ModuleType, UnionType
from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints

from bluesky.run_engine import RunEngine, call_in_bluesky_event_loop
from pydantic import create_model
Expand Down Expand Up @@ -264,7 +257,7 @@ def _type_spec_for_function(
)
return new_args

def _convert_type(self, typ: type) -> type:
def _convert_type(self, typ: type | Any) -> type:
"""
Recursively convert a type to something that can be deserialised by
pydantic. Bluesky protocols (and types that extend them) are replaced
Expand All @@ -288,6 +281,8 @@ def _convert_type(self, typ: type) -> type:
if args:
new_types = tuple(self._convert_type(i) for i in args)
root = get_origin(typ)
if root == UnionType:
root = Union
return root[new_types] if root else typ
return typ

Expand Down
19 changes: 19 additions & 0 deletions tests/core/test_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from typing import Union
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -274,6 +275,24 @@ def test_reference_type_conversion(empty_context: BlueskyContext) -> None:
)


def test_reference_type_conversion_union(empty_context: BlueskyContext) -> None:
movable_ref: type = empty_context._reference(Movable)
assert empty_context._convert_type(Movable) == movable_ref
assert (
empty_context._convert_type(Union[Movable, int]) == Union[movable_ref, int] # noqa # type: ignore
)


def test_reference_type_conversion_new_style_union(
empty_context: BlueskyContext,
) -> None:
movable_ref: type = empty_context._reference(Movable)
assert empty_context._convert_type(Movable) == movable_ref
assert (
empty_context._convert_type(Movable | int) == movable_ref | int # type: ignore
)


def test_default_device_reference(empty_context: BlueskyContext) -> None:
def default_movable(mov: Movable = "demo") -> MsgGenerator: # type: ignore
...
Expand Down

0 comments on commit 4440433

Please sign in to comment.