Skip to content

Commit

Permalink
xarray.py typed
Browse files Browse the repository at this point in the history
  • Loading branch information
Seher Karakuzu authored and Seher Karakuzu committed Apr 5, 2024
1 parent f4d0f64 commit 443fafa
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 9 deletions.
31 changes: 23 additions & 8 deletions tiled/adapters/xarray.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import collections.abc
import itertools
from typing import Any, Iterable, Iterator, Optional, Union

import xarray

from ..access_policies import DummyAccessPolicy, SimpleAccessPolicy
from ..structures.core import Spec
from .array import ArrayAdapter
from .mapping import MapAdapter
Expand All @@ -14,7 +16,13 @@ class DatasetAdapter(MapAdapter):
"""

@classmethod
def from_dataset(cls, dataset, *, specs=None, access_policy=None):
def from_dataset(
cls,
dataset: Any,
*,
specs: Optional[list[Spec]] = None,
access_policy: Optional[Union[DummyAccessPolicy, SimpleAccessPolicy]] = None,
) -> "DatasetAdapter":
mapping = _DatasetMap(dataset)
specs = specs or []
if "xarray_dataset" not in [spec.name for spec in specs]:
Expand All @@ -26,7 +34,14 @@ def from_dataset(cls, dataset, *, specs=None, access_policy=None):
access_policy=access_policy,
)

def __init__(self, mapping, *args, specs=None, access_policy=None, **kwargs):
def __init__(
self,
mapping: Any,
*args: Any,
specs: Optional[list[Spec]] = None,
access_policy: Optional[Union[SimpleAccessPolicy, DummyAccessPolicy]] = None,
**kwargs: Any,
) -> None:
if isinstance(mapping, xarray.Dataset):
raise TypeError(
"Use DatasetAdapter.from_dataset(...), not DatasetAdapter(...)."
Expand All @@ -35,24 +50,24 @@ def __init__(self, mapping, *args, specs=None, access_policy=None, **kwargs):
mapping, *args, specs=specs, access_policy=access_policy, **kwargs
)

def inlined_contents_enabled(self, depth):
def inlined_contents_enabled(self, depth: int) -> bool:
# Tell the server to in-line the description of each array
# (i.e. data_vars and coords) to avoid latency of a second
# request.
return True


class _DatasetMap(collections.abc.Mapping):
def __init__(self, dataset):
class _DatasetMap(collections.abc.Mapping[str, Any]):
def __init__(self, dataset: Any) -> None:
self._dataset = dataset

def __len__(self):
def __len__(self) -> int:
return len(self._dataset.data_vars) + len(self._dataset.coords)

def __iter__(self):
def __iter__(self) -> Iterator[Any]:
yield from itertools.chain(self._dataset.data_vars, self._dataset.coords)

def __getitem__(self, key):
def __getitem__(self, key: str) -> ArrayAdapter:
data_array = self._dataset[key]
if key in self._dataset.coords:
spec = Spec("xarray_coord")
Expand Down
2 changes: 1 addition & 1 deletion tiled/structures/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Spec:
name: str
version: Optional[str] = None

def __init__(self, name, version=None):
def __init__(self, name, version=None) -> None:
# Enable the name to be passed as a position argument.
# The setattr stuff is necessary to make this work with a frozen dataclass.
object.__setattr__(self, "name", name)
Expand Down

0 comments on commit 443fafa

Please sign in to comment.