Skip to content

Commit

Permalink
Merge pull request #52 from zwicker-group/types
Browse files Browse the repository at this point in the history
Cleaned up some types
  • Loading branch information
david-zwicker authored Dec 23, 2023
2 parents 7920a15 + 2c95669 commit e69243e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion modelrunner/storage/backend/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _read_object(self, loc: Sequence[str]) -> Any:
return self[loc][0]

def _write_object(self, loc: Sequence[str], obj: Any) -> None:
arr = np.empty(1, dtype=object) # encode object in an array
arr: np.ndarray = np.empty(1, dtype=object) # encode object in an array
arr[0] = obj
parent, name = self._get_parent(loc)
parent.array(name, arr, object_codec=self.codec, overwrite=True)
6 changes: 3 additions & 3 deletions modelrunner/storage/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from __future__ import annotations

from pathlib import Path
from typing import Any, Dict, Iterator, Literal, Optional
from typing import Any, Dict, Iterator, Literal, Optional, Tuple

import numpy as np

Expand Down Expand Up @@ -114,7 +114,7 @@ def append(self, data: Any, time: Optional[float] = None) -> None:
# initialize new trajectory
if isinstance(data, np.ndarray):
dtype = data.dtype
shape = data.shape
shape: Tuple[int, ...] = data.shape
self._item_type = "array"
else:
dtype = object
Expand All @@ -132,7 +132,7 @@ def append(self, data: Any, time: Optional[float] = None) -> None:
if self._item_type == "array":
self._trajectory.extend_dynamic_array("data", data)
elif self._item_type == "object":
arr = np.empty((), dtype=object)
arr: np.ndarray = np.empty((), dtype=object)
arr[...] = data
self._trajectory.extend_dynamic_array("data", arr)
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_result_collections():
assert rc1 == ResultCollection([r1, r2, r3])

# test result dataframes
rc1.dataframe
rc1.as_dataframe()


def test_collection_groupby():
Expand Down

0 comments on commit e69243e

Please sign in to comment.