Skip to content

Commit

Permalink
make gpu(x) = gpu_device()(x) (#2502)
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello authored Oct 22, 2024
1 parent 31dccd1 commit c9bab66
Show file tree
Hide file tree
Showing 18 changed files with 124 additions and 579 deletions.
5 changes: 1 addition & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
NCCL = "3fe64909-d7a1-4096-9b7d-7a0f12cf0f6b"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

Expand All @@ -40,7 +39,6 @@ FluxCUDAcuDNNExt = ["CUDA", "cuDNN"]
FluxEnzymeExt = "Enzyme"
FluxMPIExt = "MPI"
FluxMPINCCLExt = ["CUDA", "MPI", "NCCL"]
FluxMetalExt = "Metal"

[compat]
AMDGPU = "1"
Expand All @@ -50,11 +48,10 @@ ChainRulesCore = "1.12"
Compat = "4.10.0"
Enzyme = "0.12, 0.13"
Functors = "0.4"
MLDataDevices = "1.2.0"
MLDataDevices = "1.4.0"
MLUtils = "0.4"
MPI = "0.20.19"
MacroTools = "0.5"
Metal = "0.5, 1"
NCCL = "0.1.1"
NNlib = "0.9.22"
OneHotArrays = "0.2.4"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ makedocs(
"Flat vs. Nested" => "reference/destructure.md",
"Callback Helpers" => "reference/training/callbacks.md",
"Gradients -- Zygote.jl" => "reference/training/zygote.md",
"Transfer Data to GPU -- MLDataDevices.jl" => "reference/data/mldatadevices.md",
"Batching Data -- MLUtils.jl" => "reference/data/mlutils.md",
"OneHotArrays.jl" => "reference/data/onehot.md",
"Low-level Operations -- NNlib.jl" => "reference/models/nnlib.md",
Expand Down
208 changes: 77 additions & 131 deletions docs/src/guide/gpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,68 +16,13 @@ in your code. Notice that for CUDA, explicitly loading also `cuDNN` is not requi
!!! compat "Flux ≤ 0.13"
Old versions of Flux automatically installed CUDA.jl to provide GPU support. Starting from Flux v0.14, CUDA.jl is not a dependency anymore and has to be installed manually.

## Checking GPU Availability

By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following:

```julia
julia> using CUDA

julia> CUDA.functional()
true
```

For AMD GPU:

```julia
julia> using AMDGPU

julia> AMDGPU.functional()
true

julia> AMDGPU.functional(:MIOpen)
true
```

For Metal GPU:

```julia
julia> using Metal

julia> Metal.functional()
true
```

## Selecting GPU backend

Available GPU backends are: `CUDA`, `AMDGPU` and `Metal`.

Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use.

There are two ways you can specify it:

- From the REPL/code in your project, call `Flux.gpu_backend!("AMDGPU")` and restart (if needed) Julia session for the changes to take effect.
- In `LocalPreferences.toml` file in you project directory specify:
```toml
[Flux]
gpu_backend = "AMDGPU"
```

Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable:

```julia
julia> Flux.GPU_BACKEND
"CUDA"
```

The current backend will affect the behaviour of methods like the method `gpu` described below.

## Basic GPU Usage

Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl), [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl), and [Metal.jl](https://github.com/JuliaGPU/Metal.jl).
Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.

For example, we can use `CUDA.CuArray` (with the `cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU.
For example, we can use `CUDA.CuArray` (with the `CUDA.cu` converter) to run our [basic example](@ref man-basics) on an NVIDIA GPU.

(Note that you need to have CUDA available to use CUDA.CuArray – please see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) instructions for more details.)

Expand Down Expand Up @@ -146,6 +91,50 @@ julia> x |> cpu
0.7766742
```

## Using device objects

In Flux, you can create `device` objects which can be used to easily transfer models and data to GPUs (and defaulting to using the CPU if no GPU backend is available).
These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux uses internally and re-exports.

Device objects can be automatically created using the [`cpu_device`](@ref MLDataDevices.cpu_device) and [`gpu_device`](@ref MLDataDevices.gpu_device) functions. For instance, the `gpu` and `cpu` functions are just convenience functions defined as

```julia
cpu(x) = cpu_device()(x)
gpu(x) = gpu_device()(x)
```

`gpu_device` performs automatic GPU device selection and returns a device object:
- If no GPU is available, it returns a `CPUDevice` object.
- If a LocalPreferences file is present, then the backend specified in the file is used. To set a backend, use `Flux.gpu_backend!(<backend_name>)`. If the trigger package corresponding to the device is not loaded (e.g. with `using CUDA`), then a warning is displayed.
- If no LocalPreferences option is present, then the first working GPU with loaded trigger package is used.

Consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):

```julia-repl
julia> using Flux, CUDA;
julia> device = gpu_device() # returns handle to an NVIDIA GPU if available
(::CUDADevice{Nothing}) (generic function with 4 methods)
julia> model = Dense(2 => 3);
julia> model.weight # the model initially lives in CPU memory
3×2 Matrix{Float32}:
-0.984794 -0.904345
0.720379 -0.486398
0.851011 -0.586942
julia> model = model |> device # transfer model to the GPU
Dense(2 => 3) # 9 parameters
julia> model.weight
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
-0.984794 -0.904345
0.720379 -0.486398
0.851011 -0.586942
```


## Transferring Training Data

In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways:
Expand Down Expand Up @@ -227,65 +216,8 @@ To select specific devices by device id:
$ export CUDA_VISIBLE_DEVICES='0,1'
```


More information for conditional use of GPUs in CUDA.jl can be found in its [documentation](https://cuda.juliagpu.org/stable/installation/conditional/#Conditional-use), and information about the specific use of the variable is described in the [Nvidia CUDA blog post](https://developer.nvidia.com/blog/cuda-pro-tip-control-gpu-visibility-cuda_visible_devices/).

## Using device objects

As a more convenient syntax, Flux allows the usage of GPU `device` objects which can be used to easily transfer models to GPUs (and defaulting to using the CPU if no GPU backend is available). This syntax has a few advantages including automatic selection of the GPU backend and type stability of data movement.
These features are provided by [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl) package, that Flux's uses internally and re-exports.

A `device` object can be created using the [`gpu_device`](@ref MLDataDevices.gpu_device) function.
`gpu_device` first checks for a GPU preference, and if possible returns a device for the preference backend. For instance, consider the following example, where we load the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) package to use an NVIDIA GPU (`"CUDA"` is the default preference):

```julia-repl
julia> using Flux, CUDA;
julia> device = gpu_device() # returns handle to an NVIDIA GPU if available
(::CUDADevice{Nothing}) (generic function with 4 methods)
julia> model = Dense(2 => 3);
julia> model.weight # the model initially lives in CPU memory
3×2 Matrix{Float32}:
-0.984794 -0.904345
0.720379 -0.486398
0.851011 -0.586942
julia> model = model |> device # transfer model to the GPU
Dense(2 => 3) # 9 parameters
julia> model.weight
3×2 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
-0.984794 -0.904345
0.720379 -0.486398
0.851011 -0.586942
```

The device preference can also be set via the [`gpu_backend!`](@ref MLDataDevices.gpu_backend!) function. For instance, below we first set our device preference to `"AMDGPU"`:

```julia-repl
julia> gpu_backend!("AMDGPU")
[ Info: GPU backend has been set to AMDGPU. Restart Julia to use the new backend.
```
If no functional GPU backend is available, the device will default to a CPU device.
You can also explictly request a CPU device by calling the [`cpu_device`](@ref MLDataDevices.cpu_device) function.

```julia-repl
julia> using Flux, MLDataDevices
julia> cdev = cpu_device()
(::CPUDevice{Nothing}) (generic function with 4 methods)
julia> gdev = gpu_device(force=true) # force GPU device, error if no GPU is available
(::CUDADevice{Nothing}) (generic function with 4 methods)
julia> model = Dense(2 => 3); # model in CPU memory
julia> gmodel = model |> gdev; # transfer model to GPU
julia> cmodel = gmodel |> cdev; # transfer model back to CPU
```

## Data movement across GPU devices

Expand Down Expand Up @@ -344,24 +276,6 @@ CuDevice(1): NVIDIA TITAN RTX

Due to a limitation in `Metal.jl`, currently this kind of data movement across devices is only supported for `CUDA` and `AMDGPU` backends.

!!! warning "Printing models after moving to a different device"

Due to a limitation in how GPU packages currently work, printing
models on the REPL after moving them to a GPU device which is different
from the current device will lead to an error.


```@docs
MLDataDevices.cpu_device
MLDataDevices.default_device_rng
MLDataDevices.get_device
MLDataDevices.gpu_device
MLDataDevices.gpu_backend!
MLDataDevices.get_device_type
MLDataDevices.reset_gpu_device!
MLDataDevices.supported_gpu_backends
MLDataDevices.DeviceIterator
```

## Distributed data parallel training

Expand Down Expand Up @@ -479,3 +393,35 @@ julia> set_preferences!("Flux", "FluxDistributedMPICUDAAware" => true)

We don't run CUDA-aware tests so you're running it at own risk.


## Checking GPU Availability

By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following:

```julia
julia> using CUDA
julia> CUDA.functional()
true
```

For AMD GPU:

```julia
julia> using AMDGPU
julia> AMDGPU.functional()
true
julia> AMDGPU.functional(:MIOpen)
true
```

For Metal GPU:

```julia
julia> using Metal
julia> Metal.functional()
true
```
19 changes: 19 additions & 0 deletions docs/src/reference/data/mldatadevices.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Transferring data across devices

Flux relies on the [MLDataDevices.jl](https://github.com/LuxDL/MLDataDevices.jl/blob/main/src/public.jl) package to manage devices and transfer data across them. You don't have to explicitly use the package, as Flux re-exports the necessary functions and types.

```@docs
MLDataDevices.cpu_device
MLDataDevices.default_device_rng
MLDataDevices.functional
MLDataDevices.get_device
MLDataDevices.gpu_device
MLDataDevices.gpu_backend!
MLDataDevices.get_device_type
MLDataDevices.isleaf
MLDataDevices.loaded
MLDataDevices.reset_gpu_device!
MLDataDevices.set_device!
MLDataDevices.supported_gpu_backends
MLDataDevices.DeviceIterator
```
31 changes: 2 additions & 29 deletions ext/FluxAMDGPUExt/FluxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,49 +3,22 @@ module FluxAMDGPUExt
import ChainRulesCore
import ChainRulesCore: NoTangent
import Flux
import Flux: FluxCPUAdaptor, FluxAMDGPUAdaptor, _amd, adapt_storage, fmap
import Flux: adapt_storage, fmap
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
import NNlib
using MLDataDevices: MLDataDevices
using MLDataDevices
using AMDGPU
using Adapt
using Random
using Zygote

const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat

# Set to boolean on the first call to check_use_amdgpu
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)


function check_use_amdgpu()
if !isnothing(USE_AMDGPU[])
return
end

USE_AMDGPU[] = AMDGPU.functional()
if USE_AMDGPU[]
if !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
end
else
@info """
The AMDGPU function is being called but AMDGPU.jl is not functional.
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
""" maxlog=1
end
return
end

ChainRulesCore.@non_differentiable check_use_amdgpu()

include("functor.jl")
include("batchnorm.jl")
include("conv.jl")

function __init__()
Flux.AMDGPU_LOADED[] = true
end

# TODO
# fail early if input to the model is not on the device (e.g. on the host)
Expand Down
Loading

0 comments on commit c9bab66

Please sign in to comment.