Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Write references in compound datasets as an array #146

Merged
merged 9 commits into from
Feb 2, 2024
50 changes: 42 additions & 8 deletions src/hdmf_zarr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,18 +1017,50 @@ def write_dataset(self, **kwargs): # noqa: C901
type_str.append(self.__serial_dtype__(t)[0])

if len(refs) > 0:
dset = parent.require_dataset(name,
shape=(len(data), ),
dtype=object,
object_codec=self.__codec_cls(),
**options['io_settings'])

self._written_builders.set_written(builder) # record that the builder has been written
dset.attrs['zarr_dtype'] = type_str

# gather items to write
new_items = []
for j, item in enumerate(data):
new_item = list(item)
for i in refs:
new_item[i] = self.__get_ref(item[i], export_source=export_source)
dset[j] = new_item
new_items.append(tuple(new_item))

# Create dtype for storage, replacing values to match hdmf's hdf5 behavior
# ---
# TODO: Replace with a simple one-liner once __resolve_dtype_helper__ is
# compatible with zarr's need for fixed-length string dtypes.
# dtype = self.__resolve_dtype_helper__(options['dtype'])

new_dtype = []
for field in options['dtype']:
if field['dtype'] is str or field['dtype'] in (
'str', 'text', 'utf', 'utf8', 'utf-8', 'isodatetime'
mavaylon1 marked this conversation as resolved.
Show resolved Hide resolved
):
# Zarr does not support variable length strings
new_dtype.append((field['name'], 'O'))
elif isinstance(field['dtype'], dict):
# eg. for some references, dtype will be of the form
# {'target_type': 'Baz', 'reftype': 'object'}
# which should just get serialized as an object
new_dtype.append((field['name'], 'O'))
else:
new_dtype.append((field['name'], self.__resolve_dtype_helper__(field['dtype'])))
dtype = np.dtype(new_dtype)

# cast and store compound dataset
arr = np.array(new_items, dtype=dtype)
dset = parent.require_dataset(
name,
shape=(len(arr),),
dtype=dtype,
object_codec=self.__codec_cls(),
**options['io_settings']
)
dset.attrs['zarr_dtype'] = type_str
dset[...] = arr
else:
# write a compound datatype
dset = self.__list_fill__(parent, name, data, options)
Expand Down Expand Up @@ -1153,8 +1185,10 @@ def __resolve_dtype_helper__(cls, dtype):
return cls.__dtypes.get(dtype)
elif isinstance(dtype, dict):
return cls.__dtypes.get(dtype['reftype'])
else:
elif isinstance(dtype, list):
return np.dtype([(x['name'], cls.__resolve_dtype_helper__(x['dtype'])) for x in dtype])
else:
raise ValueError(f'Cant resolve dtype {dtype}')

@classmethod
def get_type(cls, data):
Expand Down
5 changes: 5 additions & 0 deletions tests/unit/base_tests_zarrio.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,11 @@ def test_read_reference_compound(self):
self.read()
builder = self.createReferenceCompoundBuilder()['ref_dataset']
read_builder = self.root['ref_dataset']

# ensure the array was written as a compound array
ref_dtype = np.dtype([('id', '<i4'), ('name', 'O'), ('reference', 'O')])
self.assertEqual(read_builder.data.dataset.dtype, ref_dtype)

# Load the elements of each entry in the compound dataset and compar the index, string, and referenced array
for i, v in enumerate(read_builder['data']):
self.assertEqual(v[0], builder['data'][i][0]) # Compare index value from compound tuple
Expand Down
Loading