Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

The Gaussian filter is not behaving lazily as expected and is allocating RAM for the entire array #386

Open
AlexeyPechnikov opened this issue Aug 22, 2024 · 18 comments

Comments

@AlexeyPechnikov
Copy link

The test script below uses 20 GB of RAM for the full array size. Despite the small Gaussian kernel, which requires minimal chunk overlap, dask_image.ndfilters still fails to perform lazy processing. The same issue occurs across multiple operating systems (Linux Debian, Linux Ubuntu, macOS) and Python versions (3.10, 3.11), as well as different versions of dask_image.

Processing time: 167.02 seconds
Filtered result: <xarray.DataArray 'filtered_phase' (y: 50000, x: 50000)> Size: 20GB
dask.array<_trim, shape=(50000, 50000), dtype=float64, chunksize=(2048, 2048), chunktype=numpy.ndarray>
Coordinates:
  * y        (y) float64 400kB 0.5 1.5 2.5 3.5 4.5 ... 5e+04 5e+04 5e+04 5e+04
  * x        (x) float64 400kB 0.5 1.5 2.5 3.5 4.5 ... 5e+04 5e+04 5e+04 5e+04
import numpy as np
import xarray as xr
import dask
from dask_image.ndfilters import gaussian_filter as dask_gaussian_filter
from dask.distributed import Client
import os
import time

shape = (50000, 50000)
chunk_size = (2048, 2048)
netcdf_chunk_size = (2048, 2048)
netcdf_engine='netcdf4'

cutoff = 5.3
sigma_y = 5.40
sigma_x = 20.64

def run_test():
    client = Client()
    # prepare data
    phase = xr.DataArray(
        dask.array.random.random(shape, chunks=chunk_size) + 1j * dask.array.random.random(shape, chunks=chunk_size),
        dims=['y', 'x'],
        coords={
            'y': np.arange(0.5, shape[0] + 0.5),
            'x': np.arange(0.5, shape[1] + 0.5)
        },
        name='phase'
    )
    filename = 'test_gaussian_in.nc'
    if os.path.exists(filename):
        os.remove(filename)
    encoding = {'phase': {'chunksizes': netcdf_engine}}
    encoding = {
        'phase_real': {'chunksizes': netcdf_chunk_size},
        'phase_imag': {'chunksizes': netcdf_chunk_size}
    }
    ds = xr.Dataset({
        'phase_real': phase.real,
        'phase_imag': phase.imag
    })    
    ds.to_netcdf(filename, engine=netcdf_engine, encoding=encoding, compute=True)
    ds = xr.open_dataset(filename, engine=netcdf_engine, chunks={'y': chunk_size[0], 'x': chunk_size[1]})
    phase = ds.phase_real + 1j * ds.phase_imag
    # filter data
    start_time = time.time()
    phase_filtered = dask_gaussian_filter(phase.data.real, (sigma_y, sigma_x), mode='reflect', truncate=2)
    phase_filtered = xr.DataArray(phase_filtered, dims=phase.dims, coords=phase.coords, name='phase_filtered')    
    filename = 'test_gaussian_out.nc'
    if os.path.exists(filename):
        os.remove(filename)
    encoding = {'phase_filtered': {'chunksizes': netcdf_chunk_size}}
    delayed = phase_filtered.to_netcdf(filename,
                                       engine=netcdf_engine,
                                       encoding=encoding,
                                       compute=False)
    dask.compute(delayed)
    end_time = time.time()
    # show output
    print(f'Processing time: {end_time - start_time:.2f} seconds')
    print(f'Filtered result: {phase_filtered}')

# Main block to prevent multiprocessing issues
if __name__ == '__main__':
    run_test()
@jakirkham
Copy link
Member

Is it possible this line causes a copy to a NumPy array?

phase_filtered = xr.DataArray(phase_filtered, dims=phase.dims, coords=phase.coords, name='phase_filtered')

IOW if that line and all subsequent lines are commented, what does memory usage look like?

@AlexeyPechnikov
Copy link
Author

Xarray integrates seamlessly with dask arrays, and the example code correctly produces an xarray dask object. At this stage, RAM usage remains minimal:

    phase_filtered = dask_gaussian_filter(phase.data.real, (sigma_y, sigma_x), mode='reflect', truncate=2)
    phase_filtered = xr.DataArray(phase_filtered, dims=phase.dims, coords=phase.coords, name='phase_filtered')
    print ('phase_filtered', phase_filtered)

phase_filtered <xarray.DataArray 'phase_filtered' (y: 50000, x: 50000)> Size: 20GB
dask.array<_trim, shape=(50000, 50000), dtype=float64, chunksize=(2048, 2048), chunktype=numpy.ndarray>
Coordinates:
  * y        (y) float64 400kB 0.5 1.5 2.5 3.5 4.5 ... 5e+04 5e+04 5e+04 5e+04
  * x        (x) float64 400kB 0.5 1.5 2.5 3.5 4.5 ... 5e+04 5e+04 5e+04 5e+04

And a portion of the array processes correctly without excessive memory usage:

phase_filtered[:2048,:2048].compute()

But when the full computation starts, even when lazily saving the data to a NetCDF file, it requires RAM for the entire array. I suspect there’s no proper variable cleanup in the gaussian_filter function, and the garbage collector is unable to free memory while the computation is ongoing.

For reference, in my xarray and dask code, I immediately remove unnecessary objects to avoid memory issues. For example:

        @dask.delayed
        def geo_block(ys_block, xs_block, stackval=None):
            from scipy.interpolate import RegularGridInterpolator

            def nangrid():
                return np.nan * np.zeros((ys_block.size, xs_block.size), dtype=np.float32)

            # use outer variables
            block_grid = data.sel({stackvar: stackval}) if stackval is not None else data
            trans_inv_block = trans_inv.sel(y=ys_block, x=xs_block).compute(n_workers=1)

            # check if the data block exists
            if not (trans_inv_block.y.size>0 and trans_inv_block.x.size>0):
                del block_grid, trans_inv_block
                return nangrid()

            # use trans table subset
            y = trans_inv_block.lt.values.ravel()
            x = trans_inv_block.ll.values.ravel()
            points = np.column_stack([y, x])
            del trans_inv_block

            # get interferogram full grid
            ys = data.lat.values
            xs = data.lon.values

            # this code spends additional time for the checks to exclude warnings
            if np.all(np.isnan(y)):
                del block_grid, ys, xs, points
                return nangrid()

            # calculate trans grid subset extent
            ymin, ymax = np.nanmin(y), np.nanmax(y)
            xmin, xmax = np.nanmin(x), np.nanmax(x)
            del y, x
            # and spacing
            dy = ys[1] - ys[0]
            dx = xs[1] - xs[0]

            # select required interferogram grid subset
            ys_subset = ys[(ys>ymin-dy)&(ys<ymax+dy)]
            xs_subset = xs[(xs>xmin-dx)&(xs<xmax+dx)]
            del ymin, ymax, xmin, xmax, dy, dx, ys, xs

            # for cropped interferogram we can have no valid pixels for the processing
            if ys_subset.size == 0 or xs_subset.size == 0:
                del ys_subset, xs_subset, points, block_grid
                return nangrid()

            values = block_grid.sel(lat=ys_subset, lon=xs_subset).compute(n_workers=1).data.astype(np.float64)
            del block_grid

            # perform interpolation
            interp = RegularGridInterpolator((ys_subset, xs_subset), values, method='nearest', bounds_error=False)
            grid_ra = interp(points).reshape(ys_block.size, xs_block.size).astype(np.float32)
            del ys_subset, xs_subset, points, values
            return grid_ra

At this point, the library seems impractical because it consumes the same amount of RAM as numpy/scipy functions, making it impossible to use on large arrays. There is no way to call it efficiently on a decent-sized array without running into memory issues.

@jakirkham
Copy link
Member

jakirkham commented Aug 23, 2024

Would recommend using the visualize method to look at the task graph. There may be some clues in the DAG that indicate what is going on. It can generate a PNG which GitHub supports in comments (so should be doable to include in this thread)

Should add this needs the Python Graphviz library. In Conda this is installable via the python-graphviz package. There are also wheels, but I don't know if they include the Graphviz CLI. So that portion of the dependency tree may need to be satisfied another way

@m-albert
Copy link
Collaborator

@AlexeyPechnikov next to the suggestion by John, can you confirm that your example code uses the amount of memory you'd expect after replacing the line

phase_filtered = dask_gaussian_filter(phase.data.real, (sigma_y, sigma_x), mode='reflect', truncate=2)

with

phase_filtered = dask.array.map_overlap(lambda x: x, phase.data.real, depth=1)

Given that many components are interacting in your code, this test will tell whether the problem you report is specific to the dask_image function.

@AlexeyPechnikov
Copy link
Author

@jakirkham

Would recommend using the visualize method to look at the task graph

It seems ineffective because the image is not readable:
mydask

The data structure of the variable phase_filtered for 30k x 30k grid is:
image

@m-albert Yes, it’s the same, as expected, because dask-image internally calls image.map_overlap. The issue seems to be deeper within the image.map_overlap processing. Interestingly, I haven’t encountered any problems with processing many small overlapping chunks, like 32x32 with a 16-pixel overlap, but it doesn’t perform well when processing a few large overlapping chunks.

@jakirkham
Copy link
Member

Albert's suggestion is a good one because Dask-image is a very light wrapper on top of Dask

Is it possible to adjust the reproducer to use Dask alone? If so, we might want to raise this upstream. Possibly improvements are needed to map_overlap

@m-albert
Copy link
Collaborator

Interestingly, I haven’t encountered any problems with processing many small overlapping chunks, like 32x32 with a 16-pixel overlap.

Okay, so in the case of small filter sizes dask_image is behaving as you'd expect (i.e. processing lazily with low memory requirements)?

but it doesn’t perform well when processing a few large overlapping chunks.

Could you give us an idea about the chunk sizes for which the function starts to not perform well for you?

Also, could you detail what you mean by "doesn't perform well"? Above you mention that the script uses 20GB, but if I understand you correctly it runs through.

Keep in mind that a large memory footprint is not problematic if you have the RAM available (and not necessarily unexpected). Actually, processing n chunks in parallel will require at least n times the memory required for processing a single chunk (more in the case of overlap). You mention that you're having trouble with large chunksizes: It might actually be fully expected that processing "a few overlapping chunks" in parallel would require 20GB, which is comparable to the size of the full input array.

If you want to actively limit the memory usage by dask, some options would be to

  1. Use dask without parallelism. This can be done e.g. using the dask scheduler "single-threaded" or configuring the distributed cluster such that it only uses one thread (in a single process)
  2. Use worker resources https://distributed.dask.org/en/stable/resources.html

@AlexeyPechnikov
Copy link
Author

@m-albert You can see the array and its chunks. For small overlaps, processing a few 2048x2048 chunks in parallel should require just a fraction of the memory compared to the full array size. I can process it on macOS/Linux with unlimited swap, but it breaks in Docker environments with the default swap (1GB or less) and sometimes on free Google Colab, where swap is disabled.

For reference, this command works without any issues and with minimal memory consumption (about 1-2GB, which is hard to even monitor):

phase_filtered = dask.array.map_blocks(lambda x: x, phase.data.real)

However, this command is 4-5 times slower, and requires about 10 times more RAM:

phase_filtered = dask.array.map_overlap(lambda x: x, phase.data.real, depth=1)

For a single-pixel overlap, the difference is extreme. Even for a full chunk overlap, the difference can be up to 9 times: the original block plus 8 surrounding ones, but the latter shouldn’t need to be materialized for each task when dealing with a small overlap. From my point of view, the problem lies in the full materialization of all overlapping chunks without immediate resource cleanup.

@m-albert
Copy link
Collaborator

@AlexeyPechnikov Thanks for these tests, it seems we can narrow down the problem to the use of map_overlap.

Here's an issue reporting exactly what you've detailed in your last comment: dask/dask#3671. Probably at this point it'll make sense to visualize the graph produced by map_overlap, as suggested by John and in the linked issue.

@AlexeyPechnikov
Copy link
Author

@m-albert As you can see, this issue has persisted in Dask for years and remains unresolved. It seems like this may continue to be a problem indefinitely. I’m not sure why you need the visualization, but here it is:

image

In my code, I use a workaround by processing only chunk coordinates and extracting and materializing only the required data:

def block_dask(ylim, xlim, depth):
...
    block_data = data.isel(y=slice(ylim[0]-depth[0], ylim[1]+ depth[0]),
                           x=slice(xlim[0]-depth[1], xlim[1]+ depth[1]))\
    .compute(n_workers=1)
...

# define blocks
chunks = data.chunks
ychunks, xchunks = chunks[1], chunks[2]
ychunks = np.concatenate([[0], np.cumsum(ychunks)])
xchunks = np.concatenate([[0], np.cumsum(xchunks)])
ylims = [(y1, y2) for y1, y2 in zip(ychunks, ychunks[1:])]
xlims = [(x1, x2) for x1, x2 in zip(xchunks, xchunks[1:])]

for ylim in ylims:
    for xlim in xlims:
        block = dask.array.from_delayed(dask.delayed(block_dask)(date1, date2, ylim, xlim, depth),
                                                 shape=((ylim[1]-ylim[0]), (xlim[1]-xlim[0])), ...)
        ...

However, in my recent tests, I’ve found that Dask’s blockwise can now be used effectively, yielding similar performance. How about developing a universal solution? Currently, dask-image doesn’t offer advantages over standard Dask wrappers for SciPy functions, but a solution for overlapping processing could make a big impact.

By the way, I develop an open-source Python InSAR software for satellite interferometry processing that can handle terabytes of data, even on a common laptop like an Apple Air. You can find the Google Colab links and Docker image at https://github.com/AlexeyPechnikov/pygmtsar Even complex examples provided here work in a Docker container with just 4-8GB RAM and 0.5-1GB swap. However, the code is tricky because of Dask-related issues, and it would be great to resolve this hassle completely.

@m-albert
Copy link
Collaborator

Indeed it seems like the graphs produced by map_overlap are not always efficiently computed (as @jakirkham also hinted to). I agree that dask_image.ndfilters could be a good place to provide solutions that improve on the behaviour of map_overlap, at least in its current form (there's also dask-image functionality that goes beyond map_overlap).

I’m not sure why you need the visualization, but here it is:

I've also played around with this a bit. Consider the following graph visualization for a 2D map_overlap:

import numpy as np
import dask.array as da

a = da.random.random((6, ) * 2, chunks=(2, ) * 2)
b = a.map_overlap(lambda x: x, depth=1)

b.visualize('map_overlap.png', color='order')

map_overlap

The color represents the order in which the tasks would be computed by the single threaded scheduler. From the visualization it becomes apparent that each output chunk depends on several input chunks. It seems that the input slices for each output slice are calculated at the same time (look at the third layer from below), instead of e.g. just after loading the relevant input chunk and keeping a slice of it in memory. So it could be that input chunk caching is creating the large memory footprint.

In the following code I used inline_functions from dask.optimization to unwind the dependencies between neighboring chunks. Basically, the idea is to (re-)compute the input chunks separately during calculation of a given output chunk (similarly to what the code snippet shared by @AlexeyPechnikov does explicitely).

from dask.optimization import inline_functions
import dask.array
def inline_first_hlg_layer(ar):
    """
    This function inlines the first high-level graph layer of a dask array.
    """

    layer_keys = list(ar.dask.layers.keys())
    no_dep_keys = [k for k, v in ar.dask.dependencies.items() if not len(v)]
    lowest_layer_key = [k for k in no_dep_keys
        if not max([lk in k and lk != k for lk in layer_keys])][0]
    outputs = [k for k in ar.dask
                if k[0].startswith(ar.name)]
    fast_function = ar.dask[[k for k in ar.dask.keys()
                        if k[0].startswith(lowest_layer_key)][0]][0]
    dsk_inlined = inline_functions(ar.dask,
                outputs,
                [fast_function],
                dependencies=ar.dask.get_all_dependencies(),
                )
    res_ar = dask.array.Array(dsk_inlined, ar.name, chunks=ar.chunks, dtype=ar.dtype)
    return res_ar

b_inlined = inline_first_hlg_layer(b)

b_inlined.visualize('map_overlap_inlined.png', color='order')

This yields the following nicely parallelised dask graph:
map_overlap_inlined

How does this modified graph structure affect memory usage? Let's have a look at a larger array (comparable to the code snippets above) and profile it:

from dask.diagnostics import ResourceProfiler, CacheProfiler, ProgressBar

a = da.random.random((20000,) * 2, chunks=(2000,) * 2)
b = a.map_overlap(lambda x: x, depth=1)
b_inlined = inline_first_hlg_layer(b)

os.system('rm -rf zarr1.zarr')
with ResourceProfiler(dt=.25) as rprof_inlined, ProgressBar(), CacheProfiler() as cprof_inlined:
    b_inlined.to_zarr('zarr1.zarr')

os.system('rm -rf zarr2.zarr')
with ResourceProfiler(dt=.25) as rprof, ProgressBar(), CacheProfiler() as cprof:
    b.to_zarr('zarr2.zarr')

from dask.diagnostics import visualize
visualize([rprof_inlined, rprof])
Screen Shot 2024-08-29 at 20 33 12 Screen Shot 2024-08-29 at 20 35 25

It seems that creating a linear and parallel graph structure by inlining the first graph layer

  • reduces the memory footprint and reduces memory peaks
  • leads to a different caching behaviour (that's more scalable?)
  • comes at the cost of increases in computation time

In my tests, this works for the threaded single machine scheduler and the distributed scheduler with Client(n_workers=1).

@AlexeyPechnikov Could you try this approach on your data? I.e. applying inline_first_hlg_layer to your output dask array while using Client(n_workers=1).

However, in my recent tests, I’ve found that Dask’s blockwise can now be used effectively, yielding similar performance.

Do you want to add some more detail on this?

@AlexeyPechnikov
Copy link
Author

@m-albert Yeah, the graph looks much better, and map_overlap works in about the same time and with similar memory consumption as map_blocks. However, the inline_first_hlg_layer() function produces errors with my data. I’ve fixed it somewhat for a 30k x 30k array, but it still doesn’t work for a 50k x 50k array, so my fix is definitely inaccurate. Please check the issue.

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Cell In[88], line 25
     21     res_ar = dask.array.Array(dsk_inlined, ar.name, chunks=ar.chunks, dtype=ar.dtype)
     22     return res_ar
---> 25 b_inlined = inline_first_hlg_layer(phase_filtered)
     26 b_inlined

Cell In[88], line 10, in inline_first_hlg_layer(ar)
      8 layer_keys = list(ar.dask.layers.keys())
      9 no_dep_keys = [k for k, v in ar.dask.dependencies.items() if not len(v)]
---> 10 lowest_layer_key = [k for k in no_dep_keys
     11     if not max([lk in k and lk != k for lk in layer_keys])][0]
     12 outputs = [k for k in ar.dask
     13             if k[0].startswith(ar.name)]
     14 fast_function = ar.dask[[k for k in ar.dask.keys()
     15                     if k[0].startswith(lowest_layer_key)][0]][0]

IndexError: list index out of range

I’m not confident with inline_functions; do you believe this approach is robust and can work stably in all cases? I initially saw it as more of a last-mile solution for a specific case, but it would be great if we could use it as a universal solution.

@AlexeyPechnikov
Copy link
Author

@m-albert

Do you want to add some more detail on this?

It seems there are no additional advantages to coding as I’ve shown using dask.delayed and merging the blocks; da.blockwise provides the same performance with more compact code:

da.blockwise(
            block_dask,
            'yx',
            data.y, 'y',
            data.x, 'x',
            ...
        )

Years ago, memory consumption was not predictable for da.blockwise on a Dask distributed cluster in some cases. As a result, I ended up using dask.delayed with immediate deletion of unused variables like this:

blocks_total = []
for ylim in ylims:
    blocks = []
    for xlim in xlims:
        block = dask.array.from_delayed(dask.delayed(block_dask)(ylim, xlim),
                                        shape=((ylim[1]-ylim[0]), (xlim[1]-xlim[0])), ...)
        blocks.append(block)
        del block
    blocks_total.append(blocks)
    del blocks
intf = xr.DataArray(dask.array.block(blocks_total), coords={'y': data.y, 'x': data.x})
del blocks_total

It’s not elegant, but it is (and was) predictable. Whether on large workstations or common laptops, we could process huge datasets. Recently, I rechecked, and da.blockwise just works. I hope we don’t need to replace it with a lot of dummy code today.

@jakirkham
Copy link
Member

jakirkham commented Aug 29, 2024

This probably changed when HighLevelGraphs were added to Dask ( for example: dask/dask#6510 )

If there is a way to rewrite map_overlap using the observations of performance here, would recommend pursuing that

We certainly are heavy users of map_overlap in dask-image. Though it would be better if could some this for all map_overlap users

@AlexeyPechnikov
Copy link
Author

@jakirkham

I found that dask-image does its job much better than dask-ml, for example, which crashes even on small rasters of about 1000x1000 pixels, while the wrapped SciPy functions work well (I reported the issue some years ago, but it seems developers were not interested). What I’m asking for is intended to provide easy tools to operate on terabyte datasets, even on an 8GB RAM host. It’s possible right now, but the code looks complicated, and users who want to use their own filtering or other custom operations (not included in my software) find it challenging. If dask-image could provide easy-to-use functions for multidimensional image operations, it would be a truly unique tool, competing with well-established big data processing platforms (which, in my opinion, are technically terrible but more usable without deep knowledge).

@m-albert
Copy link
Collaborator

@jakirkham Agreed that it'd be best to improve map_overlap generally.

@AlexeyPechnikov The function was mainly meant for a quick test of inlining and I think isn't in general compatible with HighLevelGraphs as John mentions.

I think this is relevant to us and directly relates to our issue: https://docs.dask.org/en/stable/order.html. In summary, @TomAugspurger added optional inlining of arrays in dask.array.from_zarr here, solving this issue.

One idea would be to implement the same for map_overlap. For the same reasons as described in the linked docs section, we probably cannot assume that inlining the overlap computation is desirable in every case. So having an additional parameter inline_array in map_overlap could improve performance in specific cases (i.e. in the context of dask-image).

In terms of the implementation, this would probably require inlining here.

@m-albert
Copy link
Collaborator

I was just reading some more about this and found a large discussion thread on map_overlap performance: dask/dask#7404.

Actually, the high level graph layer for array overlap has been implemented by @GenevieveBuckley!

If I understand the discussions correctly, this layer helped delaying the creation of the low-level graph ("materializing" it). The inline_array argument suggested above could further help with (optional) inlining of the input array slicing / concatenation tasks.

@AlexeyPechnikov
Copy link
Author

@m-albert You’re thinking in terms of Dask implementation… From my perspective, it’s more straightforward to think about an independent implementation and then fit the concept to Dask. There are actually just a few key points needed for an effective map_overlap implementation:

  1. Task Queue for Chunks: Build a queue of tasks for the chunks to be processed, and call a function that internally reads and materializes the correct data extent (i.e., not the full chunks covering the overlapping area, but only the required pixels). This can be done using Dask delayed or Dask map_blocks. There’s also chunk caching, which (partially) prevents multiple readings. I’ve used this method for years because it’s the only one that works well with large datasets.

  2. Graph of Dependencies: Add a graph of chunks required for each queued task, and reorganize the execution to read each chunk only once and clean it up right after all dependent jobs are completed. Prefetching with caching isn’t too complicated, and I believe we can achieve this in Dask, though I’m not sure how to separate it from other tasks. I don’t know how to implement this more effectively within the Dask paradigm, but a separate execution graph for map_overlap would be beneficial. There should be mechanisms to achieve this behavior in Dask—it’s somewhat similar to a persist() call but without initiating actual calculations.

  3. Integration with Dask’s Graph: Integrating the graph into the full Dask processing graph is complicated because Dask manages chunks, whereas we need to manage fractional chunks instead. However, I suppose this isn’t significant for performance because if we have only one additional chunk reading for a map_overlap call, it doesn’t affect performance on low-memory hosts (where we can’t hold chunks in cache for long) and is negligible for high-memory hosts (where all read chunks can be cached indefinitely). In other words, do we really need to integrate the map_overlap operation’s graph with the full Dask processing graph? I believe this is the most complex aspect, and it can be omitted.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants