Rudimentary Flux functionality crawling #159
Replies: 3 comments 6 replies
-
@rgobbel Can you share the changes to Flux.jl and the data you used so that we can try this out? |
Beta Was this translation helpful? Give feedback.
-
Somewhere in my fumbling around to get it working, the backward compatibility code disappeared. I've recreated it, and done a minimal check that I could run FluxCLI.jl in both 1.8.5 and 1.9.0-rc2. In addition to restoring the backward compatibility code, I changed a misleading message that mentions the GPU, when it's really just CUDA. Before running anything with Metal, you need to run |
Beta Was this translation helpful? Give feedback.
-
Apologies for the long radio silence. I finally got most of the life/work stuff out of the way, so I'm starting to have more time to investigate this. The first thing I noticed is that a lot of the slowness were from Metal/GPU overhead, which is somewhat expected. For example if we apply a 8x8 Dense layer on a 8x8 input, the GPU code is about 1000x slower: julia> a = Dense(8 => 8); da = gpu(a);
julia> x = rand(Float32, 8, 8); dx = gpu(x);
julia> @btime a(x);
199.583 ns (2 allocations: 672 bytes)
julia> @btime Metal.@sync da(dx);
209.709 μs (355 allocations: 9.10 KiB) However if the input sizes are increased to 1024x1024, the GPU is now ahead: julia> a = Dense(1024 => 1024); da = gpu(a);
julia> x = rand(Float32, 1024, 1024); dx = gpu(x);
julia> @btime a(x);
3.388 ms (4 allocations: 8.00 MiB)
julia> @btime Metal.@sync da(dx);
1.302 ms (364 allocations: 9.24 KiB) These tests were done on I also stepped through the forward pass execution line-by-line-ish, and didn't see anything obviously wrong (the correct temporary However, I did recall seeing a lot of time spent in metallib compilation when profiling the full FluxCLI training loop. One possibility is that Zygote is messing up |
Beta Was this translation helpful? Give feedback.
-
With some fairly simple changes to Flux.jl, I got a very simple program working:
It's extremely slow (CPU-only must be at least 100 times faster), but it does reduce the error:
A little later:
At this point it's no more than a proof of concept (no layers other than a basic Dense layer, and obviously far from optimized), but it's something. Suggestions welcomed.
Beta Was this translation helpful? Give feedback.
All reactions