From 443fafa785004c5628944713020739e275f347fa Mon Sep 17 00:00:00 2001 From: Seher Karakuzu Date: Thu, 4 Apr 2024 15:34:05 -0400 Subject: [PATCH] xarray.py typed --- tiled/adapters/xarray.py | 31 +++++++++++++++++++++++-------- tiled/structures/core.py | 2 +- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/tiled/adapters/xarray.py b/tiled/adapters/xarray.py index f8e9088c1..cc44e5040 100644 --- a/tiled/adapters/xarray.py +++ b/tiled/adapters/xarray.py @@ -1,8 +1,10 @@ import collections.abc import itertools +from typing import Any, Iterable, Iterator, Optional, Union import xarray +from ..access_policies import DummyAccessPolicy, SimpleAccessPolicy from ..structures.core import Spec from .array import ArrayAdapter from .mapping import MapAdapter @@ -14,7 +16,13 @@ class DatasetAdapter(MapAdapter): """ @classmethod - def from_dataset(cls, dataset, *, specs=None, access_policy=None): + def from_dataset( + cls, + dataset: Any, + *, + specs: Optional[list[Spec]] = None, + access_policy: Optional[Union[DummyAccessPolicy, SimpleAccessPolicy]] = None, + ) -> "DatasetAdapter": mapping = _DatasetMap(dataset) specs = specs or [] if "xarray_dataset" not in [spec.name for spec in specs]: @@ -26,7 +34,14 @@ def from_dataset(cls, dataset, *, specs=None, access_policy=None): access_policy=access_policy, ) - def __init__(self, mapping, *args, specs=None, access_policy=None, **kwargs): + def __init__( + self, + mapping: Any, + *args: Any, + specs: Optional[list[Spec]] = None, + access_policy: Optional[Union[SimpleAccessPolicy, DummyAccessPolicy]] = None, + **kwargs: Any, + ) -> None: if isinstance(mapping, xarray.Dataset): raise TypeError( "Use DatasetAdapter.from_dataset(...), not DatasetAdapter(...)." @@ -35,24 +50,24 @@ def __init__(self, mapping, *args, specs=None, access_policy=None, **kwargs): mapping, *args, specs=specs, access_policy=access_policy, **kwargs ) - def inlined_contents_enabled(self, depth): + def inlined_contents_enabled(self, depth: int) -> bool: # Tell the server to in-line the description of each array # (i.e. data_vars and coords) to avoid latency of a second # request. return True -class _DatasetMap(collections.abc.Mapping): - def __init__(self, dataset): +class _DatasetMap(collections.abc.Mapping[str, Any]): + def __init__(self, dataset: Any) -> None: self._dataset = dataset - def __len__(self): + def __len__(self) -> int: return len(self._dataset.data_vars) + len(self._dataset.coords) - def __iter__(self): + def __iter__(self) -> Iterator[Any]: yield from itertools.chain(self._dataset.data_vars, self._dataset.coords) - def __getitem__(self, key): + def __getitem__(self, key: str) -> ArrayAdapter: data_array = self._dataset[key] if key in self._dataset.coords: spec = Spec("xarray_coord") diff --git a/tiled/structures/core.py b/tiled/structures/core.py index 2891814ab..a4708b35e 100644 --- a/tiled/structures/core.py +++ b/tiled/structures/core.py @@ -22,7 +22,7 @@ class Spec: name: str version: Optional[str] = None - def __init__(self, name, version=None): + def __init__(self, name, version=None) -> None: # Enable the name to be passed as a position argument. # The setattr stuff is necessary to make this work with a frozen dataclass. object.__setattr__(self, "name", name)