diff --git a/tiled/adapters/awkward.py b/tiled/adapters/awkward.py index eea558f99..e0a661eb4 100644 --- a/tiled/adapters/awkward.py +++ b/tiled/adapters/awkward.py @@ -5,6 +5,7 @@ from ..structures.awkward import AwkwardStructure from ..structures.core import StructureFamily +from .array import ArrayAdapter class AwkwardAdapter: diff --git a/tiled/adapters/awkward_buffers.py b/tiled/adapters/awkward_buffers.py index 5d1bd8514..3977f91c6 100644 --- a/tiled/adapters/awkward_buffers.py +++ b/tiled/adapters/awkward_buffers.py @@ -13,19 +13,20 @@ from ..server.pydantic_awkward import AwkwardStructure from ..structures.core import StructureFamily from ..utils import path_from_uri +from .array import ArrayAdapter from .awkward import AwkwardAdapter -class DirectoryContainer(collections.abc.MutableMapping[str, JSON]): +class DirectoryContainer(collections.abc.MutableMapping[str, bytes]): def __init__(self, directory: Path, form: Any): self.directory = directory self.form = form - def __getitem__(self, form_key: str) -> JSON: + def __getitem__(self, form_key: str) -> bytes: with open(self.directory / form_key, "rb") as file: return file.read() - def __setitem__(self, form_key: str, value: Buffer) -> None: + def __setitem__(self, form_key: str, value: bytes) -> None: with open(self.directory / form_key, "wb") as file: file.write(value) diff --git a/tiled/adapters/mapping.py b/tiled/adapters/mapping.py index da7e2d9e6..e13dc1cbc 100644 --- a/tiled/adapters/mapping.py +++ b/tiled/adapters/mapping.py @@ -4,8 +4,10 @@ import operator from collections import Counter from datetime import datetime -from typing import Optional, Union +from typing import Any, Iterable, Iterator, List, Optional, Tuple, Union, cast +from _pydatetime import timedelta +from fastapi import APIRouter from type_alliases import JSON, Spec from ..access_policies import DummyAccessPolicy, SimpleAccessPolicy @@ -24,12 +26,13 @@ StructureFamilyQuery, ) from ..query_registration import QueryTranslationRegistry +from ..server.schemas import NodeStructure, SortingItem from ..structures.core import StructureFamily -from ..utils import UNCHANGED +from ..utils import UNCHANGED, Sentinel from .utils import IndexersMixin -class MapAdapter(collections.abc.Mapping, IndexersMixin): +class MapAdapter(collections.abc.Mapping[str, Any], IndexersMixin): """ Adapt any mapping (dictionary-like object) to Tiled. """ @@ -56,17 +59,17 @@ class MapAdapter(collections.abc.Mapping, IndexersMixin): def __init__( self, - mapping: dict, + mapping: dict[str, Any], *, - structure=None, + structure: Optional[NodeStructure] = None, metadata: Optional[JSON] = None, - sorting=None, + sorting: Optional[List[SortingItem]] = None, specs: Optional[list[Spec]] = None, access_policy: Optional[Union[SimpleAccessPolicy, DummyAccessPolicy]] = None, - entries_stale_after=None, - metadata_stale_after=None, - must_revalidate=True, - ): + entries_stale_after: Optional[timedelta] = None, + metadata_stale_after: Optional[timedelta] = None, + must_revalidate: bool = True, + ) -> None: """ Create a simple Adapter from any mapping (e.g. dict, OneShotCachedMap). @@ -96,36 +99,38 @@ def __init__( # This is a special case that means, "the given ordering". # By giving that a name ("_") we enable requests to asking for the # last N by requesting the sorting ("_", -1). - sorting = [("_", 1)] + sorting = [SortingItem("_", 1)] self._sorting = sorting self._metadata = metadata or {} self.specs = specs or [] self._access_policy = access_policy self._must_revalidate = must_revalidate - self.include_routers = [] - self.background_tasks = [] + self.include_routers: list[APIRouter] = [] + self.background_tasks: list[Any] = [] self.entries_stale_after = entries_stale_after self.metadata_stale_after = metadata_stale_after self.specs = specs or [] super().__init__() @property - def must_revalidate(self): + def must_revalidate(self) -> bool: return self._must_revalidate @must_revalidate.setter - def must_revalidate(self, value): + def must_revalidate(self, value: bool) -> None: self._must_revalidate = value @property - def access_policy(self): + def access_policy(self) -> Optional[Union[SimpleAccessPolicy, DummyAccessPolicy]]: return self._access_policy @access_policy.setter - def access_policy(self, value): + def access_policy( + self, value: Union[SimpleAccessPolicy, DummyAccessPolicy] + ) -> None: self._access_policy = value - def metadata(self): + def metadata(self) -> JSON: "Metadata about this Adapter." # Ensure this is immutable (at the top level) to help the user avoid # getting the wrong impression that editing this would update anything @@ -133,56 +138,56 @@ def metadata(self): return self._metadata @property - def sorting(self): + def sorting(self) -> list[SortingItem]: return list(self._sorting) - def __repr__(self): + def __repr__(self) -> str: return ( f"<{type(self).__name__}({{{', '.join(repr(k) for k in self._mapping)}}})>" ) - def __getitem__(self, key): + def __getitem__(self, key: str) -> Any: return self._mapping[key] - def __iter__(self): + def __iter__(self) -> Iterator[str]: yield from self._mapping - def __len__(self): + def __len__(self) -> int: return len(self._mapping) - def keys(self): + def keys(self) -> KeysView: # type: ignore return KeysView(lambda: len(self), self._keys_slice) - def values(self): + def values(self) -> ValuesView: # type: ignore return ValuesView(lambda: len(self), self._items_slice) - def items(self): + def items(self) -> ItemsView: # type: ignore return ItemsView(lambda: len(self), self._items_slice) - def structure(self): + def structure(self) -> None: return None @property - def metadata_stale_at(self): + def metadata_stale_at(self) -> Optional[timedelta]: if self.metadata_stale_after is None: - return - return self.metadata_stale_after + datetime.utcnow() + return None + return self.metadata_stale_after + datetime.now() @property - def entries_stale_at(self): + def entries_stale_at(self) -> Optional[timedelta]: if self.entries_stale_after is None: - return - return self.entries_stale_after + datetime.utcnow() + return None + return self.entries_stale_after + datetime.now() def new_variation( self, - *args, - mapping=UNCHANGED, - metadata=UNCHANGED, - sorting=UNCHANGED, - must_revalidate=UNCHANGED, - **kwargs, - ): + *args: Any, + mapping: Union[Sentinel, dict[str, Any]] = UNCHANGED, + metadata: Union[Sentinel, JSON] = UNCHANGED, + sorting: Union[Sentinel, list[SortingItem]] = UNCHANGED, + must_revalidate: Union[Sentinel, bool] = UNCHANGED, + **kwargs: Any, + ) -> "MapAdapter": if mapping is UNCHANGED: mapping = self._mapping if metadata is UNCHANGED: @@ -192,19 +197,19 @@ def new_variation( if must_revalidate is UNCHANGED: must_revalidate = self.must_revalidate return type(self)( - *args, - mapping=mapping, - sorting=sorting, - metadata=self._metadata, + # *args, + mapping=cast(dict[str, Any], mapping), + sorting=cast(list[SortingItem], sorting), + metadata=cast(JSON, self._metadata), specs=self.specs, access_policy=self.access_policy, entries_stale_after=self.entries_stale_after, metadata_stale_after=self.entries_stale_after, - must_revalidate=must_revalidate, + must_revalidate=cast(bool, must_revalidate), **kwargs, ) - def read(self, fields=None): + def read(self, fields: Optional[str] = None) -> "MapAdapter": if fields is not None: new_mapping = {} for field in fields: @@ -212,14 +217,21 @@ def read(self, fields=None): return self.new_variation(mapping=new_mapping) return self - def search(self, query): + def search(self, query: Any) -> Any: """ Return a Adapter with a subset of the mapping. """ return self.query_registry(query, self) - def get_distinct(self, metadata, structure_families, specs, counts): - data = {} + def get_distinct( + self, + metadata: JSON, + structure_families: StructureFamily, + specs: list[Spec], + counts: int, + ) -> dict[str, Any]: + data: dict[str, Any] = {} + # data: dict[str, list[dict[str, Any]]] = {} if metadata: data["metadata"] = {} @@ -240,7 +252,7 @@ def get_distinct(self, metadata, structure_families, specs, counts): return data - def sort(self, sorting): + def sort(self, sorting: SortingItem) -> "MapAdapter": mapping = copy.copy(self._mapping) for key, direction in reversed(sorting): if key == "_": @@ -265,21 +277,27 @@ def sort(self, sorting): # The following two methods are used by keys(), values(), items(). - def _keys_slice(self, start, stop, direction): + def _keys_slice( + self, start: int, stop: int, direction: int + ) -> Union[Iterator[str], list[str]]: if direction > 0: yield from itertools.islice(self._mapping.keys(), start, stop) else: - keys_to_slice = reversed( - list( - itertools.islice( - self._mapping.keys(), 0, len(self._mapping) - start + keys_to_slice = list( + reversed( + list( + itertools.islice( + self._mapping.keys(), 0, len(self._mapping) - start + ) ) ) ) keys = keys_to_slice[start:stop] return keys - def _items_slice(self, start, stop, direction): + def _items_slice( + self, start: int, stop: int, direction: int + ) -> Iterator[Tuple[str, NodeStructure]]: # A goal of this implementation is to avoid iterating over # self._mapping.values() because self._mapping may be a OneShotCachedMap which # only constructs its values at access time. With this in mind, we @@ -290,7 +308,7 @@ def _items_slice(self, start, stop, direction): ) -def walk_string_values(tree, node=None): +def walk_string_values(tree: MapAdapter, node: Optional[Any] = None) -> Iterator[str]: """ >>> list( ... walk_string_values( @@ -316,7 +334,7 @@ def walk_string_values(tree, node=None): yield item -def counter_to_dict(counter, counts): +def counter_to_dict(counter: dict[str, Any], counts: Any) -> list[dict[str, Any]]: if counts: data = [{"value": k, "count": v} for k, v in counter.items() if k is not None] else: @@ -325,7 +343,9 @@ def counter_to_dict(counter, counts): return data -def iter_child_metadata(query_key, tree): +def iter_child_metadata( + query_key: Any, tree: MapAdapter +) -> Iterator[Tuple[str, Any, Any]]: for key, value in tree.items(): term = value.metadata() for subkey in query_key.split("."): @@ -337,7 +357,7 @@ def iter_child_metadata(query_key, tree): yield key, value, term -def full_text_search(query, tree): +def full_text_search(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} text = query.text query_words = set(text.split()) @@ -358,7 +378,7 @@ def full_text_search(query, tree): MapAdapter.register_query(FullText, full_text_search) -def regex(query, tree): +def regex(query: Any, tree: MapAdapter) -> MapAdapter: import re matches = {} @@ -374,7 +394,7 @@ def regex(query, tree): MapAdapter.register_query(Regex, regex) -def eq(query, tree): +def eq(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value, term in iter_child_metadata(query.key, tree): if term == query.value: @@ -385,7 +405,7 @@ def eq(query, tree): MapAdapter.register_query(Eq, eq) -def noteq(query, tree): +def noteq(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value, term in iter_child_metadata(query.key, tree): if term != query.value: @@ -396,7 +416,7 @@ def noteq(query, tree): MapAdapter.register_query(NotEq, noteq) -def contains(query, tree): +def contains(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value, term in iter_child_metadata(query.key, tree): if ( @@ -411,7 +431,7 @@ def contains(query, tree): MapAdapter.register_query(Contains, contains) -def comparison(query, tree): +def comparison(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value, term in iter_child_metadata(query.key, tree): if query.operator not in {"le", "lt", "ge", "gt"}: @@ -425,7 +445,7 @@ def comparison(query, tree): MapAdapter.register_query(Comparison, comparison) -def _in(query, tree): +def _in(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value, term in iter_child_metadata(query.key, tree): if term in query.value: @@ -436,7 +456,7 @@ def _in(query, tree): MapAdapter.register_query(In, _in) -def notin(query, tree): +def notin(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value, term in iter_child_metadata(query.key, tree): if term not in query.value: @@ -447,7 +467,7 @@ def notin(query, tree): MapAdapter.register_query(NotIn, notin) -def specs(query, tree): +def specs(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} include = set(query.include) exclude = set(query.exclude) @@ -463,7 +483,7 @@ def specs(query, tree): MapAdapter.register_query(SpecsQuery, specs) -def structure_family(query, tree): +def structure_family(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value in tree.items(): if value.structure_family == query.value: @@ -475,7 +495,7 @@ def structure_family(query, tree): MapAdapter.register_query(StructureFamilyQuery, structure_family) -def keys_filter(query, tree): +def keys_filter(query: Any, tree: MapAdapter) -> MapAdapter: matches = {} for key, value in tree.items(): if key in query.keys: @@ -491,10 +511,10 @@ class _HIGH_SORTER_CLASS: Enables sort to work when metadata is sparse """ - def __lt__(self, other): + def __lt__(self, other: "_HIGH_SORTER_CLASS") -> bool: return False - def __gt__(self, other): + def __gt__(self, other: "_HIGH_SORTER_CLASS") -> bool: return True diff --git a/tiled/adapters/zarr.py b/tiled/adapters/zarr.py index a00d32bcb..d97005a27 100644 --- a/tiled/adapters/zarr.py +++ b/tiled/adapters/zarr.py @@ -15,6 +15,7 @@ from ..access_policies import DummyAccessPolicy, SimpleAccessPolicy from ..adapters.utils import IndexersMixin from ..iterviews import ItemsView, KeysView, ValuesView +from ..server.schemas import NodeStructure from ..structures.array import ArrayStructure from ..structures.core import StructureFamily from ..utils import node_repr, path_from_uri @@ -24,7 +25,7 @@ def read_zarr( - data_uri: Union[str, list[str]], structure: Optional[ArrayStructure], **kwargs: Any + data_uri: Union[str, list[str]], structure: Optional[NodeStructure], **kwargs: Any ) -> Union["ZarrGroupAdapter", "ZarrArrayAdapter"]: filepath = path_from_uri(data_uri) zarr_obj = zarr.open(filepath) # Group or Array @@ -116,7 +117,7 @@ def __init__( self, node: Any, *, - structure: Optional[ArrayStructure] = None, + structure: Optional[Union[NodeStructure, ArrayStructure]] = None, metadata: Optional[JSON] = None, specs: Optional[list[Spec]] = None, access_policy: Optional[Union[SimpleAccessPolicy, DummyAccessPolicy]] = None, diff --git a/tiled/query_registration.py b/tiled/query_registration.py index 7181ebb69..ec4bca9b7 100644 --- a/tiled/query_registration.py +++ b/tiled/query_registration.py @@ -6,6 +6,7 @@ """ import inspect from dataclasses import fields +from typing import Any from .utils import DictView, UnsupportedQueryType @@ -84,11 +85,11 @@ def inner(cls): class QueryTranslationRegistry: - def __init__(self): + def __init__(self) -> None: self._lookup = {} self._lazy = {} - def register(self, class_, translator): + def register(self, class_, translator) -> Any: self._lookup[class_] = translator return translator diff --git a/tiled/utils.py b/tiled/utils.py index c7cfd917d..969d9638c 100644 --- a/tiled/utils.py +++ b/tiled/utils.py @@ -398,17 +398,17 @@ def tree(tree, max_lines=20): class Sentinel: - def __init__(self, name): + def __init__(self, name: str) -> None: self.name = name - def __repr__(self): + def __repr__(self) -> str: return f"<{self.name}>" - def __copy__(self): + def __copy__(self) -> "Sentinel": # The goal here is to make copy.copy(sentinel) == sentinel return self - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: "Sentinel") -> "Sentinel": # The goal here is to make copy.deepcopy(sentinel) == sentinel return self