Skip to content

Commit

Permalink
implement require_dataset()
Browse files Browse the repository at this point in the history
  • Loading branch information
magland committed Apr 4, 2024
1 parent 20a8f25 commit 249feda
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 1 deletion.
5 changes: 5 additions & 0 deletions lindi/LindiH5pyFile/LindiH5pyFile.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,11 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
raise Exception("Cannot create dataset in read-only mode")
return self._the_group.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds)

def require_dataset(self, name, shape, dtype, exact=False, **kwds):
if self._mode not in ['r+']:
raise Exception("Cannot require dataset in read-only mode")
return self._the_group.require_dataset(name, shape, dtype, exact=exact, **kwds)


def _download_file(url: str, filename: str) -> None:
headers = {
Expand Down
6 changes: 6 additions & 0 deletions lindi/LindiH5pyFile/LindiH5pyGroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,12 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
assert self._writer is not None
return self._writer.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds)

def require_dataset(self, name, shape, dtype, exact=False, **kwds):
if self._readonly:
raise Exception('Cannot require dataset in read-only mode')
assert self._writer is not None
return self._writer.require_dataset(name, shape, dtype, exact=exact, **kwds)

def __setitem__(self, name, obj):
if self._readonly:
raise Exception('Cannot set item in read-only mode')
Expand Down
16 changes: 16 additions & 0 deletions lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,22 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds):
else:
raise Exception(f'Unexpected group object type: {type(self.p._group_object)}')

def require_dataset(self, name, shape, dtype, exact=False, **kwds):
if name in self.p:
ret = self.p[name]
if not isinstance(ret, LindiH5pyDataset):
raise Exception(f'Expected a dataset at {name} but got {type(ret)}')
if ret.shape != shape:
raise Exception(f'Expected shape {shape} but got {ret.shape}')
if exact:
if ret.dtype != dtype:
raise Exception(f'Expected dtype {dtype} but got {ret.dtype}')
else:
if not np.can_cast(ret.dtype, dtype):
raise Exception(f'Cannot cast dtype {ret.dtype} to {dtype}')
return ret
return self.create_dataset(name, *(shape, dtype), **kwds)

def __setitem__(self, name, obj):
if isinstance(obj, h5py.SoftLink):
if isinstance(self.p._group_object, h5py.Group):
Expand Down
24 changes: 23 additions & 1 deletion tests/test_zarr_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,28 @@ def test_zarr_write():
compare_example_h5_data(h5f_backed_by_zarr, tmpdir=tmpdir)


def test_require_dataset():
with tempfile.TemporaryDirectory() as tmpdir:
dirname = f'{tmpdir}/test.zarr'
store = zarr.DirectoryStore(dirname)
zarr.group(store=store)
with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as h5f_backed_by_zarr:
h5f_backed_by_zarr.create_dataset('dset_int8', data=np.array([1, 2, 3], dtype=np.int8))
h5f_backed_by_zarr.create_dataset('dset_int16', data=np.array([1, 2, 3], dtype=np.int16))
h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int8)
with pytest.raises(Exception):
h5f_backed_by_zarr.require_dataset('dset_int8', shape=(4,), dtype=np.int8)
with pytest.raises(Exception):
h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int16, exact=True)
h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int16, exact=False)
with pytest.raises(Exception):
h5f_backed_by_zarr.require_dataset('dset_int16', shape=(3,), dtype=np.int8, exact=False)
ds = h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float32)
ds[:] = np.array([1.1, 2.2, 3.3])
with pytest.raises(Exception):
h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float64, exact=True)


def write_example_h5_data(h5f: h5py.File):
h5f.attrs['attr_str'] = 'hello'
h5f.attrs['attr_int'] = 42
Expand Down Expand Up @@ -80,4 +102,4 @@ def compare_example_h5_data(h5f: h5py.File, tmpdir: str):


if __name__ == '__main__':
test_zarr_write()
test_require_dataset()

0 comments on commit 249feda

Please sign in to comment.