diff --git a/lindi/LindiH5pyFile/LindiH5pyFile.py b/lindi/LindiH5pyFile/LindiH5pyFile.py index eba0e81..d25f2fb 100644 --- a/lindi/LindiH5pyFile/LindiH5pyFile.py +++ b/lindi/LindiH5pyFile/LindiH5pyFile.py @@ -331,6 +331,11 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds): raise Exception("Cannot create dataset in read-only mode") return self._the_group.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds) + def require_dataset(self, name, shape, dtype, exact=False, **kwds): + if self._mode not in ['r+']: + raise Exception("Cannot require dataset in read-only mode") + return self._the_group.require_dataset(name, shape, dtype, exact=exact, **kwds) + def _download_file(url: str, filename: str) -> None: headers = { diff --git a/lindi/LindiH5pyFile/LindiH5pyGroup.py b/lindi/LindiH5pyFile/LindiH5pyGroup.py index e9f430d..1fbe8be 100644 --- a/lindi/LindiH5pyFile/LindiH5pyGroup.py +++ b/lindi/LindiH5pyFile/LindiH5pyGroup.py @@ -183,6 +183,12 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds): assert self._writer is not None return self._writer.create_dataset(name, shape=shape, dtype=dtype, data=data, **kwds) + def require_dataset(self, name, shape, dtype, exact=False, **kwds): + if self._readonly: + raise Exception('Cannot require dataset in read-only mode') + assert self._writer is not None + return self._writer.require_dataset(name, shape, dtype, exact=exact, **kwds) + def __setitem__(self, name, obj): if self._readonly: raise Exception('Cannot set item in read-only mode') diff --git a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py index 5bfa58e..a6ca6bf 100644 --- a/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py +++ b/lindi/LindiH5pyFile/writers/LindiH5pyGroupWriter.py @@ -83,6 +83,22 @@ def create_dataset(self, name, shape=None, dtype=None, data=None, **kwds): else: raise Exception(f'Unexpected group object type: {type(self.p._group_object)}') + def require_dataset(self, name, shape, dtype, exact=False, **kwds): + if name in self.p: + ret = self.p[name] + if not isinstance(ret, LindiH5pyDataset): + raise Exception(f'Expected a dataset at {name} but got {type(ret)}') + if ret.shape != shape: + raise Exception(f'Expected shape {shape} but got {ret.shape}') + if exact: + if ret.dtype != dtype: + raise Exception(f'Expected dtype {dtype} but got {ret.dtype}') + else: + if not np.can_cast(ret.dtype, dtype): + raise Exception(f'Cannot cast dtype {ret.dtype} to {dtype}') + return ret + return self.create_dataset(name, *(shape, dtype), **kwds) + def __setitem__(self, name, obj): if isinstance(obj, h5py.SoftLink): if isinstance(self.p._group_object, h5py.Group): diff --git a/tests/test_zarr_write.py b/tests/test_zarr_write.py index 7512605..a205502 100644 --- a/tests/test_zarr_write.py +++ b/tests/test_zarr_write.py @@ -20,6 +20,28 @@ def test_zarr_write(): compare_example_h5_data(h5f_backed_by_zarr, tmpdir=tmpdir) +def test_require_dataset(): + with tempfile.TemporaryDirectory() as tmpdir: + dirname = f'{tmpdir}/test.zarr' + store = zarr.DirectoryStore(dirname) + zarr.group(store=store) + with lindi.LindiH5pyFile.from_zarr_store(store, mode='r+') as h5f_backed_by_zarr: + h5f_backed_by_zarr.create_dataset('dset_int8', data=np.array([1, 2, 3], dtype=np.int8)) + h5f_backed_by_zarr.create_dataset('dset_int16', data=np.array([1, 2, 3], dtype=np.int16)) + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int8) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(4,), dtype=np.int8) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int16, exact=True) + h5f_backed_by_zarr.require_dataset('dset_int8', shape=(3,), dtype=np.int16, exact=False) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_int16', shape=(3,), dtype=np.int8, exact=False) + ds = h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float32) + ds[:] = np.array([1.1, 2.2, 3.3]) + with pytest.raises(Exception): + h5f_backed_by_zarr.require_dataset('dset_float32', shape=(3,), dtype=np.float64, exact=True) + + def write_example_h5_data(h5f: h5py.File): h5f.attrs['attr_str'] = 'hello' h5f.attrs['attr_int'] = 42 @@ -80,4 +102,4 @@ def compare_example_h5_data(h5f: h5py.File, tmpdir: str): if __name__ == '__main__': - test_zarr_write() + test_require_dataset()