From 2c956695ac5e1e7a995493cc7c3002e24854261e Mon Sep 17 00:00:00 2001 From: David Zwicker Date: Sat, 23 Dec 2023 18:15:36 +0100 Subject: [PATCH] Cleaned up some types --- modelrunner/storage/backend/zarr.py | 2 +- modelrunner/storage/trajectory.py | 6 +++--- tests/test_results.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/modelrunner/storage/backend/zarr.py b/modelrunner/storage/backend/zarr.py index fe9927b..e43e0d6 100644 --- a/modelrunner/storage/backend/zarr.py +++ b/modelrunner/storage/backend/zarr.py @@ -226,7 +226,7 @@ def _read_object(self, loc: Sequence[str]) -> Any: return self[loc][0] def _write_object(self, loc: Sequence[str], obj: Any) -> None: - arr = np.empty(1, dtype=object) # encode object in an array + arr: np.ndarray = np.empty(1, dtype=object) # encode object in an array arr[0] = obj parent, name = self._get_parent(loc) parent.array(name, arr, object_codec=self.codec, overwrite=True) diff --git a/modelrunner/storage/trajectory.py b/modelrunner/storage/trajectory.py index 91661d4..2fa488e 100644 --- a/modelrunner/storage/trajectory.py +++ b/modelrunner/storage/trajectory.py @@ -13,7 +13,7 @@ from __future__ import annotations from pathlib import Path -from typing import Any, Dict, Iterator, Literal, Optional +from typing import Any, Dict, Iterator, Literal, Optional, Tuple import numpy as np @@ -114,7 +114,7 @@ def append(self, data: Any, time: Optional[float] = None) -> None: # initialize new trajectory if isinstance(data, np.ndarray): dtype = data.dtype - shape = data.shape + shape: Tuple[int, ...] = data.shape self._item_type = "array" else: dtype = object @@ -132,7 +132,7 @@ def append(self, data: Any, time: Optional[float] = None) -> None: if self._item_type == "array": self._trajectory.extend_dynamic_array("data", data) elif self._item_type == "object": - arr = np.empty((), dtype=object) + arr: np.ndarray = np.empty((), dtype=object) arr[...] = data self._trajectory.extend_dynamic_array("data", arr) else: diff --git a/tests/test_results.py b/tests/test_results.py index 9585f4e..49814ea 100644 --- a/tests/test_results.py +++ b/tests/test_results.py @@ -127,7 +127,7 @@ def test_result_collections(): assert rc1 == ResultCollection([r1, r2, r3]) # test result dataframes - rc1.dataframe + rc1.as_dataframe() def test_collection_groupby():