Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow performance using concat_input_dims=True #165

Open
cmdupuis3 opened this issue Jan 31, 2023 · 10 comments
Open

Slow performance using concat_input_dims=True #165

cmdupuis3 opened this issue Jan 31, 2023 · 10 comments
Labels
bug Something isn't working

Comments

@cmdupuis3
Copy link

What is your issue?

I've been doing some ad-hoc performance profiling, and it looks like _drop_input_dims is always the culprit for why my batch generation runs slow. I blame the deep copy here.

However, in light of #164, I don't really see the point of this subroutine. Can someone explain what this subroutine does?

@cmdupuis3
Copy link
Author

Tagging @rabernat and @maxrjones

@rabernat
Copy link
Contributor

Hi Chris! Thanks for looking into this. Can you share your performance profiling script so other can reproduce the issue.

I blame the deep copy here.

That is not a deep copy. It's a very shallow copy. In general, copying xarray object is extremely fast and cheap. No actual data ever copied.

Can someone explain what this subroutine does?

This is an internal function. It's purpose is described in a comment

# remove input_dims coordinates from datasets, rename the dimensions
# then put intput_dims back in as coordinates

@cmdupuis3
Copy link
Author

It was very ad hoc profiling, basically just running, stopping, and checking where the stack trace went a number of times. I can try to get something more official but that would take a bit of doing.

If it's not a deep copy, there's something else very wrong. I'm getting gigs of ram use training with a few megs of data with concat_input_dims=True. That line got the most hits, and the next most common was line 317.

The docs make it sound like it does something and then undoes it again. As far as I can see, aside from appending "_input" in some circumstances (which is an issue of its own #164), it doesn't do anything, but that could just be another facet of the "all dims are input dims" problem.

@rabernat
Copy link
Contributor

If you're doing profiling, I highly recommend using snakeviz to visualize where you program is spending time.

If you're using more RAM than you expect, it's quite likely that it's a chunk-related issue. Remember that Zarr chunks have to be loaded all at once. It's not possible to do a partial read of a chunk.

If you could share a reproducible snippet of code, it would help us help you.

@cmdupuis3
Copy link
Author

On it. You can see my surface_currents notebook here.

As far as chunking, everything was chunked by time slice, so I should be good there, right?

@rabernat
Copy link
Contributor

I am trying to run your notebook now and I see what you're saying about the performance.

I am just trying to profile getting a single batch

%%prun
for batch in bgen:
    break

It has already been running for 10 minutes... 🙄 This is a tiny piece of data. Something is seriously wrong.

@rabernat
Copy link
Contributor

It looks like concat_input_dims=True has a major impact on the speed here. Basically, what is happening is a huge and very inefficient reshaping of the data. This is almost exactly what Xarray's rolling construct function does.

@rabernat
Copy link
Contributor

Here is a minimal reproducer for this issue

import numpy as np
import xarray as xr
import xbatcher as xb


# original size was 300, 250; scaled down to make debugging faster
nlat, nlon = 30, 250

ds = xr.Dataset(
    {
        "SST": (('nlat', 'nlon'), np.random.rand(nlat, nlon)),
        "SSH": (('nlat', 'nlon'), np.random.rand(nlat, nlon))
    }
)

bgen = xb.BatchGenerator(
    ds,
    input_dims={'nlon': 3, 'nlat': 3},
    input_overlap={'nlon': 2, 'nlat': 2},
    concat_input_dims=True
)

%time batch = next(iter(bgen))

On the LEAP hub, I'm getting 3.3 s for that very small size example and 36.1 s for the origin 300, 250 shape.

This is suitable for profiling with snakeviz.

%%snakeviz
batch = next(iter(bgen))

Screen Shot 2023-01-30 at 9 58 41 PM

Most of the time is spent on concat, although _drop_input_dims is also significant.

It would be interesting to compare this to rolling.construct to see if it's any more efficient.

If not, I would consider trying to bypass xarray completely internally. It's creating lots of overhead that we don't necessarily need.

@rabernat
Copy link
Contributor

rabernat commented Jan 31, 2023

Here's an alternative way to accomplish almost the same thing using xarray rolling.construct

batch = (
    ds
    .rolling({"nlat": 3, "nlon": 3})
    .construct({"nlat": "nlat_input", "nlon": "nlon_input"})
    .stack({"input_batch": ("nlat", "nlon")}, create_index=False)
)

For me this ran in 6.95 ms for the full 300 x 250 input (compared to 30 s for the xbatcher method).

With rolling, we don't have the ability to vary input_overlap explicitly. I'd be more than willing to give that up for a 5000x performance boost. 🚀 😉

@rabernat
Copy link
Contributor

This post on numpy stride tricks is extremely relevant.

@maxrjones maxrjones changed the title _drop_input_dims is hella slow Slow performance using concat_input_dims=True Feb 15, 2023
@weiji14 weiji14 added the bug Something isn't working label Feb 19, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants