diff --git a/scabha/basetypes.py b/scabha/basetypes.py index 23eb16bc..f541dd4b 100644 --- a/scabha/basetypes.py +++ b/scabha/basetypes.py @@ -1,11 +1,15 @@ +from __future__ import annotations from dataclasses import field, dataclass from collections import OrderedDict -from typing import List, Union, get_args, get_origin +from typing import List, Union, get_args, get_origin, Any import os.path import re from .exceptions import UnsetError from itertools import zip_longest -from typeguard import check_type, TypeCheckError +from typeguard import ( + check_type, TypeCheckError, TypeCheckerCallable, TypeCheckMemo, checker_lookup_functions +) +from inspect import isclass def EmptyDictDefault(): @@ -118,11 +122,21 @@ def is_file_list_type(dtype): return any(dtype == List[t] for t in FILE_TYPES) -class Skip(object): - def iterate_samples(self, collection): - return () +def check_filelike(value: Any, origin_type: Any, args: tuple[Any, ...], memo: TypeCheckMemo) -> None: + """Custom checker for filelike objects. Currently checks for strings.""" + if not isinstance(value, str): + raise TypeCheckError(f'{value} is not compatible with URI or its subclasses.') +def filelike_lookup(origin_type: Any, args: tuple[Any, ...], extras: tuple[Any, ...]) -> TypeCheckerCallable | None: + """Lookup the custom checker for filelike objects.""" + if isclass(origin_type) and issubclass(origin_type, URI): + return check_filelike + + return None + +checker_lookup_functions.append(filelike_lookup) # Register custom type checker. + def get_filelikes(dtype, value, filelikes=None): """Recursively recover all filelike elements from a composite dtype.""" @@ -152,9 +166,9 @@ def get_filelikes(dtype, value, filelikes=None): return filelikes # This is a special case for tuples of arbitrary - # length i.e. list-like behaviour. - if ... in args: - args = tuple([a for a in args if a != ...]) + # length i.e. list-like behaviour. We can simply + # strip out the Ellipsis. + args = tuple([arg for arg in args if arg != ...]) for dt, v in zip_longest(args, value, fillvalue=args[0]): filelikes = get_filelikes(dt, v, filelikes) @@ -162,11 +176,9 @@ def get_filelikes(dtype, value, filelikes=None): elif origin is Union: for dt in args: - try: - # Do not check collection member types. - check_type(value, dt, collection_check_strategy=Skip()) - except TypeCheckError: + check_type(value, dt) + except TypeCheckError: # Value doesn't match dtype - incorrect branch. continue filelikes = get_filelikes(dt, value, filelikes) diff --git a/tests/scabha_tests/test_filelikes.py b/tests/scabha_tests/test_filelikes.py new file mode 100644 index 00000000..b3697d92 --- /dev/null +++ b/tests/scabha_tests/test_filelikes.py @@ -0,0 +1,42 @@ +from scabha.basetypes import get_filelikes, File, URI, Directory, MS +from typing import Dict, List, Set, Tuple, Union, Optional +import pytest + + +@pytest.fixture(scope="module", params=[File, URI, Directory, MS]) +def templates(request): + + ft = request.param + + TEMPLATES = ( + (Tuple, (), set()), + (Tuple[int, ...], [1, 2], set()), + (Tuple[ft, ...], ("foo", "bar"), {"foo", "bar"}), + (Tuple[ft, str], ("foo", "bar"), {"foo"}), + (Dict[str, int], {"a": 1, "b": 2}, set()), + (Dict[str, ft], {"a": "foo", "b": "bar"}, {"foo", "bar"}), + (Dict[ft, str], {"foo": "a", "bar": "b"}, {"foo", "bar"}), + (List[ft], [], set()), + (List[int], [1, 2], set()), + (List[ft], ["foo", "bar"], {"foo", "bar"}), + (Set[ft], set(), set()), + (Set[int], {1, 2}, set()), + (Set[ft], {"foo", "bar"}, {"foo", "bar"}), + (Union[str, List[ft]], "foo", set()), + (Union[str, List[ft]], ["foo"], {"foo"}), + (Union[str, Tuple[ft]], "foo", set()), + (Union[str, Tuple[ft]], ("foo",), {"foo"}), + (Optional[ft], None, set()), + (Optional[ft], "foo", {"foo"}), + (Optional[Union[ft, int]], 1, set()), + (Optional[Union[ft, int]], "foo", {"foo"}), + (Dict[str, Tuple[ft, str]], {"a": ("foo", "bar")}, {"foo"}) + ) + + return TEMPLATES + + +def test_get_filelikes(templates): + + for dt, v, res in templates: + assert get_filelikes(dt, v) == res, f"Failed for dtype {dt} and value {v}."