diff --git a/.github/workflows/check_external_links.yml b/.github/workflows/check_external_links.yml index 1fbf0ee0..109446ad 100644 --- a/.github/workflows/check_external_links.yml +++ b/.github/workflows/check_external_links.yml @@ -27,7 +27,7 @@ jobs: - name: Install Sphinx dependencies and package run: | python -m pip install --upgrade pip - python -m pip install -r requirements-doc.txt -r requirements.txt + python -m pip install -r requirements-doc.txt -r requirements.txt -r requirements-opt.txt python -m pip install . - name: Check Sphinx external links run: sphinx-build -b linkcheck ./docs/source ./test_build diff --git a/.github/workflows/run_coverage.yml b/.github/workflows/run_coverage.yml index 142b0868..becadc4c 100644 --- a/.github/workflows/run_coverage.yml +++ b/.github/workflows/run_coverage.yml @@ -48,7 +48,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install -r requirements-dev.txt -r requirements.txt + python -m pip install -r requirements-dev.txt -r requirements.txt -r requirements-opt.txt - name: Install package run: | diff --git a/.readthedocs.yaml b/.readthedocs.yaml index cabf84ab..f57db3ed 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -26,6 +26,7 @@ python: install: - requirements: requirements-doc.txt - requirements: requirements.txt + - requirements: requirements-opt.txt - path: . # path to the package relative to the root # Optionally include all submodules diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ab43847..c9c6e89f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,9 @@ * Fixed error in deploy workflow. @mavaylon1 [#109](https://github.com/hdmf-dev/hdmf-zarr/pull/109) * Fixed build error for ReadtheDocs by degrading numpy for python 3.7 support. @mavaylon1 [#115](https://github.com/hdmf-dev/hdmf-zarr/pull/115) +### New Features +* Added parallel write support for the ``ZarrIO``. @CodyCBakerPhD [#118](https://github.com/hdmf-dev/hdmf-zarr/pull/118) + ## 0.3.0 (July 21, 2023) diff --git a/MANIFEST.in b/MANIFEST.in index 783dea68..de5b2302 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include LICENSE.txt versioneer.py src/hdmf_zarr/_version.py src/hdmf_zarr/_due.py -include requirements.txt requirements-dev.txt requirements-doc.txt +include requirements.txt requirements-dev.txt requirements-doc.txt requirements-opt.txt include test.py tox.ini graft tests diff --git a/requirements-min.txt b/requirements-min.txt index 4695e8f3..c452e4c5 100644 --- a/requirements-min.txt +++ b/requirements-min.txt @@ -4,3 +4,4 @@ numcodecs==0.9.1 pynwb==2.5.0 setuptools importlib_resources;python_version<'3.9' # Remove when python 3.9 becomes the new minimum +threadpoolctl==3.1.0 diff --git a/requirements-opt.txt b/requirements-opt.txt new file mode 100644 index 00000000..101e7d7b --- /dev/null +++ b/requirements-opt.txt @@ -0,0 +1 @@ +tqdm==4.65.0 diff --git a/requirements.txt b/requirements.txt index b6eb9731..edd4c45d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,5 +2,6 @@ hdmf==3.9.0 zarr==2.11.0 pynwb==2.5.0 -numpy==1.24 +numpy==1.24.0 numcodecs==0.11.0 +threadpoolctl==3.2.0 diff --git a/setup.py b/setup.py index 3556fcf9..4ff0eb53 100755 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ 'numcodecs==0.11.0', 'pynwb>=2.5.0', 'setuptools', + 'threadpoolctl>=3.1.0', ] print(reqs) @@ -40,6 +41,7 @@ 'url': 'https://github.com/hdmf-dev/hdmf-zarr', 'license': "BSD", 'install_requires': reqs, + 'extras_require': {"tqdm": ["tqdm>=4.41.0"]}, 'packages': pkgs, 'package_dir': {'': 'src'}, 'package_data': {}, diff --git a/src/hdmf_zarr/backend.py b/src/hdmf_zarr/backend.py index 688d387d..dd8b93ad 100644 --- a/src/hdmf_zarr/backend.py +++ b/src/hdmf_zarr/backend.py @@ -114,7 +114,7 @@ def __init__(self, **kwargs): self.__file = None self.__built = dict() self._written_builders = WriteStatusTracker() # track which builders were written (or read) by this IO object - self.__dci_queue = ZarrIODataChunkIteratorQueue() # a queue of DataChunkIterators that need to be exhausted + self.__dci_queue = None # Will be initialized on call to io.write # Codec class to be used. Alternates, e.g., =numcodecs.JSON self.__codec_cls = numcodecs.pickles.Pickle if object_codec_class is None else object_codec_class source_path = self.__path @@ -188,17 +188,54 @@ def load_namespaces(cls, namespace_catalog, path, namespaces=None): reader = ZarrSpecReader(ns_group) namespace_catalog.load_namespaces('namespace', reader=reader) - @docval({'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, - {'name': 'cache_spec', 'type': bool, 'doc': 'cache specification to file', 'default': True}, - {'name': 'link_data', 'type': bool, - 'doc': 'If not specified otherwise link (True) or copy (False) Datasets', 'default': True}, - {'name': 'exhaust_dci', 'type': bool, - 'doc': 'exhaust DataChunkIterators one at a time. If False, add ' + - 'them to the internal queue self.__dci_queue and exhaust them concurrently at the end', - 'default': True},) + @docval( + {'name': 'container', 'type': Container, 'doc': 'the Container object to write'}, + {'name': 'cache_spec', 'type': bool, 'doc': 'cache specification to file', 'default': True}, + {'name': 'link_data', 'type': bool, + 'doc': 'If not specified otherwise link (True) or copy (False) Datasets', 'default': True}, + {'name': 'exhaust_dci', 'type': bool, + 'doc': 'exhaust DataChunkIterators one at a time. If False, add ' + + 'them to the internal queue self.__dci_queue and exhaust them concurrently at the end', + 'default': True}, + { + "name": "number_of_jobs", + "type": int, + "doc": ( + "Number of jobs to use in parallel during write " + "(only works with GenericDataChunkIterator-wrapped datasets)." + ), + "default": 1, + }, + { + "name": "max_threads_per_process", + "type": int, + "doc": ( + "Limits the number of threads used by each process. The default is None (no limits)." + ), + "default": None, + }, + { + "name": "multiprocessing_context", + "type": str, + "doc": ( + "Context for multiprocessing. It can be None (default), 'fork' or 'spawn'. " + "Note that 'fork' is only available on UNIX systems (not Windows)." + ), + "default": None, + }, + ) def write(self, **kwargs): - """Overwrite the write method to add support for caching the specification""" - cache_spec = popargs('cache_spec', kwargs) + """Overwrite the write method to add support for caching the specification and parallelization.""" + cache_spec, number_of_jobs, max_threads_per_process, multiprocessing_context = popargs( + "cache_spec", "number_of_jobs", "max_threads_per_process", "multiprocessing_context", kwargs + ) + + self.__dci_queue = ZarrIODataChunkIteratorQueue( + number_of_jobs=number_of_jobs, + max_threads_per_process=max_threads_per_process, + multiprocessing_context=multiprocessing_context, + ) + super(ZarrIO, self).write(**kwargs) if cache_spec: self.__cache_spec() @@ -225,8 +262,36 @@ def __cache_spec(self): writer = ZarrSpecWriter(ns_group) ns_builder.export('namespace', writer=writer) - @docval(*get_docval(HDMFIO.export), - {'name': 'cache_spec', 'type': bool, 'doc': 'whether to cache the specification to file', 'default': True}) + @docval( + *get_docval(HDMFIO.export), + {'name': 'cache_spec', 'type': bool, 'doc': 'whether to cache the specification to file', 'default': True}, + { + "name": "number_of_jobs", + "type": int, + "doc": ( + "Number of jobs to use in parallel during write " + "(only works with GenericDataChunkIterator-wrapped datasets)." + ), + "default": 1, + }, + { + "name": "max_threads_per_process", + "type": int, + "doc": ( + "Limits the number of threads used by each process. The default is None (no limits)." + ), + "default": None, + }, + { + "name": "multiprocessing_context", + "type": str, + "doc": ( + "Context for multiprocessing. It can be None (default), 'fork' or 'spawn'. " + "Note that 'fork' is only available on UNIX systems (not Windows)." + ), + "default": None, + }, + ) def export(self, **kwargs): """Export data read from a file from any backend to Zarr. See :py:meth:`hdmf.backends.io.HDMFIO.export` for more details. @@ -237,6 +302,15 @@ def export(self, **kwargs): src_io = getargs('src_io', kwargs) write_args, cache_spec = popargs('write_args', 'cache_spec', kwargs) + number_of_jobs, max_threads_per_process, multiprocessing_context = popargs( + "number_of_jobs", "max_threads_per_process", "multiprocessing_context", kwargs + ) + + self.__dci_queue = ZarrIODataChunkIteratorQueue( + number_of_jobs=number_of_jobs, + max_threads_per_process=max_threads_per_process, + multiprocessing_context=multiprocessing_context, + ) if not isinstance(src_io, ZarrIO) and write_args.get('link_data', True): raise UnsupportedOperation("Cannot export from non-Zarr backend %s to Zarr with write argument " @@ -286,36 +360,53 @@ def get_builder_disk_path(self, **kwargs): builder_path = os.path.join(basepath, self.__get_path(builder).lstrip("/")) return builder_path - @docval({'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder object representing the NWBFile'}, - {'name': 'link_data', 'type': bool, - 'doc': 'If not specified otherwise link (True) or copy (False) Zarr Datasets', 'default': True}, - {'name': 'exhaust_dci', 'type': bool, - 'doc': 'exhaust DataChunkIterators one at a time. If False, add ' + - 'them to the internal queue self.__dci_queue and exhaust them concurrently at the end', - 'default': True}, - {'name': 'export_source', 'type': str, - 'doc': 'The source of the builders when exporting', 'default': None}) + @docval( + {'name': 'builder', 'type': GroupBuilder, 'doc': 'the GroupBuilder object representing the NWBFile'}, + { + 'name': 'link_data', + 'type': bool, + 'doc': 'If not specified otherwise link (True) or copy (False) Zarr Datasets', + 'default': True + }, + { + 'name': 'exhaust_dci', + 'type': bool, + 'doc': ( + 'Exhaust DataChunkIterators one at a time. If False, add ' + 'them to the internal queue self.__dci_queue and exhaust them concurrently at the end' + ), + 'default': True, + }, + { + 'name': 'export_source', + 'type': str, + 'doc': 'The source of the builders when exporting', + 'default': None, + }, + ) def write_builder(self, **kwargs): - """Write a builder to disk""" - f_builder, link_data, exhaust_dci, export_source = getargs('builder', - 'link_data', - 'exhaust_dci', - 'export_source', - kwargs) + """Write a builder to disk.""" + f_builder, link_data, exhaust_dci, export_source = getargs( + 'builder', 'link_data', 'exhaust_dci', 'export_source', kwargs + ) for name, gbldr in f_builder.groups.items(): - self.write_group(parent=self.__file, - builder=gbldr, - link_data=link_data, - exhaust_dci=exhaust_dci, - export_source=export_source) + self.write_group( + parent=self.__file, + builder=gbldr, + link_data=link_data, + exhaust_dci=exhaust_dci, + export_source=export_source, + ) for name, dbldr in f_builder.datasets.items(): - self.write_dataset(parent=self.__file, - builder=dbldr, - link_data=link_data, - exhaust_dci=exhaust_dci, - export_source=export_source) + self.write_dataset( + parent=self.__file, + builder=dbldr, + link_data=link_data, + exhaust_dci=exhaust_dci, + export_source=export_source, + ) self.write_attributes(self.__file, f_builder.attributes) # the same as set_attributes in HDMF - self.__dci_queue.exhaust_queue() # Write all DataChunkIterators that have been queued + self.__dci_queue.exhaust_queue() # Write any remaining DataChunkIterators that have been queued self._written_builders.set_written(f_builder) self.logger.debug("Done writing %s '%s' to path '%s'" % (f_builder.__class__.__qualname__, f_builder.name, self.source)) @@ -333,12 +424,10 @@ def write_builder(self, **kwargs): returns='the Group that was created', rtype='Group') def write_group(self, **kwargs): """Write a GroupBuider to file""" - parent, builder, link_data, exhaust_dci, export_source = getargs('parent', - 'builder', - 'link_data', - 'exhaust_dci', - 'export_source', - kwargs) + parent, builder, link_data, exhaust_dci, export_source = getargs( + 'parent', 'builder', 'link_data', 'exhaust_dci', 'export_source', kwargs + ) + if self.get_written(builder): group = parent[builder.name] else: @@ -347,19 +436,23 @@ def write_group(self, **kwargs): subgroups = builder.groups if subgroups: for subgroup_name, sub_builder in subgroups.items(): - self.write_group(parent=group, - builder=sub_builder, - link_data=link_data, - exhaust_dci=exhaust_dci) + self.write_group( + parent=group, + builder=sub_builder, + link_data=link_data, + exhaust_dci=exhaust_dci, + ) datasets = builder.datasets if datasets: for dset_name, sub_builder in datasets.items(): - self.write_dataset(parent=group, - builder=sub_builder, - link_data=link_data, - exhaust_dci=exhaust_dci, - export_source=export_source) + self.write_dataset( + parent=group, + builder=sub_builder, + link_data=link_data, + exhaust_dci=exhaust_dci, + export_source=export_source, + ) # write all links (haven implemented) links = builder.links @@ -379,10 +472,9 @@ def write_group(self, **kwargs): {'name': 'export_source', 'type': str, 'doc': 'The source of the builders when exporting', 'default': None}) def write_attributes(self, **kwargs): - """ - Set (i.e., write) the attributes on a given Zarr Group or Array - """ + """Set (i.e., write) the attributes on a given Zarr Group or Array.""" obj, attributes, export_source = getargs('obj', 'attributes', 'export_source', kwargs) + for key, value in attributes.items(): # Case 1: list, set, tuple type attributes if isinstance(value, (set, list, tuple)) or (isinstance(value, np.ndarray) and np.ndim(value) != 0): @@ -723,13 +815,15 @@ def __setup_chunked_dataset__(cls, parent, name, data, options=None): 'doc': 'The source of the builders when exporting', 'default': None}, returns='the Zarr array that was created', rtype=Array) def write_dataset(self, **kwargs): # noqa: C901 - parent, builder, link_data, exhaust_dci, export_source = getargs('parent', - 'builder', - 'link_data', - 'exhaust_dci', - 'export_source', - kwargs) + parent, builder, link_data, exhaust_dci, export_source = getargs( + 'parent', 'builder', 'link_data', 'exhaust_dci', 'export_source', kwargs + ) + force_data = getargs('force_data', kwargs) + + if exhaust_dci and self.__dci_queue is None: + self.__dci_queue = ZarrIODataChunkIteratorQueue() + if self.get_written(builder): return None name = builder.name diff --git a/src/hdmf_zarr/utils.py b/src/hdmf_zarr/utils.py index 9c23aba5..c584451c 100644 --- a/src/hdmf_zarr/utils.py +++ b/src/hdmf_zarr/utils.py @@ -1,20 +1,36 @@ -"""Collection of utility I/O classes for the ZarrIO backend store""" -from zarr.hierarchy import Group -import zarr -import numcodecs -import numpy as np +"""Collection of utility I/O classes for the ZarrIO backend store.""" +import gc +import traceback +import multiprocessing +import math +import json +import logging from collections import deque from collections.abc import Iterable +from typing import Optional, Union, Literal, Tuple, Dict, Any +from concurrent.futures import ProcessPoolExecutor +from threadpoolctl import threadpool_limits +from warnings import warn -import json -import logging +import numcodecs +import zarr +import numpy as np +from zarr.hierarchy import Group + +from hdmf.data_utils import DataIO, GenericDataChunkIterator, DataChunkIterator, AbstractDataChunkIterator +from hdmf.query import HDMFDataset +from hdmf.utils import docval, getargs -from hdmf.data_utils import DataIO -from hdmf.utils import (docval, - getargs) +from hdmf.spec import SpecWriter, SpecReader -from hdmf.spec import (SpecWriter, - SpecReader) + +# Necessary definitions to avoid parallelization bugs, Inherited from SpikeInterface experience +# see +# https://stackoverflow.com/questions/10117073/how-to-use-initializer-to-set-up-my-multiprocess-pool +# the tricks is : theses 2 variables are global per worker +# so they are not share in the same process +global _worker_context +global _operation_to_run class ZarrIODataChunkIteratorQueue(deque): @@ -22,18 +38,37 @@ class ZarrIODataChunkIteratorQueue(deque): Helper class used by ZarrIO to manage the write for DataChunkIterators Each queue element must be a tupple of two elements: 1) the dataset to write to and 2) the AbstractDataChunkIterator with the data + :param number_of_jobs: The number of jobs used to write the datasets. The default is 1. + :type number_of_jobs: integer + :param max_threads_per_process: Limits the number of threads used by each process. The default is None (no limits). + :type max_threads_per_process: integer or None + :param multiprocessing_context: Context for multiprocessing. It can be None (default), "fork" or "spawn". + Note that "fork" is only available on UNIX systems (not Windows). + :type multiprocessing_context: string or None """ - def __init__(self): + def __init__( + self, + number_of_jobs: int = 1, + max_threads_per_process: Union[None, int] = None, + multiprocessing_context: Union[None, Literal["fork", "spawn"]] = None, + ): self.logger = logging.getLogger('%s.%s' % (self.__class__.__module__, self.__class__.__qualname__)) + + self.number_of_jobs = number_of_jobs + self.max_threads_per_process = max_threads_per_process + self.multiprocessing_context = multiprocessing_context + super().__init__() @classmethod - def __write_chunk__(cls, dset, data): + def __write_chunk__(cls, dset: HDMFDataset, data: DataChunkIterator): """ Internal helper function used to read a chunk from the given DataChunkIterator and write it to the given Dataset - :param dset: The Dataset to write to - :param data: The DataChunkIterator to read from + :param dset: The Dataset to write to. + :type dset: HDMFDataset + :param data: The DataChunkIterator to read from. + :type data: DataChunkIterator :return: True of a chunk was written, False otherwise :rtype: bool """ @@ -63,19 +98,126 @@ def __write_chunk__(cls, dset, data): # Write the data dset[chunk_i.selection] = chunk_i.data # Chunk written and we need to continue + return True def exhaust_queue(self): """ - Read and write from any queued DataChunkIterators in a round-robin fashion + Read and write from any queued DataChunkIterators. + + Operates in a round-robin fashion for a single job. + Operates on a single dataset at a time with multiple jobs. """ - # Iterate through our queue and write data chunks in a round-robin fashion until all iterators are exhausted - self.logger.debug("Exhausting DataChunkIterator from queue (length %d)" % len(self)) + self.logger.debug(f"Exhausting DataChunkIterator from queue (length {len(self)})") + + if self.number_of_jobs > 1: + parallelizable_iterators = list() + buffer_map = list() + size_in_MB_per_iteration = list() + + display_progress = False + r_bar_in_MB = ( + "| {n_fmt}/{total_fmt} MB [Elapsed: {elapsed}, " + "Remaining: {remaining}, Rate:{rate_fmt}{postfix}]" + ) + bar_format = "{l_bar}{bar}" + f"{r_bar_in_MB}" + progress_bar_options = dict( + desc=f"Writing Zarr datasets with {self.number_of_jobs} jobs", + position=0, + bar_format=bar_format, + unit="MB", + ) + for (zarr_dataset, iterator) in iter(self): + # Parallel write only works well with GenericDataChunkIterators + # Due to perfect alignment between chunks and buffers + if not isinstance(iterator, GenericDataChunkIterator): + continue + + # Iterator must be pickleable as well, to be sent across jobs + is_iterator_pickleable, reason = self._is_pickleable(iterator=iterator) + if not is_iterator_pickleable: + self.logger.debug( + f"Dataset {zarr_dataset.path} was not pickleable during parallel write.\n\nReason: {reason}" + ) + continue + + # Add this entry to a running list to remove after initial pass (cannot mutate during iteration) + parallelizable_iterators.append((zarr_dataset, iterator)) + + # Disable progress at the iterator level and aggregate enable option + display_progress = display_progress or iterator.display_progress + iterator.display_progress = False + per_iterator_progress_options = { + key: value for key, value in iterator.progress_bar_options.items() + if key not in ["desc", "total", "file"] + } + progress_bar_options.update(**per_iterator_progress_options) + + iterator_itemsize = iterator.dtype.itemsize + for buffer_selection in iterator.buffer_selection_generator: + buffer_map_args = (zarr_dataset.store.path, zarr_dataset.path, iterator, buffer_selection) + buffer_map.append(buffer_map_args) + buffer_size_in_MB = math.prod( + [slice_.stop - slice_.start for slice_ in buffer_selection] + ) * iterator_itemsize / 1e6 + size_in_MB_per_iteration.append(buffer_size_in_MB) + progress_bar_options.update( + total=int(sum(size_in_MB_per_iteration)), # int() to round down to nearest integer for better display + ) + + if parallelizable_iterators: # Avoid spinning up ProcessPool if no candidates during this exhaustion + # Remove candidates for parallelization from the queue + for (zarr_dataset, iterator) in parallelizable_iterators: + self.remove((zarr_dataset, iterator)) + + operation_to_run = self._write_buffer_zarr + process_initialization = dict + initialization_arguments = () + with ProcessPoolExecutor( + max_workers=self.number_of_jobs, + initializer=self.initializer_wrapper, + mp_context=multiprocessing.get_context(method=self.multiprocessing_context), + initargs=( + operation_to_run, + process_initialization, + initialization_arguments, + self.max_threads_per_process + ), + ) as executor: + results = executor.map(self.function_wrapper, buffer_map) + + if display_progress: + try: # Import warnings are also issued at the level of the iterator instantiation + from tqdm import tqdm + + results = tqdm(iterable=results, **progress_bar_options) + + # exector map must be iterated to deploy commands over jobs + for size_in_MB, result in zip(size_in_MB_per_iteration, results): + results.update(n=int(size_in_MB)) # int() to round down for better display + except Exception as exception: # pragma: no cover + warn( + message=( + "Unable to setup progress bar due to" + f"\n{type(exception)}: {str(exception)}\n\n{traceback.format_exc()}" + ), + stacklevel=2, + ) + # exector map must be iterated to deploy commands over jobs + for result in results: + pass + else: + # exector map must be iterated to deploy commands over jobs + for result in results: + pass + + # Iterate through remaining queue and write DataChunks in a round-robin fashion until exhausted while len(self) > 0: - dset, data = self.popleft() - if self.__write_chunk__(dset, data): - self.append(dataset=dset, data=data) - self.logger.debug("Exhausted DataChunkIterator from queue (length %d)" % len(self)) + zarr_dataset, iterator = self.popleft() + if self.__write_chunk__(zarr_dataset, iterator): + self.append(dataset=zarr_dataset, data=iterator) + + self.logger.debug(f"Exhausted DataChunkIterator from queue (length {len(self)})") def append(self, dataset, data): """ @@ -87,6 +229,108 @@ def append(self, dataset, data): """ super().append((dataset, data)) + @staticmethod + def _is_pickleable(iterator: AbstractDataChunkIterator) -> Tuple[bool, Optional[str]]: + """ + Determine if the iterator can be pickled. + + Returns both the bool and the reason if False. + """ + try: + dictionary = iterator._to_dict() + iterator._from_dict(dictionary=dictionary) + + return True, None + except Exception as exception: + base_hdmf_not_implemented_messages = ( + "The `._to_dict()` method for pickling has not been defined for this DataChunkIterator!", + "The `._from_dict()` method for pickling has not been defined for this DataChunkIterator!", + ) + + if isinstance(exception, NotImplementedError) and str(exception) in base_hdmf_not_implemented_messages: + reason = "The pickling methods for the iterator have not been defined." + else: + reason = ( + f"The pickling methods for the iterator have been defined but throw the error:\n\n" + f"{type(exception)}: {str(exception)}\n\nwith traceback\n\n{traceback.format_exc()}," + ) + + return False, reason + + @staticmethod + def initializer_wrapper( + operation_to_run: callable, + process_initialization: callable, + initialization_arguments: Iterable, # TODO: eventually standardize with typing.Iterable[typing.Any] + max_threads_per_process: Optional[int] = None + ): # keyword arguments here are just for readability, ProcessPool only takes a tuple + """ + Needed as a part of a bug fix with cloud memory leaks discovered by SpikeInterface team. + + Recommended fix is to have global wrappers for the working initializer that limits the + threads used per process. + """ + global _worker_context + global _operation_to_run + + if max_threads_per_process is None: + _worker_context = process_initialization(*initialization_arguments) + else: + with threadpool_limits(limits=max_threads_per_process): + _worker_context = process_initialization(*initialization_arguments) + _worker_context["max_threads_per_process"] = max_threads_per_process + _operation_to_run = operation_to_run + + @staticmethod + def _write_buffer_zarr( + worker_context: Dict[str, Any], + zarr_store_path: str, + relative_dataset_path: str, + iterator: AbstractDataChunkIterator, + buffer_selection: Tuple[slice, ...], + ): + # TODO, figure out propagation of storage options + zarr_store = zarr.open(store=zarr_store_path, mode="r+") # storage_options=storage_options) + zarr_dataset = zarr_store[relative_dataset_path] + + data = iterator._get_data(selection=buffer_selection) + zarr_dataset[buffer_selection] = data + + # An issue detected in cloud usage by the SpikeInterface team + # Fix memory leak by forcing garbage collection + del data + gc.collect() + + @staticmethod + def function_wrapper(args: Tuple[str, str, AbstractDataChunkIterator, Tuple[slice, ...]]): + """ + Needed as a part of a bug fix with cloud memory leaks discovered by SpikeInterface team. + + Recommended fix is to have a global wrapper for the executor.map level. + """ + zarr_store_path, relative_dataset_path, iterator, buffer_selection = args + global _worker_context + global _operation_to_run + + max_threads_per_process = _worker_context["max_threads_per_process"] + if max_threads_per_process is None: + return _operation_to_run( + _worker_context, + zarr_store_path, + relative_dataset_path, + iterator, + buffer_selection + ) + else: + with threadpool_limits(limits=max_threads_per_process): + return _operation_to_run( + _worker_context, + zarr_store_path, + relative_dataset_path, + iterator, + buffer_selection, + ) + class ZarrSpecWriter(SpecWriter): """ diff --git a/tests/unit/test_parallel_write.py b/tests/unit/test_parallel_write.py new file mode 100644 index 00000000..61aae7ab --- /dev/null +++ b/tests/unit/test_parallel_write.py @@ -0,0 +1,267 @@ +"""Module for testing the parallel write feature for the ZarrIO.""" +import unittest +import platform +from typing import Tuple, Dict +from io import StringIO +from unittest.mock import patch + +import numpy as np +from numpy.testing import assert_array_equal +from hdmf_zarr import ZarrIO +from hdmf.common import DynamicTable, VectorData, get_manager +from hdmf.data_utils import GenericDataChunkIterator, DataChunkIterator + +try: + import tqdm # noqa: F401 + TQDM_INSTALLED = True +except ImportError: + TQDM_INSTALLED = False + + +class PickleableDataChunkIterator(GenericDataChunkIterator): + """Generic data chunk iterator used for specific testing purposes.""" + + def __init__(self, data, **base_kwargs): + self.data = data + + self._base_kwargs = base_kwargs + super().__init__(**base_kwargs) + + def _get_dtype(self) -> np.dtype: + return self.data.dtype + + def _get_maxshape(self) -> tuple: + return self.data.shape + + def _get_data(self, selection: Tuple[slice]) -> np.ndarray: + return self.data[selection] + + def __reduce__(self): + instance_constructor = self._from_dict + initialization_args = (self._to_dict(),) + return (instance_constructor, initialization_args) + + def _to_dict(self) -> Dict: + dictionary = dict() + # Note this is not a recommended way to pickle contents + # ~~ Used for testing purposes only ~~ + dictionary["data"] = self.data + dictionary["base_kwargs"] = self._base_kwargs + + return dictionary + + @staticmethod + def _from_dict(dictionary: dict) -> GenericDataChunkIterator: # TODO: need to investigate the need of base path + data = dictionary["data"] + + iterator = PickleableDataChunkIterator(data=data, **dictionary["base_kwargs"]) + return iterator + + +class NotPickleableDataChunkIterator(GenericDataChunkIterator): + """Generic data chunk iterator used for specific testing purposes.""" + + def __init__(self, data, **base_kwargs): + self.data = data + + self._base_kwargs = base_kwargs + super().__init__(**base_kwargs) + + def _get_dtype(self) -> np.dtype: + return self.data.dtype + + def _get_maxshape(self) -> tuple: + return self.data.shape + + def _get_data(self, selection: Tuple[slice]) -> np.ndarray: + return self.data[selection] + + +def test_parallel_write(tmpdir): + number_of_jobs = 2 + data = np.array([1., 2., 3.]) + column = VectorData(name="TestColumn", description="", data=PickleableDataChunkIterator(data=data)) + dynamic_table = DynamicTable(name="TestTable", description="", id=list(range(3)), columns=[column]) + + zarr_top_level_path = str(tmpdir / "test_parallel_write.zarr") + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + io.write(container=dynamic_table, number_of_jobs=number_of_jobs) + + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="r") as io: + dynamic_table_roundtrip = io.read() + data_roundtrip = dynamic_table_roundtrip["TestColumn"].data + assert_array_equal(data_roundtrip, data) + + +def test_mixed_iterator_types(tmpdir): + number_of_jobs = 2 + + generic_iterator_data = np.array([1., 2., 3.]) + generic_iterator_column = VectorData( + name="TestGenericIteratorColumn", + description="", + data=PickleableDataChunkIterator(data=generic_iterator_data) + ) + + classic_iterator_data = np.array([4., 5., 6.]) + classic_iterator_column = VectorData( + name="TestClassicIteratorColumn", + description="", + data=DataChunkIterator(data=classic_iterator_data) + ) + + unwrappped_data = np.array([7., 8., 9.]) + unwrapped_column = VectorData(name="TestUnwrappedColumn", description="", data=unwrappped_data) + dynamic_table = DynamicTable( + name="TestTable", + description="", + id=list(range(3)), + columns=[generic_iterator_column, classic_iterator_column, unwrapped_column], + ) + + zarr_top_level_path = str(tmpdir / "test_mixed_iterator_types.zarr") + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + io.write(container=dynamic_table, number_of_jobs=number_of_jobs) + + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="r") as io: + dynamic_table_roundtrip = io.read() + generic_iterator_data_roundtrip = dynamic_table_roundtrip["TestGenericIteratorColumn"].data + assert_array_equal(generic_iterator_data_roundtrip, generic_iterator_data) + + classic_iterator_data_roundtrip = dynamic_table_roundtrip["TestClassicIteratorColumn"].data + assert_array_equal(classic_iterator_data_roundtrip, classic_iterator_data) + + generic_iterator_data_roundtrip = dynamic_table_roundtrip["TestUnwrappedColumn"].data + assert_array_equal(generic_iterator_data_roundtrip, unwrappped_data) + + +def test_mixed_iterator_pickleability(tmpdir): + number_of_jobs = 2 + + pickleable_iterator_data = np.array([1., 2., 3.]) + pickleable_iterator_column = VectorData( + name="TestGenericIteratorColumn", + description="", + data=PickleableDataChunkIterator(data=pickleable_iterator_data) + ) + + not_pickleable_iterator_data = np.array([4., 5., 6.]) + not_pickleable_iterator_column = VectorData( + name="TestClassicIteratorColumn", + description="", + data=NotPickleableDataChunkIterator(data=not_pickleable_iterator_data) + ) + + dynamic_table = DynamicTable( + name="TestTable", + description="", + id=list(range(3)), + columns=[pickleable_iterator_column, not_pickleable_iterator_column], + ) + + zarr_top_level_path = str(tmpdir / "test_mixed_iterator_pickleability.zarr") + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + io.write(container=dynamic_table, number_of_jobs=number_of_jobs) + + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="r") as io: + dynamic_table_roundtrip = io.read() + + pickleable_iterator_data_roundtrip = dynamic_table_roundtrip["TestGenericIteratorColumn"].data + assert_array_equal(pickleable_iterator_data_roundtrip, pickleable_iterator_data) + + not_pickleable_iterator_data_roundtrip = dynamic_table_roundtrip["TestClassicIteratorColumn"].data + assert_array_equal(not_pickleable_iterator_data_roundtrip, not_pickleable_iterator_data) + + +@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed") +def test_simple_tqdm(tmpdir): + number_of_jobs = 2 + expected_desc = f"Writing Zarr datasets with {number_of_jobs} jobs" + + zarr_top_level_path = str(tmpdir / "test_simple_tqdm.zarr") + with patch("sys.stderr", new=StringIO()) as tqdm_out: + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + column = VectorData( + name="TestColumn", + description="", + data=PickleableDataChunkIterator( + data=np.array([1., 2., 3.]), + display_progress=True, + ) + ) + dynamic_table = DynamicTable(name="TestTable", description="", columns=[column]) + io.write(container=dynamic_table, number_of_jobs=number_of_jobs) + + assert expected_desc in tqdm_out.getvalue() + + +@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed") +def test_compound_tqdm(tmpdir): + number_of_jobs = 2 + expected_desc_pickleable = f"Writing Zarr datasets with {number_of_jobs} jobs" + expected_desc_not_pickleable = "Writing non-parallel dataset..." + + zarr_top_level_path = str(tmpdir / "test_compound_tqdm.zarr") + with patch("sys.stderr", new=StringIO()) as tqdm_out: + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + pickleable_column = VectorData( + name="TestPickleableIteratorColumn", + description="", + data=PickleableDataChunkIterator( + data=np.array([1., 2., 3.]), + display_progress=True, + ) + ) + not_pickleable_column = VectorData( + name="TestNotPickleableColumn", + description="", + data=NotPickleableDataChunkIterator( + data=np.array([4., 5., 6.]), + display_progress=True, + progress_bar_options=dict(desc=expected_desc_not_pickleable, position=1) + ) + ) + dynamic_table = DynamicTable( + name="TestTable", description="", columns=[pickleable_column, not_pickleable_column] + ) + io.write(container=dynamic_table, number_of_jobs=number_of_jobs) + + tqdm_out_value = tqdm_out.getvalue() + assert expected_desc_pickleable in tqdm_out_value + assert expected_desc_not_pickleable in tqdm_out_value + + +def test_extra_keyword_argument_propagation(tmpdir): + number_of_jobs = 2 + + column = VectorData(name="TestColumn", description="", data=np.array([1., 2., 3.])) + dynamic_table = DynamicTable(name="TestTable", description="", id=list(range(3)), columns=[column]) + + zarr_top_level_path = str(tmpdir / "test_extra_parallel_write_keyword_arguments.zarr") + + test_keyword_argument_pairs = [ + dict(max_threads_per_process=2, multiprocessing_context=None), + dict(max_threads_per_process=None, multiprocessing_context="spawn"), + dict(max_threads_per_process=2, multiprocessing_context="spawn"), + ] + if platform.system() != "Windows": + test_keyword_argument_pairs.extend( + [ + dict(max_threads_per_process=None, multiprocessing_context="spawn"), + dict(max_threads_per_process=2, multiprocessing_context="spawn"), + ] + ) + + for test_keyword_argument_pair in test_keyword_argument_pairs: + test_max_threads_per_process = test_keyword_argument_pair["max_threads_per_process"] + test_multiprocessing_context = test_keyword_argument_pair["multiprocessing_context"] + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + io.write( + container=dynamic_table, + number_of_jobs=number_of_jobs, + max_threads_per_process=test_max_threads_per_process, + multiprocessing_context=test_multiprocessing_context + ) + + assert io._ZarrIO__dci_queue.max_threads_per_process == test_max_threads_per_process + assert io._ZarrIO__dci_queue.multiprocessing_context == test_multiprocessing_context diff --git a/tox.ini b/tox.ini index 6934d6e4..720a97f5 100644 --- a/tox.ini +++ b/tox.ini @@ -38,7 +38,7 @@ install_command = python -m pip install {opts} {packages} deps = {[testenv]deps} - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv]commands} # Test with python 3.11; pinned dev and optional reqs; upgraded run reqs @@ -48,7 +48,7 @@ install_command = python -m pip install -U {opts} {packages} deps = -rrequirements-dev.txt - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv]commands} # Test with python 3.11; pinned dev and optional reqs; upgraded, pre-release run reqs @@ -58,7 +58,7 @@ install_command = python -m pip install -U --pre {opts} {packages} deps = -rrequirements-dev.txt - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv]commands} # Test with python 3.8; pinned dev reqs; minimum run reqs @@ -95,7 +95,7 @@ commands = {[testenv:build]commands} basepython = python3.11 deps = {[testenv]deps} - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv:build]commands} [testenv:build-py311-upgraded] @@ -104,7 +104,7 @@ install_command = python -m pip install -U {opts} {packages} deps = -rrequirements-dev.txt - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv:build]commands} [testenv:build-py311-prerelease] @@ -113,7 +113,7 @@ install_command = python -m pip install -U --pre {opts} {packages} deps = -rrequirements-dev.txt - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv:build]commands} [testenv:build-py38-minimum] @@ -169,7 +169,7 @@ install_command = deps = -rrequirements-dev.txt -rrequirements-doc.txt - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv:gallery]commands} # Test with python 3.11; pinned dev, doc, and optional reqs; pre-release run reqs @@ -180,7 +180,7 @@ install_command = deps = -rrequirements-dev.txt -rrequirements-doc.txt - # -rrequirements-opt.txt + -rrequirements-opt.txt commands = {[testenv:gallery]commands} # Test with python 3.8; pinned dev and doc reqs; minimum run reqs @@ -190,4 +190,5 @@ deps = -rrequirements-dev.txt -rrequirements-min.txt -rrequirements-doc.txt + -rrequirements-opt.txt commands = {[testenv:gallery]commands}