Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

exclude zattrs when empty #16

Merged
merged 3 commits into from
Mar 17, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions lindi/LindiH5Store/LindiH5Store.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,9 @@ def __init__(
self._opts = _opts

# Some datasets do not correspond to traditional chunked datasets. For
# those datasets, we need to store the inline data so that we can
# return it when the chunk is requested. We store the inline data in a
# dictionary with the dataset name as the key. The values are of type
# bytes.
# those datasets, we need to store the inline data so that we can return
# it when the chunk is requested. We store the inline data in a
# dictionary with the dataset name as the key. The values are the bytes.
self._inline_data_for_arrays: Dict[str, bytes] = {}

self._external_array_links: Dict[str, Union[dict, None]] = {}
Expand Down Expand Up @@ -114,7 +113,9 @@ def __getitem__(self, key):
key_name = parts[-1]
key_parent = "/".join(parts[:-1])
if key_name == ".zattrs":
# Get the attributes of a group or dataset
# Get the attributes of a group or dataset. We return this even if
# it is empty, but we exclude it when writing out the reference file
# system.
magland marked this conversation as resolved.
Show resolved Hide resolved
return self._get_zattrs_bytes(key_parent)
elif key_name == ".zgroup":
# Get the .zgroup JSON text for a group
Expand All @@ -138,16 +139,27 @@ def __contains__(self, key):
key_parent = "/".join(parts[:-1])
if key_name == ".zattrs":
h5_item = _get_h5_item(self._h5f, key_parent)
return isinstance(h5_item, h5py.Group)
if not h5_item:
return False
# We always return True here even if the attributes are going to be
# empty, because it's not worth including the logic. But when we
# write out the ref file system, we exclude it there.
return isinstance(h5_item, h5py.Group) or isinstance(h5_item, h5py.Dataset)
elif key_name == ".zgroup":
h5_item = _get_h5_item(self._h5f, key_parent)
if not h5_item:
return False
return isinstance(h5_item, h5py.Group)
elif key_name == ".zarray":
h5_item = _get_h5_item(self._h5f, key_parent)
if not h5_item:
return False
return isinstance(h5_item, h5py.Dataset)
else:
# a chunk file
h5_item = _get_h5_item(self._h5f, key_parent)
if not h5_item:
return False
if not isinstance(h5_item, h5py.Dataset):
return False
external_array_link = self._get_external_array_link(key_parent, h5_item)
Expand Down Expand Up @@ -182,6 +194,9 @@ def _get_zattrs_bytes(self, parent_key: str):
if self._h5f is None:
raise Exception("Store is closed")
h5_item = _get_h5_item(self._h5f, parent_key)
assert isinstance(h5_item, h5py.Group) or isinstance(h5_item, h5py.Dataset), (
f"Item {parent_key} is not a group or dataset in _get_zattrs_bytes"
)
# We create a dummy zarr group and copy the attributes to it. That way
# we know that zarr has accepted them and they are serialized in the
# correct format.
Expand Down Expand Up @@ -411,7 +426,9 @@ def _add_ref(key: str, content: Union[bytes, None]):

def _process_group(key, item: h5py.Group):
if isinstance(item, h5py.Group):
_add_ref(_join(key, ".zattrs"), self.get(_join(key, ".zattrs")))
zattrs_bytes = self.get(_join(key, ".zattrs"))
if zattrs_bytes != b"{}": # don't include empty zattrs
_add_ref(_join(key, ".zattrs"), self.get(_join(key, ".zattrs")))
_add_ref(_join(key, ".zgroup"), self.get(_join(key, ".zgroup")))
for k in item.keys():
subitem = item[k]
Expand Down Expand Up @@ -485,6 +502,7 @@ def _get_chunk_names_for_dataset(chunk_coords_shape: List[int]) -> List[str]:


def _reformat_json(x: Union[bytes, None]) -> Union[bytes, None]:
"""Reformat to not include whitespace and to encode NaN, Inf, and -Inf as strings."""
if x is None:
return None
a = json.loads(x.decode("utf-8"))
Expand Down
2 changes: 1 addition & 1 deletion lindi/LindiH5Store/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

def _get_h5_item(h5f: h5py.File, key: str):
"""Get an item from the h5 file, given its key."""
return h5f['/' + key]
return h5f.get('/' + key, None)


def _read_bytes(file: IO, offset: int, count: int):
Expand Down
13 changes: 12 additions & 1 deletion tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@ def test_scalar_datasets():
with tempfile.TemporaryDirectory() as tmpdir:
filename = f"{tmpdir}/test.h5"
with h5py.File(filename, "w") as f:
f.create_dataset("X", data=val)
ds = f.create_dataset("X", data=val)
ds.attrs["foo"] = "bar"
with LindiH5Store.from_file(
filename, url=filename
) as store: # set url so that a reference file system can be created
Expand All @@ -25,6 +26,16 @@ def test_scalar_datasets():
if not _check_equal(X1[()], X2[()]):
print(f"WARNING: {X1} ({type(X1)}) != {X2} ({type(X2)})")
raise ValueError("Scalar datasets are not equal")
assert '.zgroup' in store
assert '.zarray' not in rfs['refs']
assert '.zarray' not in store
assert '.zattrs' in store # it's in the store but not in the ref file system -- see notes in LindiH5Store source code
assert '.zattrs' not in rfs['refs']
assert 'X/.zgroup' not in store
assert 'X/.zattrs' in store # foo is set to bar
assert store['X/.zattrs'] == rfs['refs']['X/.zattrs'].encode()
assert 'X/.zarray' in rfs['refs']
assert store['X/.zarray'] == rfs['refs']['X/.zarray'].encode()


def test_numpy_arrays():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_with_real_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,5 +298,5 @@ def test_with_real_data():
for k in top_level_keys:
assert k in top_level_keys_2

root = zarr.open(store)
root = zarr.open(store, mode="r")
_hdf5_visit_items(h5f, lambda key, item: _compare_item_2(item, root[key]))