From 2f37253a807307dcc2219dcd32d01ee1580c039c Mon Sep 17 00:00:00 2001 From: CodyCBakerPhD Date: Wed, 23 Aug 2023 21:01:32 +0000 Subject: [PATCH] more tests; more debugs --- src/hdmf_zarr/utils.py | 76 ++++++++++++++------- tests/unit/test_parallel_write.py | 108 +++++++++++++++++++++++++++--- 2 files changed, 150 insertions(+), 34 deletions(-) diff --git a/src/hdmf_zarr/utils.py b/src/hdmf_zarr/utils.py index 40599c3a..67c5e125 100644 --- a/src/hdmf_zarr/utils.py +++ b/src/hdmf_zarr/utils.py @@ -1,20 +1,22 @@ -"""Collection of utility I/O classes for the ZarrIO backend store""" -from zarr.hierarchy import Group -import zarr -import numcodecs +"""Collection of utility I/O classes for the ZarrIO backend store.""" import gc import traceback -import numpy as np import multiprocessing +import math +import json +import logging from collections import deque from collections.abc import Iterable from typing import Optional, Union, Literal, Tuple from pathlib import Path from concurrent.futures import ProcessPoolExecutor from threadpoolctl import threadpool_limits +from warnings import warn -import json -import logging +import numcodecs +import zarr +import numpy as np +from zarr.hierarchy import Group from hdmf.data_utils import DataIO, GenericDataChunkIterator, DataChunkIterator, DataChunk, AbstractDataChunkIterator from hdmf.query import HDMFDataset @@ -83,7 +85,7 @@ def _is_pickleable(iterator: AbstractDataChunkIterator) -> Tuple[bool, Optional[ """ try: dictionary = iterator._to_dict() - iterator._from_dict(dictionary=test_pickle) + iterator._from_dict(dictionary=dictionary) return True, None except Exception as exception: @@ -195,10 +197,19 @@ def exhaust_queue( """ self.logger.debug("Exhausting DataChunkIterator from queue (length %d)" % len(self)) if number_of_jobs > 1: + parallelizable_iterators = list() buffer_map = list() + size_in_MB_per_iteration = list() display_progress = False - parallelized_iterators = list() + r_bar_in_MB = "| {n_fmt}/{total_fmt} MB [Elapsed: {elapsed}, Remaining: {remaining}, Rate:{rate_fmt}{postfix}]" + bar_format = "{l_bar}{bar}" + f"{r_bar_in_MB}" + progress_bar_options = dict( + desc=f"Writing Zarr datasets with {number_of_jobs} jobs", + position=0, + bar_format=bar_format, + unit="MB", + ) for (zarr_dataset, iterator) in iter(self): # Parallel write only works well with GenericDataChunkIterators # Due to perfect alignment between chunks and buffers @@ -210,25 +221,35 @@ def exhaust_queue( if not is_iterator_pickleable: self.logger.debug(f"Dataset {zarr_dataset.path} was not pickleable during parallel write.\n\nReason: {reason}") continue - + # Add this entry to a running list to remove after initial pass (cannot mutate during iteration) - parallelized_iterators.append((zarr_dataset, iterator)) + parallelizable_iterators.append((zarr_dataset, iterator)) + # Disable progress at the iterator level and aggregate enable option display_progress = display_progress or iterator.display_progress iterator.display_progress = False + per_iterator_progress_options = { + key: value for key, value in iterator.progress_bar_options.items() if key not in ["desc", "total", "file"] + } + progress_bar_options.update(**per_iterator_progress_options) + iterator_itemsize = iterator.dtype.itemsize for buffer_selection in iterator.buffer_selection_generator: buffer_map_args = (zarr_dataset.store.path, zarr_dataset.path, iterator, buffer_selection) buffer_map.append(buffer_map_args) + buffer_size_in_MB = math.prod([slice_.stop - slice_.start for slice_ in buffer_selection]) * iterator_itemsize / 1e6 + size_in_MB_per_iteration.append(buffer_size_in_MB) + progress_bar_options.update( + total=int(sum(size_in_MB_per_iteration)), # int() to round down to nearest integer for better display + ) # Remove candidates for parallelization from the queue - for (zarr_dataset, iterator) in parallelized_iterators: + for (zarr_dataset, iterator) in parallelizable_iterators: self.remove((zarr_dataset, iterator)) operation_to_run = _write_buffer_zarr process_initialization = dict initialization_arguments = () - with ProcessPoolExecutor( max_workers=number_of_jobs, initializer=initializer_wrapper, @@ -238,23 +259,30 @@ def exhaust_queue( results = executor.map(function_wrapper, buffer_map) if display_progress: - from tqdm import tqdm - - results = tqdm(results, desc="Writing in parallel with Zarr", total=len(buffer_map), position=0) + try: + from tqdm import tqdm + + results = tqdm(iterable=results, **progress_bar_options) + + # exector map must be iterated to deploy commands over jobs + for size_in_MB, result in zip(size_in_MB_per_iteration, results): + results.update(n=int(size_in_MB)) # int() to round down to nearest integer for better display + except Exception as exception: # Import warnings are also issued at the level of the iterator instantiation + warn(f"Unable to setup progress bar due to\ntype(exception): str(exception)\n\n{traceback.format_exc()}") + # exector map must be iterated to deploy commands over jobs + for result in results: + pass + else: + # exector map must be iterated to deploy commands over jobs + for result in results: + pass - for result in results: - pass - # if self.handle_returns: - # returns.append(res) - # if self.gather_func is not None: - # self.gather_func(res) - #else: # Iterate through our remaining queue and write DataChunks in a round-robin fashion until all iterators are exhausted while len(self) > 0: zarr_dataset, iterator = self.popleft() if self.__write_chunk__(zarr_dataset, iterator): self.append(dataset=zarr_dataset, data=iterator) - self.logger.debug("Exhausted DataChunkIterator from queue (length %d)" % len(self)) + self.logger.debug(f"Exhausted DataChunkIterator from queue (length {len(self)})") def append(self, dataset, data): """ diff --git a/tests/unit/test_parallel_write.py b/tests/unit/test_parallel_write.py index 0f479642..bd6859a6 100644 --- a/tests/unit/test_parallel_write.py +++ b/tests/unit/test_parallel_write.py @@ -1,12 +1,21 @@ """Module for testing the parallel write feature for the ZarrIO.""" +import unittest from pathlib import Path from typing import Tuple, Dict +from io import StringIO +from unittest.mock import patch import numpy as np +from numpy.testing import assert_array_equal from hdmf_zarr import ZarrIO -from hdmf.common import DynamicTable, VectorData +from hdmf.common import DynamicTable, VectorData, get_manager from hdmf.data_utils import GenericDataChunkIterator, DataChunkIterator +try: + import tqdm # noqa: F401 + TQDM_INSTALLED = True +except ImportError: + TQDM_INSTALLED = False class PickleableDataChunkIterator(GenericDataChunkIterator): """Generic data chunk iterator used for specific testing purposes.""" @@ -42,7 +51,7 @@ def _to_dict(self) -> Dict: @staticmethod def _from_dict(dictionary: dict) -> GenericDataChunkIterator: # TODO: need to investigate the need of base path - source_type = dictionary["data"] + data = dictionary["data"] iterator = PickleableDataChunkIterator(data=data, **dictionary["base_kwargs"]) return iterator @@ -67,12 +76,20 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray: def test_parallel_write(tmpdir): number_of_jobs = 2 - column = VectorData(name="TestColumn", description="", data=PickleableDataChunkIterator(data=np.array([1., 2., 3.]))) + data = np.array([1., 2., 3.]) + column = VectorData(name="TestColumn", description="", data=PickleableDataChunkIterator(data=data)) dynamic_table = DynamicTable(name="TestTable", description="", columns=[column]) - zarr_top_level_path = str(tmpdir / f"example_parallel_zarr_{number_of_jobs}.zarr") + zarr_top_level_path = str(tmpdir / "test_parallel_write.zarr") with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: io.write(dynamic_table, number_of_jobs=number_of_jobs) + + # TODO: roundtrip currently fails due to read error + #with ZarrIO(path=zarr_top_level_path, mode="r") as io: + # dynamic_table_roundtrip = io.read() + # data_roundtrip = dynamic_table_roundtrip["TestColumn"].data + # assert_array_equal(data_roundtrip, data) + def test_mixed_iterator_types(tmpdir): number_of_jobs = 2 @@ -81,17 +98,88 @@ def test_mixed_iterator_types(tmpdir): unwrapped_column = VectorData(name="TestUnwrappedColumn", description="", data=np.array([7., 8., 9.])) dynamic_table = DynamicTable(name="TestTable", description="", columns=[generic_column, classic_column, unwrapped_column]) - zarr_top_level_path = str(tmpdir / f"example_parallel_zarr_{number_of_jobs}.zarr") + zarr_top_level_path = str(tmpdir / "test_mixed_iterator_types.zarr") with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: io.write(dynamic_table, number_of_jobs=number_of_jobs) - # TODO: ensure can write a Zarr file with three datasets, one wrapped in a Generic iterator, one wrapped in DataChunkIterator, one not wrapped at all + + # TODO: roundtrip currently fails + #with ZarrIO(path=zarr_top_level_path, mode="r") as io: + # dynamic_table_roundtrip = io.read() + # data_roundtrip = dynamic_table_roundtrip["TestColumn"].data + # assert_array_equal(data_roundtrip, data) def test_mixed_iterator_pickleability(tmpdir): - pass # TODO: ensure can write a Zarr file with two datasets, one wrapped in pickleable one wrapped in not-pickleable + number_of_jobs = 2 + pickleable_column = VectorData(name="TestGenericColumn", description="", data=PickleableDataChunkIterator(data=np.array([1., 2., 3.]))) + not_pickleable_column = VectorData(name="TestClassicColumn", description="", data=NotPickleableDataChunkIterator(data=np.array([4., 5., 6.]))) + dynamic_table = DynamicTable(name="TestTable", description="", columns=[pickleable_column, not_pickleable_column]) + + zarr_top_level_path = str(tmpdir / "test_mixed_iterator_pickleability.zarr") + with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + io.write(dynamic_table, number_of_jobs=number_of_jobs) + + # TODO: roundtrip currently fails due to read error + #with ZarrIO(path=zarr_top_level_path, mode="r") as io: + # dynamic_table_roundtrip = io.read() + # data_roundtrip = dynamic_table_roundtrip["TestColumn"].data + # assert_array_equal(data_roundtrip, data) + + +@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed") +def test_simple_tqdm(tmpdir): + number_of_jobs = 2 + expected_desc = f"Writing Zarr datasets with {number_of_jobs} jobs" + + zarr_top_level_path = str(tmpdir / "test_simple_tqdm.zarr") + with patch("sys.stderr", new=StringIO()) as tqdm_out, ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + column = VectorData( + name="TestColumn", + description="", + data=PickleableDataChunkIterator( + data=np.array([1., 2., 3.]), + display_progress=True, + #progress_bar_options=dict(file=tqdm_out), + ) + ) + dynamic_table = DynamicTable(name="TestTable", description="", columns=[column]) + io.write(dynamic_table, number_of_jobs=number_of_jobs) + + assert expected_desc in tqdm_out.getvalue() + + +@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed") +def test_compound_tqdm(tmpdir): + number_of_jobs = 2 + expected_desc_pickleable = f"Writing Zarr datasets with {number_of_jobs} jobs" + expected_desc_not_pickleable = "Writing non-parallel dataset..." + + zarr_top_level_path = str(tmpdir / "test_compound_tqdm.zarr") + with patch("sys.stderr", new=StringIO()) as tqdm_out, ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io: + pickleable_column = VectorData( + name="TestGenericColumn", + description="", + data=PickleableDataChunkIterator( + data=np.array([1., 2., 3.]), + display_progress=True, + ) + ) + not_pickleable_column = VectorData( + name="TestClassicColumn", + description="", + data=NotPickleableDataChunkIterator( + data=np.array([4., 5., 6.]), + display_progress=True, + progress_bar_options=dict(desc=expected_desc_not_pickleable, position=1) + ) + ) + dynamic_table = DynamicTable(name="TestTable", description="", columns=[pickleable_column, not_pickleable_column]) + io.write(dynamic_table, number_of_jobs=number_of_jobs) + + tqdm_out_value = tqdm_out.getvalue() + assert expected_desc_pickleable in tqdm_out_value + assert expected_desc_not_pickleable in tqdm_out_value -def test_tqdm(tmpdir): - pass # TODO: grab stdout with dispaly_progress enabled and ensure it looks as expected (consult HDMF generic iterator tests) def test_extra_args(tmpdir): pass # TODO? Should we test if the other arguments like thread count can be passed? - # I mean, anything _can_ be passed, but how to test if it was actually used? Seems difficult... + # I mean, anything _can_ be passed due to dynamic **kwargs, but how to test if it was actually used? Seems difficult...