Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

handle data movement with MLDataDevices.jl #2492

Merged
merged 21 commits into from
Oct 11, 2024
Merged

handle data movement with MLDataDevices.jl #2492

merged 21 commits into from
Oct 11, 2024

Conversation

CarloLucibello
Copy link
Member

@CarloLucibello CarloLucibello commented Oct 10, 2024

Implements the move to MLDataDevices.jl. Closes #2482

Thanks to the extended functionality given by MLDataDevices.jl we can also easily fix #2490

The gpu/cpu logic could also be offloaded to MLDataDevices.jl, will do that in a followup PR.

Unfortunately this PR has become rather large also due to some unrelated changes to fix the CI since it was broken due to:

  • expected arrays memory sizes as printed in some of the docstrings changed in julia v0.11
  • some second derivatives tests for Zygote now fail possibly due to recent changes in Zygote or to julia v0.11
  • some gradient tests involving Flux.params() (e.g. for L2 regularization) now fail. possibly due to recent changes in Zygote or to julia v0.11
  • Enzyme tests seem to be completely broken (cc @wsmoses) [edit: Enzyme doesn't work on v0.11 yet]

Worth openining separate Issues to keep track of these.

PR Checklist

  • Tests are added
  • Entry in NEWS.md
  • Documentation, if applicable

@CarloLucibello CarloLucibello changed the title Cl/mldata handle data movement with MLDataDevices.jl Oct 10, 2024
@wsmoses
Copy link
Contributor

wsmoses commented Oct 11, 2024

I assume the error is due to Enztme not yet supporting 1.11, does it work on the LTS?

@CarloLucibello
Copy link
Member Author

Ok I think this is ready to go.

Buildkite passes on CUDA but fails on Metal and AMDGPU, due to the fact that it seems to be failing to load the preference.
Tests for Metal locally pass on my laptop.
This can be solved later since this PR already fixes a ton of CI breakage and I'd rather get it sooner than later.
We also want to address #2490 as soon as possible.

@CarloLucibello
Copy link
Member Author

I assume the error is due to Enztme not yet supporting 1.11, does it work on the LTS?

ah ok, I didn't know. It works on 1.10 indeed.

@wsmoses
Copy link
Contributor

wsmoses commented Oct 11, 2024

Maybe also adding an LTS cI run would be helpful here?

(Now that it’s released, were starting to adapt to the new intrinsics and array operators [this was sadly for us a more significant set of changes])

@wsmoses
Copy link
Contributor

wsmoses commented Oct 11, 2024

I kind of imagine that happened for most other packages too since the release just happened the other day

@darsnack
Copy link
Member

Should probably replace Flux.rng_from_array with get_device and default_device_rng.

@ToucheSir
Copy link
Member

One note for the follow-up: the MLDataDevices path does not flip conv kernels as is required for the AMDGPU backend:

# CPU -> GPU
function Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::CPU_CONV)
flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2))
_conv_basetype(m)(
Adapt.adapt(to, m.σ),
Adapt.adapt(to, flipped_weight),
Adapt.adapt(to, m.bias),
_other_args(m)...)
end
# Don't adapt again.
Adapt.adapt_structure(to::FluxAMDGPUAdaptor, m::AMDGPU_CONV) = m
# GPU -> CPU
function Adapt.adapt_structure(to::FluxCPUAdaptor, m::AMDGPU_CONV)
dims = ntuple(i -> i, ndims(m.weight) - 2)
_conv_basetype(m)(
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
Adapt.adapt(to, m.bias), _other_args(m)...)
end
.

We should add this functionality before cutting any releases.

CarloLucibello added a commit to CarloLucibello/LuxDeviceUtils.jl that referenced this pull request Oct 13, 2024
After FluxML/Flux.jl#2492 also Flux relies on MLDataDevices.
avik-pal pushed a commit to LuxDL/MLDataDevices.jl that referenced this pull request Oct 13, 2024
After FluxML/Flux.jl#2492 also Flux relies on MLDataDevices.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

The dependency error about Flux->FluxMPIExt occurs when updating to Julia 1.11 use MLDataDevices.jl?
5 participants