Skip to content

Commit

Permalink
handle references in attrs and datasets
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Mar 18, 2024
1 parent 9eeb63c commit 5f937ba
Show file tree
Hide file tree
Showing 10 changed files with 157 additions and 55 deletions.
4 changes: 4 additions & 0 deletions lindi/LindiClient/LindiAttributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import Union
import zarr
from .LindiReference import LindiReference


class LindiAttributes:
Expand All @@ -10,6 +11,9 @@ def get(self, key, default=None):
return self._object.attrs.get(key, default)

def __getitem__(self, key):
val = self._object.attrs[key]
if isinstance(val, dict) and "_REFERENCE" in val:
return LindiReference(val["_REFERENCE"])
return self._object.attrs[key]

def __setitem__(self, key, value):
Expand Down
27 changes: 27 additions & 0 deletions lindi/LindiClient/LindiClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fsspec.implementations.reference import ReferenceFileSystem
from zarr.storage import Store
from .LindiGroup import LindiGroup
from .LindiReference import LindiReference


class LindiClient(LindiGroup):
Expand Down Expand Up @@ -53,6 +54,32 @@ def from_reference_file_system(data: dict) -> "LindiClient":
fs = ReferenceFileSystem(data).get_mapper(root="/")
return LindiClient.from_zarr_store(fs)

def __getitem__(self, key): # type: ignore
if isinstance(key, str):
if key.startswith('/'):
key = key[1:]
parts = key.split("/")
if len(parts) == 1:
return super().__getitem__(key)
else:
g = self
for part in parts:
g = g[part]
return g
elif isinstance(key, LindiReference):
if key._source != '.':
raise Exception(f'For now, source of reference must be ".", got "{key._source}"')
if key._source_object_id is not None:
if key._source_object_id != self._zarr_group.attrs.get("object_id"):
raise Exception(f'Mismatch in source object_id: "{key._source_object_id}" and "{self._zarr_group.attrs.get("object_id")}"')
target = self[key._path]
if key._object_id is not None:
if key._object_id != target.attrs.get("object_id"):
raise Exception(f'Mismatch in object_id: "{key._object_id}" and "{target.attrs.get("object_id")}"')
return target
else:
raise Exception(f'Cannot use key "{key}" of type "{type(key)}" to index into a LindiClient')


def _download_file(url: str, filename: str) -> None:
headers = {
Expand Down
10 changes: 9 additions & 1 deletion lindi/LindiClient/LindiDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import h5py
import remfile
from .LindiAttributes import LindiAttributes
from .LindiReference import LindiReference


class LindiDataset:
Expand Down Expand Up @@ -98,7 +99,11 @@ def __getitem__(self, selection):
# Find the index of this field in the compound dtype
ind = self._compound_dtype.names.index(selection)
# Get the dtype of this field
dtype = np.dtype(self._compound_dtype[ind])
dt = self._compound_dtype[ind]
if dt == 'object':
dtype = h5py.Reference
else:
dtype = np.dtype(dt)
# Return a new object that can be sliced further
# It's important that the return type is Any here, because otherwise we get linter problems
ret: Any = LindiDatasetCompoundFieldSelection(
Expand Down Expand Up @@ -144,6 +149,9 @@ def __init__(self, *, dataset: LindiDataset, ind: int, dtype: np.dtype):
# Prepare the data in memory
za = self._dataset._zarr_array
d = [za[i][self._ind] for i in range(len(za))]
if self._dtype == h5py.Reference:
# Convert to LindiReference
d = [LindiReference(x['_REFERENCE']) for x in d]
self._data = np.array(d, dtype=self._dtype)

def __len__(self):
Expand Down
26 changes: 3 additions & 23 deletions lindi/LindiClient/LindiGroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ def get(self, key, default=None):
return default

def __getitem__(self, key):
if isinstance(key, dict):
# might be a reference
if "_REFERENCE" in key:
return LindiReference(key['_REFERENCE'])
if not isinstance(key, str):
raise Exception(f'Cannot use key "{key}" of type "{type(key)}" to index into a LindiGroup, at path "{self._zarr_group.name}"')
raise Exception(
f'Cannot use key "{key}" of type "{type(key)}" to index into a LindiGroup, at path "{self._zarr_group.name}"'
)
if key in self._zarr_group.keys():
x = self._zarr_group[key]
if isinstance(x, zarr.Group):
Expand All @@ -46,21 +44,3 @@ def __getitem__(self, key):
def __iter__(self):
for k in self.keys():
yield k


class LindiReference:
def __init__(self, obj: dict):
self._object_id = obj["object_id"]
self._path = obj["path"]
self._source = obj["source"]
self._source_object_id = obj["source_object_id"]

@property
def name(self):
return self._path

def __repr__(self):
return f"LindiReference({self._source}, {self._path})"

def __str__(self):
return f"LindiReference({self._source}, {self._path})"
16 changes: 16 additions & 0 deletions lindi/LindiClient/LindiReference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
class LindiReference:
def __init__(self, obj: dict):
self._object_id = obj["object_id"]
self._path = obj["path"]
self._source = obj["source"]
self._source_object_id = obj["source_object_id"]

@property
def name(self):
return self._path

def __repr__(self):
return f"LindiReference({self._source}, {self._path})"

def __str__(self):
return f"LindiReference({self._source}, {self._path})"
2 changes: 1 addition & 1 deletion lindi/LindiClient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
from .LindiGroup import LindiGroup # noqa: F401
from .LindiDataset import LindiDataset # noqa: F401
from .LindiAttributes import LindiAttributes # noqa: F401
from .LindiGroup import LindiReference # noqa: F401
from .LindiReference import LindiReference # noqa: F401
20 changes: 13 additions & 7 deletions lindi/LindiH5Store/_h5_attr_to_zarr_attr.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,23 @@ def _h5_attr_to_zarr_attr(attr: Any, *, label: str = '', h5f: h5py.File):
Otherwise, raise NotImplementedError
"""
if isinstance(attr, bytes):
return attr.decode('utf-8') # is this reversible?
return attr.decode('utf-8')
elif isinstance(attr, (int, float, str)):
return attr
elif np.issubdtype(type(attr), np.integer):
return int(attr)
elif np.issubdtype(type(attr), np.floating):
return float(attr)
elif np.issubdtype(type(attr), np.bool_):
return bool(attr)
elif np.issubdtype(type(attr), np.bytes_):
return attr.decode('utf-8')
elif isinstance(attr, h5py.Reference):
return _h5_ref_to_zarr_attr(attr, label=label, h5f=h5f)
elif isinstance(attr, list):
return [_h5_attr_to_zarr_attr(x, label=label, h5f=h5f) for x in attr]
elif isinstance(attr, dict):
return {k: _h5_attr_to_zarr_attr(v, label=label, h5f=h5f) for k, v in attr.items()}
elif isinstance(attr, h5py.Reference):
return _h5_ref_to_zarr_attr(attr, label=label, h5f=h5f)
elif np.issubdtype(type(attr), np.integer):
return int(attr) # possible loss of precision?
elif np.issubdtype(type(attr), np.floating):
return float(attr) # possible loss of precision?
elif isinstance(attr, np.ndarray):
return _h5_attr_to_zarr_attr(attr.tolist(), label=label, h5f=h5f)
else:
Expand Down Expand Up @@ -62,6 +66,8 @@ def _h5_ref_to_zarr_attr(ref: h5py.Reference, *, label: str = '', h5f: h5py.File
# is to do an initial pass through the file and build a map of object IDs to
# paths. This would need to happen elsewhere in the code.
deref_objname = h5py.h5r.get_name(ref, file_id)
if deref_objname is None:
raise ValueError(f"Could not dereference object with reference {ref}")
deref_objname = deref_objname.decode("utf-8")

dref_obj = h5f[deref_objname]
Expand Down
2 changes: 1 addition & 1 deletion lindi/LindiH5Store/_h5_filters_to_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
# https://github.com/fsspec/kerchunk
# Copyright (c) 2020 Intake
# MIT License
def _h5_filters_to_codecs(h5obj: h5py.Dataset) -> Union[List[Codec], None]:
def _h5_filters_to_codecs_kerchunk(h5obj: h5py.Dataset) -> Union[List[Codec], None]:
"""Decode HDF5 filters to numcodecs filters."""
if h5obj.scaleoffset:
raise RuntimeError(
Expand Down
40 changes: 19 additions & 21 deletions lindi/LindiH5Store/_zarr_info_for_h5_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import h5py
from numcodecs.abc import Codec
from ._h5_filters_to_codecs import _h5_filters_to_codecs
from ._h5_attr_to_zarr_attr import _h5_ref_to_zarr_attr


@dataclass
Expand Down Expand Up @@ -111,17 +112,13 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
object_codec = numcodecs.JSON()
data = h5_dataset[:]
data_vec_view = data.ravel()
_warning_reference_in_dataset_printed = False
for i, val in enumerate(data_vec_view):
if isinstance(val, bytes):
data_vec_view[i] = val.decode()
elif isinstance(val, str):
data_vec_view[i] = val
elif isinstance(val, h5py.h5r.Reference):
if not _warning_reference_in_dataset_printed:
print(f'Warning: reference in dataset {h5_dataset.name} not handled')
_warning_reference_in_dataset_printed = True
data_vec_view[i] = None
elif isinstance(val, h5py.Reference):
data_vec_view[i] = _h5_ref_to_zarr_attr(val, label=f'{h5_dataset.name}[{i}]', h5f=h5_dataset.file)
else:
raise Exception(f'Cannot handle dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
inline_data = json.dumps(data.tolist() + ['|O', list(shape)], separators=(',', ':')).encode('utf-8')
Expand All @@ -137,25 +134,20 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
elif dtype.kind in 'SU': # byte string or unicode string
raise Exception(f'Not yet implemented (2): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')
elif dtype.kind == 'V': # void (i.e. compound)
# This is an array representing the compound type
# For example: [['x', 'uint32'], ['y', 'uint32'], ['weight', 'float32']]
compound_dtype = [
[name, str(dtype[name])]
for name in dtype.names
]
if h5_dataset.ndim == 1:
# for now we only handle the case of a 1D compound dataset
data = h5_dataset[:]
# Create an array that would be for example like this
# [[3, 4, 5.3], [2, 1, 7.1], ...]
# dtype = np.dtype([('x', np.float64), ('y', np.int32), ('weight', np.float64)])
# array_list = [[3, 4, 5.3], [2, 1, 7.1], ...]
# where the first entry corresponds to x in the example above, the second to y, and the third to weight
# This is a more compact representation than [{'x': ...}]
# The _COMPOUND_DTYPE attribute will be set on the dataset in the zarr store
# which will be used to interpret the data
array_list = [
[
_json_serialize(data[name][i], type_str)
for name, type_str in compound_dtype
_json_serialize(data[name][i], dtype[name], h5_dataset)
for name in dtype.names
]
for i in range(h5_dataset.shape[0])
]
Expand All @@ -177,15 +169,21 @@ def _zarr_info_for_h5_dataset(h5_dataset: h5py.Dataset) -> ZarrInfoForH5Dataset:
raise Exception(f'Not yet implemented (3): dataset {h5_dataset.name} with dtype {dtype} and shape {shape}')


def _json_serialize(val: Any, type_str: str) -> Any:
if type_str.startswith('uint'):
def _json_serialize(val: Any, dtype: np.dtype, h5_dataset: h5py.Dataset) -> Any:
if dtype.kind in ['i', 'u']: # integer, unsigned integer
return int(val)
elif type_str.startswith('int'):
return int(val)
elif type_str.startswith('float'):
elif dtype.kind == 'f': # float
return float(val)
elif dtype.kind == 'b': # boolean
return bool(val)
elif dtype.kind == 'S': # byte string
return val.decode()
elif dtype.kind == 'U': # unicode string
return val
elif dtype == h5py.Reference:
return _h5_ref_to_zarr_attr(val, label=f'{h5_dataset.name}', h5f=h5_dataset.file)
else:
raise Exception(f'Unable to serialize {val} with type {type_str}')
raise Exception(f'Cannot serialize item {val} with dtype {dtype} when serializing dataset {h5_dataset.name} with compound dtype.')


def _get_numeric_format_str(dtype: Any) -> Union[str, None]:
Expand Down
65 changes: 64 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
import h5py
import tempfile
from lindi import LindiH5Store, LindiClient, LindiDataset, LindiGroup
from lindi import LindiH5Store, LindiClient, LindiDataset, LindiGroup, LindiReference
import pytest


Expand Down Expand Up @@ -223,6 +223,69 @@ def test_nan_inf_attr():
assert X2.attrs["ninf"] == '-Infinity'


def test_reference_attributes():
print("Testing reference attributes")
with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/test.h5"
with h5py.File(filename, "w") as f:
X_ds = f.create_dataset("X", data=[1, 2, 3])
Y_ds = f.create_dataset("Y", data=[4, 5, 6])
X_ds.attrs["ref"] = Y_ds.ref
h5f = h5py.File(filename, "r")
with LindiH5Store.from_file(filename, url=filename) as store:
rfs = store.to_reference_file_system()
client = LindiClient.from_reference_file_system(rfs)

X1 = h5f["X"]
assert isinstance(X1, h5py.Dataset)
X2 = client["X"]
assert isinstance(X2, LindiDataset)

ref1 = X1.attrs["ref"]
assert isinstance(ref1, h5py.Reference)
ref2 = X2.attrs["ref"]
assert isinstance(ref2, LindiReference)

target1 = h5f[ref1]
assert isinstance(target1, h5py.Dataset)
target2 = client[ref2]
assert isinstance(target2, LindiDataset)

assert _check_equal(target1[:], target2[:])


def test_reference_in_compound_dtype():
print("Testing reference in dataset with compound dtype")
with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/test.h5"
with h5py.File(filename, "w") as f:
compound_dtype = np.dtype([("x", "i4"), ("y", h5py.special_dtype(ref=h5py.Reference))])
Y_ds = f.create_dataset("Y", data=[1, 2, 3])
f.create_dataset("X", data=[(1, Y_ds.ref), (2, Y_ds.ref)], dtype=compound_dtype)
h5f = h5py.File(filename, "r")
with LindiH5Store.from_file(filename, url=filename) as store:
rfs = store.to_reference_file_system()
client = LindiClient.from_reference_file_system(rfs)

X1 = h5f["X"]
assert isinstance(X1, h5py.Dataset)
X2 = client["X"]
assert isinstance(X2, LindiDataset)

assert _check_equal(X1["x"][:], X2["x"][:])
ref1 = X1["y"][0]
assert isinstance(ref1, h5py.Reference)
ref2 = X2["y"][0]
assert isinstance(ref2, LindiReference)

target1 = h5f[ref1]
assert isinstance(target1, h5py.Dataset)
target2 = client[ref2]
assert isinstance(target2, LindiDataset)

assert _check_equal(target1[:], target2[:])


def _check_equal(a, b):
# allow comparison of bytes and strings
if isinstance(a, str):
Expand Down

0 comments on commit 5f937ba

Please sign in to comment.