Skip to content

Commit

Permalink
Merge pull request #63 from zwicker-group/trajectory
Browse files Browse the repository at this point in the history
* Much improved handling of trajectories
* Better error messages in many places
* Track whether storages have been closed
* Better access checking when items are read from groups
* Improved support for binary data in HDF files
  • Loading branch information
david-zwicker authored Apr 9, 2024
2 parents b624131 + 6bbfb56 commit 645a898
Show file tree
Hide file tree
Showing 15 changed files with 239 additions and 46 deletions.
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-r ../requirements.txt
-r ../requirements_full.txt
Sphinx>=4
sphinx-autodoc-annotation>=1.0
sphinx-gallery>=0.6
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@
intersphinx_mapping = {
"h5py": ("https://docs.h5py.org/en/latest", None),
"matplotlib": ("https://matplotlib.org/stable", None),
"modelrunner": ("https://py-modelrunner.readthedocs.io/en/latest", None),
"napari": ("https://napari.org/", None),
"numba": ("https://numba.pydata.org/numba-doc/latest/", None),
"numpy": ("https://numpy.org/doc/stable", None),
Expand Down
2 changes: 1 addition & 1 deletion modelrunner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# determine the package version
try:
# try reading version of the automatically generated module
from ._version import __version__ # type: ignore
from ._version import __version__
except ImportError:
# determine version automatically from CVS information
from importlib.metadata import PackageNotFoundError, version
Expand Down
5 changes: 5 additions & 0 deletions modelrunner/storage/access_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def parse(cls, obj_or_name: str | AccessMode) -> AccessMode:
if isinstance(obj_or_name, AccessMode):
return obj_or_name
elif isinstance(obj_or_name, str):
if obj_or_name == "closed":
raise ValueError("Cannot use `closed` access mode.")
try:
return cls._defined[obj_or_name]
except KeyError:
Expand All @@ -67,6 +69,9 @@ def parse(cls, obj_or_name: str | AccessMode) -> AccessMode:


# define default access modes
_access_closed = AccessMode(
name="closed", description="Does not allow anything", file_mode="r", read=False
)
access_read = AccessMode(
name="read",
description="Only allows reading",
Expand Down
46 changes: 34 additions & 12 deletions modelrunner/storage/backend/hdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@ def close(self) -> None:
# shorten dynamic arrays to correct size
for hdf_path, size in self._dynamic_array_size.items():
self._file[hdf_path].resize(size, axis=0)

if self._close:
self._file.close()
super().close()

def _get_hdf_path(self, loc: Sequence[str]) -> str:
return "/" + "/".join(loc)
Expand Down Expand Up @@ -120,7 +120,12 @@ def __getitem__(self, loc: Sequence[str]) -> Any:
return self._file
else:
parent, name = self._get_parent(loc)
return parent[name]
try:
return parent[name]
except ValueError as e:
raise ValueError(
f"Invalid location `{name}` in path `{parent.name}`"
) from e

def keys(self, loc: Sequence[str] | None = None) -> Collection[str]:
if loc:
Expand All @@ -133,7 +138,10 @@ def is_group(self, loc: Sequence[str]) -> bool:

def _create_group(self, loc: Sequence[str]):
parent, name = self._get_parent(loc)
return parent.create_group(name)
try:
return parent.create_group(name)
except ValueError as e:
raise ValueError(f"Cannot create group `{name}`") from e

def _read_attrs(self, loc: Sequence[str]) -> AttrsLike:
return self[loc].attrs # type: ignore
Expand All @@ -152,7 +160,16 @@ def _read_array(
# decode potentially binary data
attrs = self._read_attrs(loc)
if attrs.get("__pickled__", False):
arr_like = decode_binary(np.asarray(arr_like).item())
# data has been pickled inside the array
if np.issubdtype(arr_like.dtype, "O"):
# array of object dtype
arr_like = np.frompyfunc(decode_binary, nin=1, nout=1)(arr_like)
elif np.issubdtype(arr_like.dtype, np.uint8):
arr_like = decode_binary(arr_like)
else:
data = np.asarray(arr_like).item()
arr_like = decode_binary(data)

elif not isinstance(arr_like, (h5py.Dataset, np.ndarray, np.generic)):
raise RuntimeError(
f"Found {arr_like.__class__} at location `/{'/'.join(loc)}`"
Expand All @@ -173,16 +190,18 @@ def _write_array(self, loc: Sequence[str], arr: np.ndarray) -> None:
# for this operation need to be checked by the caller!
dataset = parent[name]
if dataset.attrs.get("__pickled__", None) == encode_attr(True):
arr_str = encode_binary(arr, binary=True)
dataset[...] = np.void(arr_str)
arr_bin = encode_binary(arr, binary=True)
assert isinstance(arr_bin, bytes)
dataset[...] = np.void(arr_bin)
else:
dataset[...] = arr

else:
# create a new data set
if arr.dtype == object:
arr_str = encode_binary(arr, binary=True)
dataset = parent.create_dataset(name, data=np.void(arr_str))
arr_bin = encode_binary(arr, binary=True)
assert isinstance(arr_bin, bytes)
dataset = parent.create_dataset(name, data=np.void(arr_bin))
dataset.attrs["__pickled__"] = True
else:
args = {"compression": "gzip"} if self.compression else {}
Expand All @@ -200,11 +219,13 @@ def _create_dynamic_array(
record_array: bool = False,
) -> None:
parent, name = self._get_parent(loc)
if dtype == object:
dt = h5py.special_dtype(vlen=np.dtype("uint8"))
if np.issubdtype(dtype, "O"):
try:
dataset = parent.create_dataset(
name, shape=(1,) + shape, maxshape=(None,) + shape, dtype=dt
name,
shape=(1,) + shape,
maxshape=(None,) + shape,
dtype=h5py.vlen_dtype(np.uint8),
)
except ValueError:
raise RuntimeError(f"Array `{'/'.join(loc)}` already exists")
Expand Down Expand Up @@ -250,7 +271,8 @@ def _extend_dynamic_array(self, loc: Sequence[str], arr: ArrayLike) -> None:

if dataset.attrs.get("__pickled__", False):
arr_bin = encode_binary(arr, binary=True)
dataset[size] = np.frombuffer(arr_bin, dtype="uint8")
assert isinstance(arr_bin, bytes)
dataset[size] = np.frombuffer(arr_bin, dtype=np.uint8)
else:
dataset[size] = arr

Expand Down
1 change: 1 addition & 0 deletions modelrunner/storage/backend/text_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ def flush(self) -> None:
def close(self) -> None:
"""close the file and write the data to the file"""
self.flush()
super().close()

def to_text(self, simplify: bool | None = None) -> str:
"""serialize the data and return it as a string
Expand Down
2 changes: 1 addition & 1 deletion modelrunner/storage/backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def simplify_data(data):
elif isinstance(data, np.ndarray):
if np.isscalar(data):
data = data.item()
elif data.dtype == object and data.size == 1:
elif np.issubdtype(data.dtype, "O") and data.size == 1:
data = [simplify_data(data.item())]
elif data.size <= 100:
# for less than ~100 items a list is actually more efficient to store
Expand Down
1 change: 1 addition & 0 deletions modelrunner/storage/backend/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def close(self) -> None:
if self._close:
self._store.close()
self._root = None
super().close()

def _get_parent(self, loc: Sequence[str]) -> tuple[zarr.Group, str]:
"""get the parent group for a particular location
Expand Down
11 changes: 9 additions & 2 deletions modelrunner/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import numpy as np
from numpy.typing import ArrayLike, DTypeLike

from .access_modes import AccessError, AccessMode, ModeType
from .access_modes import AccessError, AccessMode, ModeType, _access_closed
from .attributes import Attrs, AttrsLike, decode_attrs, encode_attr
from .utils import encode_class

Expand All @@ -46,6 +46,8 @@ class StorageBase(metaclass=ABCMeta):
"""list of str: all file extensions supported by this storage"""
default_codec = numcodecs.Pickle()
""":class:`numcodecs.Codec`: the default codec used for encoding binary data"""
mode: AccessMode
""":class:`~modelrunner.storage.access_modes.AccessMode`: access mode"""

_codec: numcodecs.abc.Codec
""":class:`numcodecs.Codec`: the specific codec used for encoding binary data"""
Expand All @@ -62,7 +64,12 @@ def __init__(self, *, mode: ModeType = "read"):

def close(self) -> None:
"""closes the storage, potentially writing data to a persistent place"""
...
self.mode = _access_closed

@property
def closed(self) -> bool:
"""bool: determines whether the storage has been closed"""
return self.mode is _access_closed

@property
def can_update(self) -> bool:
Expand Down
14 changes: 11 additions & 3 deletions modelrunner/storage/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def __init__(self, storage: StorageBase | StorageGroup, loc: Location = None):
loc (str or list of str):
Denotes the location (path) of the group within the storage
"""
self.loc = []
self.loc = [] # initialize empty location, since `loc` is relative to root
self.loc = self._get_loc(loc)

if isinstance(storage, StorageBase):
Expand All @@ -43,6 +43,14 @@ def __init__(self, storage: StorageBase | StorageGroup, loc: Location = None):
f"Cannot interprete `storage` of type `{storage.__class__}`"
)

assert isinstance(self._storage, StorageBase)
if self._storage.closed:
raise RuntimeError("Cannot access group in closed storage")
if self.loc not in self._storage:
raise RuntimeError(
f'"/{"/".join(self.loc)}" is not in storage. Available root items are: '
f"{list(self._storage.keys(loc=[]))}"
)
if not self.is_group():
raise RuntimeError(f'"/{"/".join(self.loc)}" is not a group')

Expand Down Expand Up @@ -172,10 +180,10 @@ def read_item(self, loc: Location, *, use_class: bool = True) -> Any:
# read the item using the generic classes
obj_type = self._storage._read_attrs(loc_list).get("__type__")
if obj_type in {"array", "dynamic_array"}:
arr = self._storage._read_array(loc_list, copy=True)
arr = self._storage.read_array(loc_list)
return Array(arr, attrs=self._storage.read_attrs(loc_list))
elif obj_type == "object":
return self._storage._read_object(loc_list)
return self._storage.read_object(loc_list)
else:
raise RuntimeError(f"Cannot read objects of type `{obj_type}`")

Expand Down
17 changes: 15 additions & 2 deletions modelrunner/storage/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pathlib import Path
from typing import Union

from .access_modes import AccessMode
from .backend import AVAILABLE_STORAGE, MemoryStorage
from .base import StorageBase
from .group import StorageGroup
Expand Down Expand Up @@ -123,19 +124,31 @@ def __init__(

if store_obj is None:
raise TypeError(f"Unsupported store type {storage.__class__.__name__}")
assert isinstance(store_obj, StorageBase)

super().__init__(store_obj, loc=loc)
self._closed = False

def close(self) -> None:
"""close the storage (and flush all data to persistent storage if necessary)"""
if self._close:
self._storage.close()
else:
self._storage.flush()
self._closed = True

@property
def closed(self) -> bool:
"""bool: determines whether the storage group has been closed"""
return self._closed

@property
def mode(self) -> AccessMode:
""":class:`~modelrunner.storage.access_modes.AccessMode`: access mode"""
return self._storage.mode

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
if self.close:
self.close()
self.close()
40 changes: 25 additions & 15 deletions modelrunner/storage/trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from modelrunner.storage.access_modes import ModeType

from .access_modes import AccessError
from .base import StorageBase
from .group import StorageGroup
from .tools import open_storage
from .utils import Location, storage_actions
Expand Down Expand Up @@ -53,15 +55,15 @@ class TrajectoryWriter:

def __init__(
self,
storage,
storage: str | Path | StorageGroup | StorageBase,
loc: Location = "trajectory",
*,
attrs: dict[str, Any] | None = None,
mode: ModeType | None = None,
):
"""
Args:
store (MutableMapping or string):
store:
Store or path to directory in file system or name of zip file.
loc (str or list of str):
The location in the storage where the trajectory data is written.
Expand All @@ -84,22 +86,29 @@ def __init__(
)
mode = "full"

storage = open_storage(storage, mode=mode)
self._storage = open_storage(storage, mode=mode)

if storage._storage.mode.insert:
self._trajectory = storage.create_group(loc, cls=Trajectory)
if self._storage.mode.insert:
self._trajectory = self._storage.create_group(loc, cls=Trajectory)
elif self._storage.mode.dynamic_append:
self._trajectory = StorageGroup(self._storage, loc)
else:
self._trajectory = StorageGroup(storage, loc)
raise AccessError(f"Cannot insert data. Open storage with write access")

# make sure we don't overwrite data
if "times" in self._trajectory or "data" in self._trajectory:
if not storage._storage.mode.dynamic_append:
if not self._storage.mode.dynamic_append:
raise OSError("Storage already contains data and we cannot append")
self._item_type = self._trajectory.attrs["item_type"]

if attrs is not None:
self._trajectory.write_attrs(attrs=attrs)

@property
def times(self) -> np.ndarray:
""":class:`~numpy.ndarray`: Time points written so far"""
return self._trajectory.read_array("time")

def append(self, data: Any, time: float | None = None) -> None:
"""append data to the trajectory
Expand Down Expand Up @@ -140,7 +149,7 @@ def append(self, data: Any, time: float | None = None) -> None:
self._trajectory.extend_dynamic_array("time", time)

def close(self):
self._trajectory._storage.close()
self._storage.close()

def __enter__(self):
return self
Expand Down Expand Up @@ -170,21 +179,22 @@ def __init__(self, storage: StorageGroup, loc: Location = "trajectory"):
The location in the storage where the trajectory data is read.
"""
# open the storage
storage = open_storage(storage, mode="read")
self._trajectory = StorageGroup(storage, loc)
self._storage = open_storage(storage, mode="read")
self._loc = self._storage._get_loc(loc)
# self._storage = storage#StorageGroup(storage, loc)

# read some intial data from storage
self._item_type = self._trajectory.attrs["item_type"]
self.times = self._trajectory.read_array("time")
self.attrs = self._trajectory.read_attrs()
self.attrs = self._storage.read_attrs(self._loc)
self._item_type = self.attrs["item_type"]
self.times = self._storage.read_array(self._loc + ["time"])

# check temporal ordering
if np.any(np.diff(self.times) < 0):
raise ValueError(f"Times are not monotonously increasing: {self.times}")

def close(self) -> None:
"""close the openend storage"""
self._trajectory._storage.close()
self._storage.close()

def __len__(self) -> int:
return len(self.times)
Expand All @@ -207,7 +217,7 @@ def _get_item(self, t_index: int) -> Any:
if not 0 <= t_index < len(self):
raise IndexError("Time index out of range")

res = self._trajectory.read_array("data", index=t_index)
res = self._storage.read_array(self._loc + ["data"], index=t_index)
if self._item_type == "array":
return res
elif self._item_type == "object":
Expand Down
Loading

0 comments on commit 645a898

Please sign in to comment.