-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Slow performance using concat_input_dims=True #165
Comments
Tagging @rabernat and @maxrjones |
Hi Chris! Thanks for looking into this. Can you share your performance profiling script so other can reproduce the issue.
That is not a deep copy. It's a very shallow copy. In general, copying xarray object is extremely fast and cheap. No actual data ever copied.
This is an internal function. It's purpose is described in a comment xbatcher/xbatcher/generators.py Lines 312 to 313 in 673e3ca
|
It was very ad hoc profiling, basically just running, stopping, and checking where the stack trace went a number of times. I can try to get something more official but that would take a bit of doing. If it's not a deep copy, there's something else very wrong. I'm getting gigs of ram use training with a few megs of data with The docs make it sound like it does something and then undoes it again. As far as I can see, aside from appending "_input" in some circumstances (which is an issue of its own #164), it doesn't do anything, but that could just be another facet of the "all dims are input dims" problem. |
If you're doing profiling, I highly recommend using snakeviz to visualize where you program is spending time. If you're using more RAM than you expect, it's quite likely that it's a chunk-related issue. Remember that Zarr chunks have to be loaded all at once. It's not possible to do a partial read of a chunk. If you could share a reproducible snippet of code, it would help us help you. |
On it. You can see my surface_currents notebook here. As far as chunking, everything was chunked by time slice, so I should be good there, right? |
I am trying to run your notebook now and I see what you're saying about the performance. I am just trying to profile getting a single batch %%prun
for batch in bgen:
break It has already been running for 10 minutes... 🙄 This is a tiny piece of data. Something is seriously wrong. |
It looks like |
Here is a minimal reproducer for this issue import numpy as np
import xarray as xr
import xbatcher as xb
# original size was 300, 250; scaled down to make debugging faster
nlat, nlon = 30, 250
ds = xr.Dataset(
{
"SST": (('nlat', 'nlon'), np.random.rand(nlat, nlon)),
"SSH": (('nlat', 'nlon'), np.random.rand(nlat, nlon))
}
)
bgen = xb.BatchGenerator(
ds,
input_dims={'nlon': 3, 'nlat': 3},
input_overlap={'nlon': 2, 'nlat': 2},
concat_input_dims=True
)
%time batch = next(iter(bgen)) On the LEAP hub, I'm getting 3.3 s for that very small size example and 36.1 s for the origin 300, 250 shape. This is suitable for profiling with snakeviz.
Most of the time is spent on It would be interesting to compare this to If not, I would consider trying to bypass xarray completely internally. It's creating lots of overhead that we don't necessarily need. |
Here's an alternative way to accomplish almost the same thing using xarray rolling.construct batch = (
ds
.rolling({"nlat": 3, "nlon": 3})
.construct({"nlat": "nlat_input", "nlon": "nlon_input"})
.stack({"input_batch": ("nlat", "nlon")}, create_index=False)
) For me this ran in 6.95 ms for the full 300 x 250 input (compared to 30 s for the xbatcher method). With rolling, we don't have the ability to vary |
This post on numpy stride tricks is extremely relevant. |
_drop_input_dims
is hella slow
What is your issue?
I've been doing some ad-hoc performance profiling, and it looks like
_drop_input_dims
is always the culprit for why my batch generation runs slow. I blame the deep copy here.However, in light of #164, I don't really see the point of this subroutine. Can someone explain what this subroutine does?
The text was updated successfully, but these errors were encountered: