Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issues with recent versions of Dask and Numba #313

Merged
merged 9 commits into from
Jan 18, 2024

Conversation

dmgav
Copy link
Contributor

@dmgav dmgav commented Jan 18, 2024

The PR contains fixes for two issues with Dask (1) and Numba/Numpy (2):

  1. Issue with Dask. Starting with Dask v2023.9.3 the workaround used to force Dask workers to close HDF5 files stopped working. The issue was created in the distributed package: HDF5 file remains open after computations with Dask array based on HDF5 dataset is completed dask/distributed#8452 The working solution (also a workaround) can be found in https://gist.github.com/dmgav/9fa69d1e507eff46e8098f082b0a1611 and includes overriding default serializers and deserializers for h5py files and datasets. The code is also included in this PR description below.
#############  File 'deserializers.py'  ####################

import distributed.protocol.h5py
from distributed.protocol.serialize import dask_serialize, dask_deserialize

deserialized_files = set()


def serialize_h5py_file(f):
    if f and (f.mode != "r"):
        raise ValueError("Can only serialize read-only h5py files")
    filename = f.filename if f else None
    return {"filename": filename}, []


def serialize_h5py_dataset(x):
    header, _ = serialize_h5py_file(x.file if x else None)
    header["name"] = x.name if x else None
    return header, []

def deserialize_h5py_file(header, frames):
    import h5py

    filename = header["filename"]
    if filename:
        file = h5py.File(filename, mode="r")
        deserialized_files.add(file)
    else:
        file = None
    return file

def deserialize_h5py_dataset(header, frames):
    file = deserialize_h5py_file(header, frames)
    name = header["name"]
    dset = file[name] if (file and name) else None
    return dset

def set_custom_serializers():
    import h5py

    dask_serialize.register((h5py.Group, h5py.Dataset), serialize_h5py_dataset)
    dask_serialize.register(h5py.File, serialize_h5py_file)
    dask_deserialize.register((h5py.Group, h5py.Dataset), deserialize_h5py_dataset)
    dask_deserialize.register(h5py.File, deserialize_h5py_file)

def close_all_files():
    while deserialized_files:
        file = deserialized_files.pop()
        if file:
            file.close()
import h5py
import dask
import dask.array as da
import distributed
from dask.distributed import Client, wait
import numpy as np

import logging
logger =  logging.Logger(__name__)


def run_example():


    from deserializers import set_custom_serializers, close_all_files

    print(f"Version of Dask: {dask.__version__}")
    print(f"Version of Distributed: {distributed.__version__}")
    print(f"===============================")

    # Create HDF5 file
    print("Creating HDF5 file")
    fln = "test.h5"
    with h5py.File(fln, "w") as f:
        dset = f.create_dataset("data", data=np.random.random(size=(100, 100)), chunks=(10, 10), dtype="float64")

    print("Creating client")
    client = Client()

    client.run(set_custom_serializers)
    set_custom_serializers()

    # Process the file
    print("Loading and processing data")
    with h5py.File(fln, "r") as f:

        data = da.from_array(f["data"], chunks=(10, 10))
        sm_fut = da.sum(data, axis=0).persist(scheduler=client)
        sm = sm_fut.compute(scheduler=client)
        print(f"sm={sm}")

    client.run(close_all_files)
    close_all_files()

    # Try to open file for writing
    print("Attempting to open file for writing")
    try:
        with h5py.File(fln, "r+") as f:
            print("File was opened for writing !!!")
    except OSError as ex:
        logger.exception("Failed to open file for writing: %s", ex)

    print("Closing client")
    client.close()

if __name__ == "__main__":
    run_example()
  1. Issue with Numpy/Numba. The function that implements 'snip' method for background subtraction kept failing with List index out of range error when run with Numba JIT (it was working correctly without Numba JIT). The issue was tracked to the use of numpy.convolve function. The part of the function using convolution was reimplemented using an explicit loop, multiplication and subtraction. In the initial tests, the new solution appears to work at least as fast as the old solution when compiled with Numba JIT.

Original code based on np.convolve:

    A = s.sum()
    background = np.convolve(background, s) / A
    # Trim 'background' array to imitate the np.convolve option 'mode="same"'
    mg = len(s) - 1
    n_beg = mg // 2
    n_end = n_beg - mg  # Negative
    background = background[n_beg:n_end]

Replacement code:

    def convolve(background, s):
        s_len = len(s)
        n_beg = (s_len - 1) // 2
        A = s.sum()
        source = np.hstack(
            (
                np.zeros(n_beg, dtype=background.dtype),
                background,
                np.zeros(s_len - n_beg, dtype=background.dtype),
            )
        )
        for n in range(len(background)):
            background[n] = np.sum(source[n : n + s_len] * s) / A

    convolve(background, s)

@dmgav dmgav merged commit 690a308 into NSLS-II:master Jan 18, 2024
13 checks passed
@dmgav dmgav deleted the dask-fix branch January 18, 2024 14:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant