From 7174615054dc7837e2eca2d94e8d081ac91959d7 Mon Sep 17 00:00:00 2001 From: sneakers-the-rat Date: Tue, 26 Dec 2023 23:46:59 -0800 Subject: [PATCH] Use explicit type map instead of stringlike types Expand check for string types Make 'else' a failure condition in `__resolve_dtype_helper__` (rather than implicitly assuming list) --- src/hdmf_zarr/backend.py | 15 ++++++++++++--- tests/unit/base_tests_zarrio.py | 2 +- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/hdmf_zarr/backend.py b/src/hdmf_zarr/backend.py index 4afa008c..fa8cab43 100644 --- a/src/hdmf_zarr/backend.py +++ b/src/hdmf_zarr/backend.py @@ -1023,9 +1023,16 @@ def write_dataset(self, **kwargs): # noqa: C901 new_items.append(tuple(new_item)) # Create dtype for storage, replacing values to match hdmf's hdf5 behavior + # --- + # TODO: Replace with a simple one-liner once __resolve_dtype_helper__ is + # compatible with zarr's need for fixed-length string dtypes. + # dtype = self.__resolve_dtype_helper__(options['dtype']) + new_dtype = [] for field in options['dtype']: - if field['dtype'] is str: + if field['dtype'] is str or field['dtype'] in ( + 'str', 'text', 'utf', 'utf8', 'utf-8', 'isodatetime' + ): new_dtype.append((field['name'], 'U25')) elif isinstance(field['dtype'], dict): # eg. for some references, dtype will be of the form @@ -1033,7 +1040,7 @@ def write_dataset(self, **kwargs): # noqa: C901 # which should just get serialized as an object new_dtype.append((field['name'], 'O')) else: - new_dtype.append((field['name'], field['dtype'])) + new_dtype.append((field['name'], self.__resolve_dtype_helper__(field['dtype']))) dtype = np.dtype(new_dtype) # cast and store compound dataset @@ -1171,8 +1178,10 @@ def __resolve_dtype_helper__(cls, dtype): return cls.__dtypes.get(dtype) elif isinstance(dtype, dict): return cls.__dtypes.get(dtype['reftype']) - else: + elif isinstance(dtype, list): return np.dtype([(x['name'], cls.__resolve_dtype_helper__(x['dtype'])) for x in dtype]) + else: + raise ValueError(f'Cant resolve dtype {dtype}') @classmethod def get_type(cls, data): diff --git a/tests/unit/base_tests_zarrio.py b/tests/unit/base_tests_zarrio.py index 6700b803..81ab9416 100644 --- a/tests/unit/base_tests_zarrio.py +++ b/tests/unit/base_tests_zarrio.py @@ -436,7 +436,7 @@ def test_read_reference_compound(self): read_builder = self.root['ref_dataset'] # ensure the array was written as a compound array - ref_dtype = np.dtype([('id', '