Skip to content

Commit

Permalink
more tests; more debugs
Browse files Browse the repository at this point in the history
  • Loading branch information
CodyCBakerPhD committed Aug 23, 2023
1 parent 9ff7a17 commit 2f37253
Show file tree
Hide file tree
Showing 2 changed files with 150 additions and 34 deletions.
76 changes: 52 additions & 24 deletions src/hdmf_zarr/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
"""Collection of utility I/O classes for the ZarrIO backend store"""
from zarr.hierarchy import Group
import zarr
import numcodecs
"""Collection of utility I/O classes for the ZarrIO backend store."""
import gc
import traceback
import numpy as np
import multiprocessing
import math
import json
import logging
from collections import deque
from collections.abc import Iterable
from typing import Optional, Union, Literal, Tuple
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor
from threadpoolctl import threadpool_limits
from warnings import warn

import json
import logging
import numcodecs
import zarr
import numpy as np
from zarr.hierarchy import Group

from hdmf.data_utils import DataIO, GenericDataChunkIterator, DataChunkIterator, DataChunk, AbstractDataChunkIterator
from hdmf.query import HDMFDataset
Expand Down Expand Up @@ -83,7 +85,7 @@ def _is_pickleable(iterator: AbstractDataChunkIterator) -> Tuple[bool, Optional[
"""
try:
dictionary = iterator._to_dict()
iterator._from_dict(dictionary=test_pickle)
iterator._from_dict(dictionary=dictionary)

return True, None
except Exception as exception:
Expand Down Expand Up @@ -195,10 +197,19 @@ def exhaust_queue(
"""
self.logger.debug("Exhausting DataChunkIterator from queue (length %d)" % len(self))
if number_of_jobs > 1:
parallelizable_iterators = list()
buffer_map = list()
size_in_MB_per_iteration = list()

display_progress = False
parallelized_iterators = list()
r_bar_in_MB = "| {n_fmt}/{total_fmt} MB [Elapsed: {elapsed}, Remaining: {remaining}, Rate:{rate_fmt}{postfix}]"
bar_format = "{l_bar}{bar}" + f"{r_bar_in_MB}"
progress_bar_options = dict(
desc=f"Writing Zarr datasets with {number_of_jobs} jobs",
position=0,
bar_format=bar_format,
unit="MB",
)
for (zarr_dataset, iterator) in iter(self):
# Parallel write only works well with GenericDataChunkIterators
# Due to perfect alignment between chunks and buffers
Expand All @@ -210,25 +221,35 @@ def exhaust_queue(
if not is_iterator_pickleable:
self.logger.debug(f"Dataset {zarr_dataset.path} was not pickleable during parallel write.\n\nReason: {reason}")
continue

# Add this entry to a running list to remove after initial pass (cannot mutate during iteration)
parallelized_iterators.append((zarr_dataset, iterator))
parallelizable_iterators.append((zarr_dataset, iterator))

# Disable progress at the iterator level and aggregate enable option
display_progress = display_progress or iterator.display_progress
iterator.display_progress = False
per_iterator_progress_options = {
key: value for key, value in iterator.progress_bar_options.items() if key not in ["desc", "total", "file"]
}
progress_bar_options.update(**per_iterator_progress_options)

iterator_itemsize = iterator.dtype.itemsize
for buffer_selection in iterator.buffer_selection_generator:
buffer_map_args = (zarr_dataset.store.path, zarr_dataset.path, iterator, buffer_selection)
buffer_map.append(buffer_map_args)
buffer_size_in_MB = math.prod([slice_.stop - slice_.start for slice_ in buffer_selection]) * iterator_itemsize / 1e6
size_in_MB_per_iteration.append(buffer_size_in_MB)
progress_bar_options.update(
total=int(sum(size_in_MB_per_iteration)), # int() to round down to nearest integer for better display
)

# Remove candidates for parallelization from the queue
for (zarr_dataset, iterator) in parallelized_iterators:
for (zarr_dataset, iterator) in parallelizable_iterators:
self.remove((zarr_dataset, iterator))

operation_to_run = _write_buffer_zarr
process_initialization = dict
initialization_arguments = ()

with ProcessPoolExecutor(
max_workers=number_of_jobs,
initializer=initializer_wrapper,
Expand All @@ -238,23 +259,30 @@ def exhaust_queue(
results = executor.map(function_wrapper, buffer_map)

if display_progress:
from tqdm import tqdm

results = tqdm(results, desc="Writing in parallel with Zarr", total=len(buffer_map), position=0)
try:
from tqdm import tqdm

results = tqdm(iterable=results, **progress_bar_options)

# exector map must be iterated to deploy commands over jobs
for size_in_MB, result in zip(size_in_MB_per_iteration, results):
results.update(n=int(size_in_MB)) # int() to round down to nearest integer for better display
except Exception as exception: # Import warnings are also issued at the level of the iterator instantiation
warn(f"Unable to setup progress bar due to\ntype(exception): str(exception)\n\n{traceback.format_exc()}")
# exector map must be iterated to deploy commands over jobs
for result in results:
pass
else:
# exector map must be iterated to deploy commands over jobs
for result in results:
pass

for result in results:
pass
# if self.handle_returns:
# returns.append(res)
# if self.gather_func is not None:
# self.gather_func(res)
#else:
# Iterate through our remaining queue and write DataChunks in a round-robin fashion until all iterators are exhausted
while len(self) > 0:
zarr_dataset, iterator = self.popleft()
if self.__write_chunk__(zarr_dataset, iterator):
self.append(dataset=zarr_dataset, data=iterator)
self.logger.debug("Exhausted DataChunkIterator from queue (length %d)" % len(self))
self.logger.debug(f"Exhausted DataChunkIterator from queue (length {len(self)})")

def append(self, dataset, data):
"""
Expand Down
108 changes: 98 additions & 10 deletions tests/unit/test_parallel_write.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
"""Module for testing the parallel write feature for the ZarrIO."""
import unittest
from pathlib import Path
from typing import Tuple, Dict
from io import StringIO
from unittest.mock import patch

import numpy as np
from numpy.testing import assert_array_equal
from hdmf_zarr import ZarrIO
from hdmf.common import DynamicTable, VectorData
from hdmf.common import DynamicTable, VectorData, get_manager
from hdmf.data_utils import GenericDataChunkIterator, DataChunkIterator

try:
import tqdm # noqa: F401
TQDM_INSTALLED = True
except ImportError:
TQDM_INSTALLED = False

class PickleableDataChunkIterator(GenericDataChunkIterator):
"""Generic data chunk iterator used for specific testing purposes."""
Expand Down Expand Up @@ -42,7 +51,7 @@ def _to_dict(self) -> Dict:

@staticmethod
def _from_dict(dictionary: dict) -> GenericDataChunkIterator: # TODO: need to investigate the need of base path
source_type = dictionary["data"]
data = dictionary["data"]

iterator = PickleableDataChunkIterator(data=data, **dictionary["base_kwargs"])
return iterator
Expand All @@ -67,12 +76,20 @@ def _get_data(self, selection: Tuple[slice]) -> np.ndarray:

def test_parallel_write(tmpdir):
number_of_jobs = 2
column = VectorData(name="TestColumn", description="", data=PickleableDataChunkIterator(data=np.array([1., 2., 3.])))
data = np.array([1., 2., 3.])
column = VectorData(name="TestColumn", description="", data=PickleableDataChunkIterator(data=data))
dynamic_table = DynamicTable(name="TestTable", description="", columns=[column])

zarr_top_level_path = str(tmpdir / f"example_parallel_zarr_{number_of_jobs}.zarr")
zarr_top_level_path = str(tmpdir / "test_parallel_write.zarr")
with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io:
io.write(dynamic_table, number_of_jobs=number_of_jobs)

# TODO: roundtrip currently fails due to read error
#with ZarrIO(path=zarr_top_level_path, mode="r") as io:
# dynamic_table_roundtrip = io.read()
# data_roundtrip = dynamic_table_roundtrip["TestColumn"].data
# assert_array_equal(data_roundtrip, data)


def test_mixed_iterator_types(tmpdir):
number_of_jobs = 2
Expand All @@ -81,17 +98,88 @@ def test_mixed_iterator_types(tmpdir):
unwrapped_column = VectorData(name="TestUnwrappedColumn", description="", data=np.array([7., 8., 9.]))
dynamic_table = DynamicTable(name="TestTable", description="", columns=[generic_column, classic_column, unwrapped_column])

zarr_top_level_path = str(tmpdir / f"example_parallel_zarr_{number_of_jobs}.zarr")
zarr_top_level_path = str(tmpdir / "test_mixed_iterator_types.zarr")
with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io:
io.write(dynamic_table, number_of_jobs=number_of_jobs)
# TODO: ensure can write a Zarr file with three datasets, one wrapped in a Generic iterator, one wrapped in DataChunkIterator, one not wrapped at all

# TODO: roundtrip currently fails
#with ZarrIO(path=zarr_top_level_path, mode="r") as io:
# dynamic_table_roundtrip = io.read()
# data_roundtrip = dynamic_table_roundtrip["TestColumn"].data
# assert_array_equal(data_roundtrip, data)

def test_mixed_iterator_pickleability(tmpdir):
pass # TODO: ensure can write a Zarr file with two datasets, one wrapped in pickleable one wrapped in not-pickleable
number_of_jobs = 2
pickleable_column = VectorData(name="TestGenericColumn", description="", data=PickleableDataChunkIterator(data=np.array([1., 2., 3.])))
not_pickleable_column = VectorData(name="TestClassicColumn", description="", data=NotPickleableDataChunkIterator(data=np.array([4., 5., 6.])))
dynamic_table = DynamicTable(name="TestTable", description="", columns=[pickleable_column, not_pickleable_column])

zarr_top_level_path = str(tmpdir / "test_mixed_iterator_pickleability.zarr")
with ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io:
io.write(dynamic_table, number_of_jobs=number_of_jobs)

# TODO: roundtrip currently fails due to read error
#with ZarrIO(path=zarr_top_level_path, mode="r") as io:
# dynamic_table_roundtrip = io.read()
# data_roundtrip = dynamic_table_roundtrip["TestColumn"].data
# assert_array_equal(data_roundtrip, data)


@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed")
def test_simple_tqdm(tmpdir):
number_of_jobs = 2
expected_desc = f"Writing Zarr datasets with {number_of_jobs} jobs"

zarr_top_level_path = str(tmpdir / "test_simple_tqdm.zarr")
with patch("sys.stderr", new=StringIO()) as tqdm_out, ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io:
column = VectorData(
name="TestColumn",
description="",
data=PickleableDataChunkIterator(
data=np.array([1., 2., 3.]),
display_progress=True,
#progress_bar_options=dict(file=tqdm_out),
)
)
dynamic_table = DynamicTable(name="TestTable", description="", columns=[column])
io.write(dynamic_table, number_of_jobs=number_of_jobs)

assert expected_desc in tqdm_out.getvalue()


@unittest.skipIf(not TQDM_INSTALLED, "optional tqdm module is not installed")
def test_compound_tqdm(tmpdir):
number_of_jobs = 2
expected_desc_pickleable = f"Writing Zarr datasets with {number_of_jobs} jobs"
expected_desc_not_pickleable = "Writing non-parallel dataset..."

zarr_top_level_path = str(tmpdir / "test_compound_tqdm.zarr")
with patch("sys.stderr", new=StringIO()) as tqdm_out, ZarrIO(path=zarr_top_level_path, manager=get_manager(), mode="w") as io:
pickleable_column = VectorData(
name="TestGenericColumn",
description="",
data=PickleableDataChunkIterator(
data=np.array([1., 2., 3.]),
display_progress=True,
)
)
not_pickleable_column = VectorData(
name="TestClassicColumn",
description="",
data=NotPickleableDataChunkIterator(
data=np.array([4., 5., 6.]),
display_progress=True,
progress_bar_options=dict(desc=expected_desc_not_pickleable, position=1)
)
)
dynamic_table = DynamicTable(name="TestTable", description="", columns=[pickleable_column, not_pickleable_column])
io.write(dynamic_table, number_of_jobs=number_of_jobs)

tqdm_out_value = tqdm_out.getvalue()
assert expected_desc_pickleable in tqdm_out_value
assert expected_desc_not_pickleable in tqdm_out_value

def test_tqdm(tmpdir):
pass # TODO: grab stdout with dispaly_progress enabled and ensure it looks as expected (consult HDMF generic iterator tests)

def test_extra_args(tmpdir):
pass # TODO? Should we test if the other arguments like thread count can be passed?
# I mean, anything _can_ be passed, but how to test if it was actually used? Seems difficult...
# I mean, anything _can_ be passed due to dynamic **kwargs, but how to test if it was actually used? Seems difficult...

0 comments on commit 2f37253

Please sign in to comment.