diff --git a/.JuliaFormatter.toml b/.JuliaFormatter.toml index f1f84c1..22c3407 100644 --- a/.JuliaFormatter.toml +++ b/.JuliaFormatter.toml @@ -1,6 +1,5 @@ style = "sciml" whitespace_in_kwargs = false -always_use_return = true margin = 92 indent = 4 format_docstrings = true diff --git a/.buildkite/pipeline.yml b/.buildkite/pipeline.yml index 8feda5f..1e9319d 100644 --- a/.buildkite/pipeline.yml +++ b/.buildkite/pipeline.yml @@ -181,7 +181,32 @@ steps: julia: - "1" + - group: ":julia: oneAPI GPU" + steps: + - label: ":julia: Julia: {{matrix.julia}} + oneAPI" + plugins: + - JuliaCI/julia#v1: + version: "{{matrix.julia}}" + - JuliaCI/julia-test#v1: + test_args: "--quickfail" + - JuliaCI/julia-coverage#v1: + codecov: true + dirs: + - src + - ext + env: + GROUP: "oneAPI" + agents: + queue: "juliagpu" + intel: "*" + if: build.message !~ /\[skip tests\]/ + timeout_in_minutes: 60 + matrix: + setup: + julia: + - "1" + env: - RETESTITEMS_NWORKERS: 2 + RETESTITEMS_NWORKERS: 8 RETESTITEMS_NWORKER_THREADS: 2 SECRET_CODECOV_TOKEN: "PxSr3Y7vdbiwaoX51uGykPsogxmP1IOBt5Z8TwP9GqDxIrvFocEVV2DR4Bebee12G/HYvXtQTyYXH49DpzlsfJ7ri1GQZxd9WRr+aM1DDYmzfDCfpadp4hMoJ5NQvmc/PzeGrNWOOaewaLTUP1eEaG4suygZN0lc5q9BCchIJeqoklGms5DVt/HtfTmwoD/s4wGoIJINi4RoFgnCAkzSh11hTAkyjVerfBGWEi/8E6+WBq3UKwaW4HnT02wG9qFnD4XkHpIpjMxJTpdBn5ufKI+QoJ7qJHlwqgDCtsOCblApccLTjH/BnTahNoSb/b0wdS/cblOTrtdPGzZ5UvmQ4Q==;U2FsdGVkX1/Ji2Nqeq3tqTYCBik6iXILP+rriPRqj/qxhFu4vBWWT3UnlfqDzj6oVdXyuKt0+5e+x33x2S0mBw==" diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index fce13ab..16b0c1b 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -12,16 +12,14 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: ${{ startsWith(github.ref, 'refs/pull/') }} jobs: - test: - name: Julia ${{ matrix.version }} - ${{ matrix.os }} - ${{ github.event_name }} - runs-on: ${{ matrix.os }} + test-general: + name: Julia ${{ matrix.version }} - ubuntu-latest - ${{ github.event_name }} + runs-on: ubuntu-latest strategy: fail-fast: false matrix: version: - "1" - os: - - ubuntu-latest steps: - uses: actions/checkout@v4 - uses: julia-actions/setup-julia@v2 @@ -48,3 +46,40 @@ jobs: token: ${{ secrets.CODECOV_TOKEN }} verbose: true fail_ci_if_error: true + + test-mac-intel: # This is mostly for coverage purposes + name: Julia ${{ matrix.version }} - macos-latest - ${{ github.event_name }} + runs-on: macos-latest + strategy: + fail-fast: false + matrix: + version: + - "1" + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: ${{ matrix.version }} + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + - uses: julia-actions/julia-runtest@v1 + env: + GROUP: Metal + - uses: julia-actions/julia-processcoverage@v1 + with: + directories: src,ext + - uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{ secrets.CODECOV_TOKEN }} + verbose: true + fail_ci_if_error: true \ No newline at end of file diff --git a/.github/workflows/FormatCheck.yml b/.github/workflows/FormatCheck.yml index ac75c52..0ddeb4e 100644 --- a/.github/workflows/FormatCheck.yml +++ b/.github/workflows/FormatCheck.yml @@ -1,40 +1,9 @@ -name: FormatCheck +name: Format suggestions -on: - push: - branches: - - 'main' - - 'release-' - tags: ['*'] - pull_request: +on: [pull_request] jobs: - build: - runs-on: ${{ matrix.os }} - strategy: - matrix: - julia-version: ["1"] - julia-arch: [x86] - os: [ubuntu-latest] + code-style: + runs-on: ubuntu-latest steps: - - uses: julia-actions/setup-julia@latest - with: - version: ${{ matrix.julia-version }} - - - uses: actions/checkout@v4 - - name: Install JuliaFormatter and format - run: | - julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))' - julia -e 'using JuliaFormatter; format(".", verbose=true)' - - name: Format check - run: | - julia -e ' - out = Cmd(`git diff --name-only`) |> read |> String - if out == "" - exit(0) - else - @error "Some files have not been formatted !!!" - write(stdout, out) - exit(1) - end' - \ No newline at end of file + - uses: julia-actions/julia-format@v3 diff --git a/Project.toml b/Project.toml index aadadd7..cd57505 100644 --- a/Project.toml +++ b/Project.toml @@ -1,12 +1,11 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.20" +version = "0.1.21" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" @@ -18,68 +17,80 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" Metal = "dde4c033-4e86-420c-a63e-0dd931031962" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" [extensions] LuxDeviceUtilsAMDGPUExt = "AMDGPU" LuxDeviceUtilsCUDAExt = "CUDA" LuxDeviceUtilsFillArraysExt = "FillArrays" LuxDeviceUtilsGPUArraysExt = "GPUArrays" -LuxDeviceUtilsLuxAMDGPUExt = "LuxAMDGPU" LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" -LuxDeviceUtilsMetalGPUArraysExt = ["GPUArrays", "Metal"] +LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" +LuxDeviceUtilsReverseDiffExt = "ReverseDiff" LuxDeviceUtilsSparseArraysExt = "SparseArrays" +LuxDeviceUtilsTrackerExt = "Tracker" LuxDeviceUtilsZygoteExt = "Zygote" +LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] [compat] AMDGPU = "0.8.4, 0.9" Adapt = "4" Aqua = "0.8.4" +ArrayInterface = "7.11" CUDA = "5.2" -ChainRulesCore = "1.20" +ChainRulesCore = "1.23" +ChainRulesTestUtils = "1.13.0" ComponentArrays = "0.15.8" ExplicitImports = "1.4.1" -FastClosures = "0.3.2" FillArrays = "1" +ForwardDiff = "0.10.36" Functors = "0.4.4" GPUArrays = "10" -LuxAMDGPU = "0.2.2" LuxCUDA = "0.3.2" LuxCore = "0.1.4" Metal = "1" +Pkg = "1.10" PrecompileTools = "1.2" Preferences = "1.4" Random = "1.10" RecursiveArrayTools = "3.8" +ReverseDiff = "1.15" SafeTestsets = "0.1" SparseArrays = "1.10" Test = "1.10" TestSetExtensions = "3" +Tracker = "0.2.34" Zygote = "0.6.69" julia = "1.10" +oneAPI = "1.5" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" -LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" -LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" -Metal = "dde4c033-4e86-420c-a63e-0dd931031962" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestSetExtensions = "98d24dd4-01ad-11ea-1b02-c9a08f80db04" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "FillArrays", "LuxAMDGPU", "LuxCUDA", "LuxCore", "Metal", "Random", "RecursiveArrayTools", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Zygote"] +test = ["Aqua", "ArrayInterface", "ChainRulesTestUtils", "ComponentArrays", "ExplicitImports", "FillArrays", "ForwardDiff", "LuxCore", "Pkg", "Random", "RecursiveArrayTools", "ReverseDiff", "SafeTestsets", "SparseArrays", "Test", "TestSetExtensions", "Tracker", "Zygote"] diff --git a/README.md b/README.md index 8830b4b..6b67043 100644 --- a/README.md +++ b/README.md @@ -7,7 +7,6 @@ [![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml) [![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl) [![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl) -[![Package Downloads](https://shields.io/endpoint?url=https://pkgs.genieframework.com/api/v1/badge/LuxDeviceUtils)](https://pkgs.genieframework.com?packages=LuxDeviceUtils) [![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl) [![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac) @@ -15,3 +14,10 @@ `LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead. + +Currently we provide support for the following backends: + +1. `CUDA.jl` for NVIDIA GPUs. +2. `AMDGPU.jl` for AMD ROCM GPUs. +3. `Metal.jl` for Apple Metal GPUs. **(Experimental)** +4. `oneAPI.jl` for Intel GPUs. **(Experimental)** diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..0398f92 --- /dev/null +++ b/codecov.yml @@ -0,0 +1,3 @@ +codecov: + notify: + wait_for_ci: false diff --git a/ext/LuxDeviceUtilsAMDGPUExt.jl b/ext/LuxDeviceUtilsAMDGPUExt.jl index c88619a..93a8c84 100644 --- a/ext/LuxDeviceUtilsAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsAMDGPUExt.jl @@ -2,13 +2,35 @@ module LuxDeviceUtilsAMDGPUExt using Adapt: Adapt using AMDGPU: AMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUAdaptor, LuxAMDGPUDevice, LuxCPUAdaptor -using Random: Random, AbstractRNG +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCPUDevice, reset_gpu_device! +using Random: Random + +__init__() = reset_gpu_device!() + +# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. +const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) + +function _check_use_amdgpu!() + USE_AMD_GPU[] === nothing || return + + USE_AMD_GPU[] = AMDGPU.functional() + if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) + @warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ + available." maxlog=1 + end + return +end + +LuxDeviceUtils.loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}})::Bool + _check_use_amdgpu!() + return USE_AMD_GPU[] +end function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, ::Nothing) return LuxAMDGPUDevice(nothing) end -function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Int) +function LuxDeviceUtils._with_device(::Type{LuxAMDGPUDevice}, id::Integer) id > length(AMDGPU.devices()) && throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) old_dev = AMDGPU.device() @@ -24,37 +46,36 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng() # Query Device from Array -LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) = LuxAMDGPUDevice(AMDGPU.device(x)) +function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray) + parent_x = parent(x) + parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x)) + return LuxDeviceUtils.get_device(parent_x) +end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice) - if !AMDGPU.functional() - @warn "AMDGPU is not functional." - return - end - AMDGPU.device!(dev) - return + return AMDGPU.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Int) - LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) - return +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, id::Integer) + return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, AMDGPU.devices()[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Int) +function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(AMDGPU.devices())) return LuxDeviceUtils.set_device!(LuxAMDGPUDevice, id) end # Device Transfer ## To GPU -Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, x) = AMDGPU.roc(x) -function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x) +Adapt.adapt_storage(::LuxAMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) +function Adapt.adapt_storage(to::LuxAMDGPUDevice, x::AbstractArray) old_dev = AMDGPU.device() # remember the current device - if !(x isa AMDGPU.AnyROCArray) + dev = LuxDeviceUtils.get_device(x) + if !(dev isa LuxAMDGPUDevice) AMDGPU.device!(to.device) x_new = AMDGPU.roc(x) AMDGPU.device!(old_dev) return x_new - elseif AMDGPU.device_id(AMDGPU.device(x)) == AMDGPU.device_id(to.device) + elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) return x else AMDGPU.device!(to.device) @@ -63,13 +84,7 @@ function Adapt.adapt_storage(to::LuxAMDGPUAdaptor, x) return x_new end end -Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng -Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng -function Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG) - return AMDGPU.rocrand_rng() -end -Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng() -Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() end diff --git a/ext/LuxDeviceUtilsCUDAExt.jl b/ext/LuxDeviceUtilsCUDAExt.jl index ae6a45f..29ff65c 100644 --- a/ext/LuxDeviceUtilsCUDAExt.jl +++ b/ext/LuxDeviceUtilsCUDAExt.jl @@ -1,11 +1,12 @@ module LuxDeviceUtilsCUDAExt using Adapt: Adapt -using CUDA: CUDA, CUSPARSE -using LuxDeviceUtils: LuxDeviceUtils, LuxCUDAAdaptor, LuxCUDADevice, LuxCPUAdaptor -using Random: Random, AbstractRNG +using CUDA: CUDA +using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector +using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, LuxCPUDevice +using Random: Random -function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Int) +function LuxDeviceUtils._with_device(::Type{LuxCUDADevice}, id::Integer) id > length(CUDA.devices()) && throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) old_dev = CUDA.device() @@ -25,41 +26,38 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) + LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng() # Query Device from Array -LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) = LuxCUDADevice(CUDA.device(x)) +function LuxDeviceUtils.get_device(x::CUDA.AnyCuArray) + parent_x = parent(x) + parent_x === x && return LuxCUDADevice(CUDA.device(x)) + return LuxDeviceUtils.get_device(parent_x) +end +function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) + return LuxCUDADevice(CUDA.device(x.nzVal)) +end # Set Device function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice) - if !CUDA.functional() - @warn "CUDA is not functional." - return - end - CUDA.device!(dev) - return + return CUDA.device!(dev) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Int) - if !CUDA.functional() - @warn "CUDA is not functional." - return - end - CUDA.device!(id - 1) - return +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, id::Integer) + return LuxDeviceUtils.set_device!(LuxCUDADevice, collect(CUDA.devices())[id]) end -function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Int) +function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, ::Nothing, rank::Integer) id = mod1(rank + 1, length(CUDA.devices())) return LuxDeviceUtils.set_device!(LuxCUDADevice, id) end # Device Transfer -## To GPU -Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, x) = CUDA.cu(x) -function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) +Adapt.adapt_storage(::LuxCUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) +function Adapt.adapt_storage(to::LuxCUDADevice, x::AbstractArray) old_dev = CUDA.device() # remember the current device - if !(x isa CUDA.AnyCuArray) + dev = LuxDeviceUtils.get_device(x) + if !(dev isa LuxCUDADevice) CUDA.device!(to.device) x_new = CUDA.cu(x) CUDA.device!(old_dev) return x_new - elseif CUDA.device(x) == to.device + elseif dev.device == to.device return x else CUDA.device!(to.device) @@ -68,19 +66,20 @@ function Adapt.adapt_storage(to::LuxCUDAAdaptor, x) return x_new end end -Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng -Adapt.adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng -function Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG) - return CUDA.default_rng() -end -Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng() -Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUDevice, rng::CUDA.RNG) = Random.default_rng() -## To CPU -## FIXME: Use SparseArrays to preserve the sparsity -function Adapt.adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) - return Adapt.adapt(Array, x) +# Defining as extensions seems to case precompilation errors +@static if isdefined(CUDA.CUSPARSE, :SparseArrays) + function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseMatrix) + return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) + end + function Adapt.adapt_storage(::LuxCPUDevice, x::AbstractCuSparseVector) + return CUDA.CUSPARSE.SparseArrays.SparseVector(x) + end +else + @warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ + an issue in LuxDeviceUtils.jl repository." end end diff --git a/ext/LuxDeviceUtilsFillArraysExt.jl b/ext/LuxDeviceUtilsFillArraysExt.jl index 879d380..b596233 100644 --- a/ext/LuxDeviceUtilsFillArraysExt.jl +++ b/ext/LuxDeviceUtilsFillArraysExt.jl @@ -1,14 +1,10 @@ module LuxDeviceUtilsFillArraysExt using Adapt: Adapt -using FillArrays: FillArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxCPUAdaptor +using FillArrays: FillArrays, AbstractFill +using LuxDeviceUtils: LuxDeviceUtils, LuxCPUDevice, AbstractLuxDevice -Adapt.adapt_structure(::LuxCPUAdaptor, x::FillArrays.AbstractFill) = x - -function Adapt.adapt_structure( - to::LuxDeviceUtils.AbstractLuxDeviceAdaptor, x::FillArrays.AbstractFill) - return Adapt.adapt(to, collect(x)) -end +Adapt.adapt_structure(::LuxCPUDevice, x::AbstractFill) = x +Adapt.adapt_structure(to::AbstractLuxDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) end diff --git a/ext/LuxDeviceUtilsGPUArraysExt.jl b/ext/LuxDeviceUtilsGPUArraysExt.jl index 7d72484..1e8f9f9 100644 --- a/ext/LuxDeviceUtilsGPUArraysExt.jl +++ b/ext/LuxDeviceUtilsGPUArraysExt.jl @@ -2,9 +2,9 @@ module LuxDeviceUtilsGPUArraysExt using Adapt: Adapt using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxCPUAdaptor +using LuxDeviceUtils: LuxCPUDevice using Random: Random -Adapt.adapt_storage(::LuxCPUAdaptor, rng::GPUArrays.RNG) = Random.default_rng() +Adapt.adapt_storage(::LuxCPUDevice, rng::GPUArrays.RNG) = Random.default_rng() end diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl deleted file mode 100644 index 15fcb9f..0000000 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ /dev/null @@ -1,13 +0,0 @@ -module LuxDeviceUtilsLuxAMDGPUExt - -using LuxAMDGPU: LuxAMDGPU -using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, reset_gpu_device! - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.__is_loaded(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxAMDGPUDevice, <:Type{LuxAMDGPUDevice}}) - return LuxAMDGPU.functional() -end - -end diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index 4e386ad..4870710 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -5,8 +5,8 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxCUDADevice, reset_gpu_device! __init__() = reset_gpu_device!() -LuxDeviceUtils.__is_loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) +LuxDeviceUtils.loaded(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) return LuxCUDA.functional() end diff --git a/ext/LuxDeviceUtilsMetalExt.jl b/ext/LuxDeviceUtilsMetalExt.jl new file mode 100644 index 0000000..908de28 --- /dev/null +++ b/ext/LuxDeviceUtilsMetalExt.jl @@ -0,0 +1,25 @@ +module LuxDeviceUtilsMetalExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using LuxDeviceUtils: LuxDeviceUtils, LuxMetalDevice, reset_gpu_device! +using Metal: Metal, MtlArray + +__init__() = reset_gpu_device!() + +LuxDeviceUtils.loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) + return Metal.functional() +end + +# Default RNG +LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) + +# Query Device from Array +LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() + +# Device Transfer +## To GPU +Adapt.adapt_storage(::LuxMetalDevice, x::AbstractArray) = Metal.mtl(x) + +end diff --git a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl b/ext/LuxDeviceUtilsMetalGPUArraysExt.jl deleted file mode 100644 index 5cdd530..0000000 --- a/ext/LuxDeviceUtilsMetalGPUArraysExt.jl +++ /dev/null @@ -1,30 +0,0 @@ -module LuxDeviceUtilsMetalGPUArraysExt - -using Adapt: Adapt -using GPUArrays: GPUArrays -using LuxDeviceUtils: LuxDeviceUtils, LuxMetalAdaptor, LuxMetalDevice, reset_gpu_device! -using Metal: Metal, MtlArray -using Random: Random, AbstractRNG - -__init__() = reset_gpu_device!() - -LuxDeviceUtils.__is_loaded(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = true -function LuxDeviceUtils.__is_functional(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) - return Metal.functional() -end - -# Default RNG -LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray) - -# Query Device from Array -LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice() - -# Device Transfer -## To GPU -Adapt.adapt_storage(::LuxMetalAdaptor, x) = Metal.mtl(x) -Adapt.adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng -function Adapt.adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) - return GPUArrays.default_rng(MtlArray) -end - -end diff --git a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl index 06279e2..78aec5e 100644 --- a/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl +++ b/ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl @@ -1,17 +1,21 @@ module LuxDeviceUtilsRecursiveArrayToolsExt using Adapt: Adapt, adapt -using LuxDeviceUtils: AbstractLuxDeviceAdaptor +using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice using RecursiveArrayTools: VectorOfArray, DiffEqArray # We want to preserve the structure -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::VectorOfArray) +function Adapt.adapt_structure(to::AbstractLuxDevice, x::VectorOfArray) return VectorOfArray(map(Base.Fix1(adapt, to), x.u)) end -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::DiffEqArray) +function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray) # Don't move the `time` to the GPU return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t) end +function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray}) + return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u) +end + end diff --git a/ext/LuxDeviceUtilsReverseDiffExt.jl b/ext/LuxDeviceUtilsReverseDiffExt.jl new file mode 100644 index 0000000..a683b3e --- /dev/null +++ b/ext/LuxDeviceUtilsReverseDiffExt.jl @@ -0,0 +1,13 @@ +module LuxDeviceUtilsReverseDiffExt + +using LuxDeviceUtils: LuxDeviceUtils +using ReverseDiff: ReverseDiff + +@inline function LuxDeviceUtils.get_device(x::ReverseDiff.TrackedArray) + return LuxDeviceUtils.get_device(ReverseDiff.value(x)) +end +@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal}) + return LuxDeviceUtils.get_device(ReverseDiff.value.(x)) +end + +end diff --git a/ext/LuxDeviceUtilsSparseArraysExt.jl b/ext/LuxDeviceUtilsSparseArraysExt.jl index 2f20e9e..f337d2f 100644 --- a/ext/LuxDeviceUtilsSparseArraysExt.jl +++ b/ext/LuxDeviceUtilsSparseArraysExt.jl @@ -1,9 +1,9 @@ module LuxDeviceUtilsSparseArraysExt using Adapt: Adapt -using LuxDeviceUtils: LuxCPUAdaptor +using LuxDeviceUtils: LuxCPUDevice using SparseArrays: AbstractSparseArray -Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractSparseArray) = x +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractSparseArray) = x end diff --git a/ext/LuxDeviceUtilsTrackerExt.jl b/ext/LuxDeviceUtilsTrackerExt.jl new file mode 100644 index 0000000..6746b9b --- /dev/null +++ b/ext/LuxDeviceUtilsTrackerExt.jl @@ -0,0 +1,26 @@ +module LuxDeviceUtilsTrackerExt + +using Adapt: Adapt +using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, + LuxoneAPIDevice +using Tracker: Tracker + +@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray) + return LuxDeviceUtils.get_device(Tracker.data(x)) +end +@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal}) + return LuxDeviceUtils.get_device(Tracker.data.(x)) +end + +@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true + +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) + @eval function Adapt.adapt_storage(to::$(T), x::AbstractArray{<:Tracker.TrackedReal}) + @warn "AbstractArray{<:Tracker.TrackedReal} is not supported for $(to). Converting \ + to Tracker.TrackedArray." maxlog=1 + return to(Tracker.collect(x)) + end +end + +end diff --git a/ext/LuxDeviceUtilsZygoteExt.jl b/ext/LuxDeviceUtilsZygoteExt.jl index 4f87b22..ae61dc4 100644 --- a/ext/LuxDeviceUtilsZygoteExt.jl +++ b/ext/LuxDeviceUtilsZygoteExt.jl @@ -1,13 +1,10 @@ module LuxDeviceUtilsZygoteExt using Adapt: Adapt -using LuxDeviceUtils: AbstractLuxDeviceAdaptor, LuxCPUAdaptor +using LuxDeviceUtils: AbstractLuxDevice, LuxCPUDevice using Zygote: OneElement -Adapt.adapt_structure(::LuxCPUAdaptor, x::OneElement) = x - -function Adapt.adapt_structure(to::AbstractLuxDeviceAdaptor, x::OneElement) - return Adapt.adapt(to, collect(x)) -end +Adapt.adapt_structure(::LuxCPUDevice, x::OneElement) = x +Adapt.adapt_structure(to::AbstractLuxDevice, x::OneElement) = Adapt.adapt(to, collect(x)) end diff --git a/ext/LuxDeviceUtilsoneAPIExt.jl b/ext/LuxDeviceUtilsoneAPIExt.jl new file mode 100644 index 0000000..00b8faa --- /dev/null +++ b/ext/LuxDeviceUtilsoneAPIExt.jl @@ -0,0 +1,44 @@ +module LuxDeviceUtilsoneAPIExt + +using Adapt: Adapt +using GPUArrays: GPUArrays +using LuxDeviceUtils: LuxDeviceUtils, LuxoneAPIDevice, reset_gpu_device! +using oneAPI: oneAPI, oneArray, oneL0 + +const SUPPORTS_FP64 = Dict{oneL0.ZeDevice, Bool}() + +function __init__() + reset_gpu_device!() + for dev in oneAPI.devices() + SUPPORTS_FP64[dev] = oneL0.module_properties(dev).fp64flags & + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 == + oneL0.ZE_DEVICE_MODULE_FLAG_FP64 + end +end + +LuxDeviceUtils.loaded(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = true +function LuxDeviceUtils.functional(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) + return oneAPI.functional() +end + +# Default RNG +LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneArray) + +# Query Device from Array +LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice() + +# Device Transfer +## To GPU +for (T1, T2) in ((Float64, Float32), (ComplexF64, ComplexF32)) + @eval function Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray{$(T1)}) + if !SUPPORTS_FP64[oneAPI.device()] + @warn LazyString( + "Double type is not supported on this device. Using `", $(T2), "` instead.") + return oneArray{$(T2)}(x) + end + return oneArray(x) + end +end +Adapt.adapt_storage(::LuxoneAPIDevice, x::AbstractArray) = oneArray(x) + +end diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index 775439c..d7b7b40 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -5,7 +5,6 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using Adapt: Adapt using ChainRulesCore: ChainRulesCore, NoTangent - using FastClosures: @closure using Functors: Functors, fmap using LuxCore: LuxCore using Preferences: @delete_preferences!, @load_preference, @set_preferences! @@ -17,15 +16,40 @@ const CRC = ChainRulesCore export gpu_backend!, supported_gpu_backends, reset_gpu_device! export default_device_rng export gpu_device, cpu_device -export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice -export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor +export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice export get_device abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end -__is_functional(x) = false -__is_loaded(x) = false +""" + functional(x::AbstractLuxDevice) -> Bool + functional(::Type{<:AbstractLuxDevice}) -> Bool + +Checks if the device is functional. This is used to determine if the device can be used for +computation. Note that even if the backend is loaded (as checked via +[`LuxDeviceUtils.loaded`](@ref)), the device may not be functional. + +Note that while this function is not exported, it is considered part of the public API. +""" +@inline functional(x) = false + +Base.@deprecate __is_functional(x) functional(x) + +""" + loaded(x::AbstractLuxDevice) -> Bool + loaded(::Type{<:AbstractLuxDevice}) -> Bool + +Checks if the trigger package for the device is loaded. Trigger packages are as follows: + + - `LuxCUDA.jl` for NVIDIA CUDA Support. + - `AMDGPU.jl` for AMD GPU ROCM Support. + - `Metal.jl` for Apple Metal GPU Support. + - `oneAPI.jl` for Intel oneAPI GPU Support. +""" +@inline loaded(x) = false + +Base.@deprecate __is_loaded(x) loaded(x) struct LuxCPUDevice <: AbstractLuxDevice end @kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice @@ -35,43 +59,34 @@ end device::D = nothing end struct LuxMetalDevice <: AbstractLuxGPUDevice end +struct LuxoneAPIDevice <: AbstractLuxGPUDevice end -_with_device(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice() -function _with_device(::Type{LuxCPUDevice}, device_id) - @warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1 - return LuxCPUDevice() -end - -_with_device(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice() -function _with_device(::Type{LuxMetalDevice}, device_id) - @warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1 - return LuxMetalDevice() +for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice) + @eval begin + _with_device(::Type{$dev}, ::Nothing) = $dev() + function _with_device(::Type{$dev}, device_id) + @warn "`device_id` is not applicable for `$dev`." maxlog=1 + return $dev() + end + end end -__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true - -_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU" -_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA" -_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU" -_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" - -_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "" -_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA" -_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU" -_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal" +@inline functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true +@inline loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true -_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor() -_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device) -_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device) -_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor() - -_get_device_id(::LuxCPUDevice) = nothing -_get_device_id(::LuxCUDADevice{Nothing}) = nothing -_get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing -_get_device_id(::LuxMetalDevice) = nothing +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + tpkg = name === :CPU ? "" : (name == :CUDA ? "Lux$(name)" : string(name)) + ldev = eval(Symbol(:Lux, name, :Device)) + @eval begin + @inline _get_device_name(::Union{$ldev, Type{<:$ldev}}) = $(string(name)) + @inline _get_triggerpkg_name(::Union{$ldev, Type{<:$ldev}}) = $(tpkg) + end +end -Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) +for T in (LuxCPUDevice, LuxCUDADevice{Nothing}, + LuxAMDGPUDevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) + @eval @inline _get_device_id(::$(T)) = nothing +end struct LuxDeviceSelectionException <: Exception end @@ -80,7 +95,7 @@ function Base.showerror(io::IO, ::LuxDeviceSelectionException) end # Order is important here -const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice) +const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice) const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) @@ -90,7 +105,7 @@ const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing) Resets the selected GPU device. This is useful when automatic GPU selection needs to be run again. """ -reset_gpu_device!() = (GPU_DEVICE[] = nothing) +@inline reset_gpu_device!() = (GPU_DEVICE[] = nothing) """ supported_gpu_backends() -> Tuple{String, ...} @@ -104,13 +119,13 @@ Return a tuple of supported GPU backends. !!! danger - `Metal.jl` support is **extremely** experimental and most things are not expected to - work. + `Metal.jl` and `oneAPI.jl` support is **extremely** experimental and most things are not + expected to work. """ -supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) +@inline supported_gpu_backends() = map(_get_device_name, GPU_DEVICES) """ - gpu_device(device_id::Union{Nothing, Int}=nothing; + gpu_device(device_id::Union{Nothing, Integer}=nothing; force_gpu_usage::Bool=false) -> AbstractLuxDevice() Selects GPU device based on the following criteria: @@ -126,24 +141,24 @@ Selects GPU device based on the following criteria: ## Arguments - - `device_id::Union{Nothing, Int}`: The device id to select. If `nothing`, then we return + - `device_id::Union{Nothing, Integer}`: The device id to select. If `nothing`, then we return the last selected device or if none was selected then we run the autoselection and choose the current device using `CUDA.device()` or `AMDGPU.device()` or similar. If - `Int`, then we select the device with the given id. Note that this is `1`-indexed, in + `Integer`, then we select the device with the given id. Note that this is `1`-indexed, in contrast to the `0`-indexed `CUDA.jl`. For example, `id = 4` corresponds to `CUDA.device!(3)`. !!! warning - `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal` and `CPU` - backends, `device_id` is ignored and a warning is printed. + `device_id` is only applicable for `CUDA` and `AMDGPU` backends. For `Metal`, `oneAPI` + and `CPU` backends, `device_id` is ignored and a warning is printed. ## Keyword Arguments - `force_gpu_usage::Bool`: If `true`, then an error is thrown if no functional GPU device is found. """ -function gpu_device(device_id::Union{Nothing, Int}=nothing; +function gpu_device(device_id::Union{Nothing, <:Integer}=nothing; force_gpu_usage::Bool=false)::AbstractLuxDevice device_id == 0 && throw(ArgumentError("`device_id` is 1-indexed.")) @@ -173,21 +188,21 @@ function _get_gpu_device(; force_gpu_usage::Bool) # If backend set with preferences, use it if backend !== nothing allowed_backends = supported_gpu_backends() - idx = findfirst(isequal(backend), allowed_backends) if backend ∉ allowed_backends @warn "`gpu_backend` preference is set to $backend, which is not a valid \ backend. Valid backends are $allowed_backends. Defaulting to automatic \ GPU Backend selection." maxlog=1 else @debug "Using GPU backend set in preferences: $backend." + idx = findfirst(isequal(backend), allowed_backends) device = GPU_DEVICES[idx] - if !__is_loaded(device) + if !loaded(device) @warn "Trying to use backend: $(_get_device_name(device)) but the trigger \ - package $(device.pkgid) is not loaded. Ignoring the Preferences \ - backend!!! Please load the package and call this function again to \ - respect the Preferences backend." maxlog=1 + package $(_get_triggerpkg_name(device)) is not loaded. Ignoring the \ + Preferences backend!!! Please load the package and call this \ + function again to respect the Preferences backend." maxlog=1 else - if __is_functional(device) + if functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else @@ -201,16 +216,16 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Running automatic GPU backend selection..." for device in GPU_DEVICES - if __is_loaded(device) + if loaded(device) @debug "Trying backend: $(_get_device_name(device))." - if __is_functional(device) + if functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device end @debug "GPU backend: $(_get_device_name(device)) is not functional." else @debug "Trigger package for backend ($(_get_device_name(device))): \ - $(_get_trigger_pkgname(device)) not loaded." + $(_get_triggerpkg_name(device)) not loaded." end end @@ -221,9 +236,10 @@ function _get_gpu_device(; force_gpu_usage::Bool) 1. If no GPU is available, nothing needs to be done. 2. If GPU is available, load the corresponding trigger package. - a. LuxCUDA.jl for NVIDIA CUDA Support. - b. LuxAMDGPU.jl for AMD GPU ROCM Support. - c. Metal.jl for Apple Metal GPU Support.""" maxlog=1 + a. `LuxCUDA.jl` for NVIDIA CUDA Support. + b. `AMDGPU.jl` for AMD GPU ROCM Support. + c. `Metal.jl` for Apple Metal GPU Support. (Experimental) + d. `oneAPI.jl` for Intel oneAPI GPU Support. (Experimental)""" maxlog=1 return LuxCPUDevice end end @@ -261,7 +277,9 @@ function gpu_backend!(backend::String) return end - @assert backend in allowed_backends "`gpu_backend` must be one of $(allowed_backends)" + if backend ∉ allowed_backends + throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends.")) + end @set_preferences!("gpu_backend"=>backend) @info "GPU backend has been set to $backend. Restart Julia to use the new backend." @@ -283,10 +301,11 @@ and states on the device using [WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl). """ function default_device_rng(D::AbstractLuxDevice) - return error("""`default_device_rng` not implemented for $(typeof(D)). This is either because: + return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \ + either because: 1. The default RNG for this device is not known / officially provided. - 2. The trigger package for the device is not loaded. + 2. The trigger package for the device ($(_get_device_name(D)).jl) is not loaded. """) end default_device_rng(::LuxCPUDevice) = Random.default_rng() @@ -295,38 +314,43 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng() # Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability # For all other types we rely on fmap which means we lose type stability. # For Lux, typically models only has these 3 datastructures so we should be mostly fine. -for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal) +for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) ldev = Symbol("Lux$(dev)Device") @eval begin - function (D::$(ldev))(x::AbstractArray) - ladaptor = _get_adaptor(D) - fn = Base.Fix1(Adapt.adapt, ladaptor) - return _isbitsarray(x) ? fn(x) : map(D, x) + function (D::$(ldev))(x::AbstractArray{T}) where {T} + fn = Base.Fix1(Adapt.adapt, D) + return isbitstype(T) || __special_aos(x) ? fn(x) : map(D, x) end (D::$(ldev))(x::Tuple) = map(D, x) (D::$(ldev))(x::NamedTuple{F}) where {F} = NamedTuple{F}(D(values(x))) function (D::$(ldev))(x) - ladaptor = _get_adaptor(D) - _isleaf(x) && return Adapt.adapt(ladaptor, x) - return fmap(Base.Fix1(Adapt.adapt, ladaptor), x; exclude=_isleaf) + Functors.isleaf(x) && return Adapt.adapt(D, x) + return fmap(D, x) end function (::$(ldev))(NN::LuxCore.AbstractExplicitLayer) @warn "Lux layers are stateless and hence don't participate in device \ transfers. Apply this function on the parameters and states generated \ - using `Lux.setup`." maxlog=1 + using `Lux.setup`." return NN end end end +@inline __special_aos(x::AbstractArray) = false + # Query Device from Array """ - get_device(x::AbstractArray) -> AbstractLuxDevice + get_device(x) -> AbstractLuxDevice | Exception | Nothing + +If all arrays (on the leaves of the structure) are on the same device, we return that +device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`. -Returns the device of the array `x`. Trigger Packages must be loaded for this to return the -correct device. +!!! note + + Trigger Packages must be loaded for this to return the correct device. """ -function get_device(x::AbstractArray) +function get_device(x::AbstractArray{T}) where {T} + !isbitstype(T) && return mapreduce(get_device, __combine_devices, x) if hasmethod(parent, Tuple{typeof(x)}) parent_x = parent(x) parent_x === x && return LuxCPUDevice() @@ -334,16 +358,35 @@ function get_device(x::AbstractArray) end return LuxCPUDevice() end +function get_device(x) + dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing) + _get_device(x) = (dev[] = __combine_devices(dev[], get_device(x))) + fmap(_get_device, x) + return dev[] +end +for T in (Number, AbstractRNG, Val, Symbol, String) + @eval get_device(::$(T)) = nothing +end +get_device(x::Tuple) = mapreduce(get_device, __combine_devices, x) +get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x)) CRC.@non_differentiable get_device(::Any...) +function __combine_devices(dev1, dev2) + dev1 === nothing && return dev2 + dev2 === nothing && return dev1 + dev1 != dev2 && + throw(ArgumentError("Objects are on different devices: $dev1 and $dev2.")) + return dev1 +end + # Set the device const SET_DEVICE_DOCS = """ Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice` and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not loaded. -Currently, `LuxMetalDevice` doesn't support setting the device. +Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device. """ const SET_DEVICE_DANGER = """ @@ -370,66 +413,68 @@ $SET_DEVICE_DANGER """ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice} T === LuxCUDADevice && - @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + @warn "`CUDA.jl` hasn't been loaded. Ignoring the device setting." T === LuxAMDGPUDevice && - @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1 + @warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." T === LuxMetalDevice && - @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1 + @warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." + T === LuxoneAPIDevice && + @warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." T === LuxCPUDevice && - @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1 + @warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." return end """ - set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Int) + set_device!(T::Type{<:AbstractLuxDevice}, ::Nothing, rank::Integer) $SET_DEVICE_DOCS ## Arguments - `T::Type{<:AbstractLuxDevice}`: The device type to set. - - `rank::Int`: Local Rank of the process. This is applicable for distributed training and + - `rank::Integer`: Local Rank of the process. This is applicable for distributed training and must be `0`-indexed. $SET_DEVICE_DANGER """ -function set_device!(::Type{T}, ::Nothing, rank::Int) where {T <: AbstractLuxDevice} +function set_device!(::Type{T}, ::Nothing, rank::Integer) where {T <: AbstractLuxDevice} return set_device!(T, rank) end # Adapt Interface -abstract type AbstractLuxDeviceAdaptor end -abstract type AbstractLuxGPUDeviceAdaptor <: AbstractLuxDeviceAdaptor end - -struct LuxCPUAdaptor <: AbstractLuxDeviceAdaptor end -struct LuxCUDAAdaptor{D} <: AbstractLuxGPUDeviceAdaptor - device::D -end -struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor - device::D +# In older versions we had corresponding Adapt functions, rn we directly dispatch on the +# device type. +for name in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI) + dev = Symbol(:Lux, name, :Device) + adaptor = Symbol(:Lux, name, :Adaptor) + @eval Base.@deprecate_binding $(adaptor) $(dev) true end -struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end -Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x -Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = Adapt.adapt(Array, x) -Adapt.adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractArray) = Adapt.adapt(Array, x) +Adapt.adapt_storage(::LuxCPUDevice, rng::AbstractRNG) = rng + +for T in (LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDevice, LuxoneAPIDevice) + @eval begin + function Adapt.adapt_storage(to::$(T), ::Random.TaskLocalRNG) + return default_device_rng(to) + end + Adapt.adapt_storage(::$(T), rng::AbstractRNG) = rng + end +end +Adapt.adapt_storage(::LuxCPUDevice, x::AbstractRange) = x # Prevent Ambiguity -for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor) +for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice, + LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice) @eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x)) end -_isbitsarray(::AbstractArray{<:Number}) = true -_isbitsarray(::AbstractArray{T}) where {T} = isbitstype(T) -_isbitsarray(x) = false - -_isleaf(::AbstractRNG) = true -_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x) - # Chain Rules Core -function CRC.rrule( - ::typeof(Adapt.adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray) - ∇adapt_storage = @closure Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) +function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray) + ∇adapt_storage = let x = x + Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ)) + end return Adapt.adapt_storage(to, x), ∇adapt_storage end diff --git a/test/amdgpu.jl b/test/amdgpu.jl index 9247fdb..159b241 100644 --- a/test/amdgpu.jl +++ b/test/amdgpu.jl @@ -1,23 +1,28 @@ using LuxDeviceUtils, Random +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxAMDGPUDevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxAMDGPUDevice(nothing)) + @test_logs (:warn, "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxAMDGPUDevice, nothing, 1) end -using LuxAMDGPU +using AMDGPU @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if LuxAMDGPU.functional() - @info "LuxAMDGPU is functional" + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + @info "AMDGPU is functional" @test gpu_device() isa LuxAMDGPUDevice @test gpu_device(; force_gpu_usage=true) isa LuxAMDGPUDevice else - @info "LuxAMDGPU is NOT functional" + @info "AMDGPU is NOT functional" @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) @@ -28,24 +33,33 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxAMDGPU.functional() ? ROCArray : Array - rngType = LuxAMDGPU.functional() ? AMDGPU.rocRAND.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? ROCArray : Array + rngType = LuxDeviceUtils.functional(LuxAMDGPUDevice) ? AMDGPU.rocRAND.RNG : + Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxAMDGPUDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxDeviceUtils.functional(LuxAMDGPUDevice) @test ps_xpu.one_elem isa ROCArray @test ps_xpu.farray isa ROCArray else @@ -54,44 +68,83 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxAMDGPU.functional() + if LuxDeviceUtils.functional(LuxAMDGPUDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + dev2 = gpu_device(length(AMDGPU.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end -if LuxAMDGPU.functional() - ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) - ps_cpu = deepcopy(ps) - cdev = cpu_device() - for idx in 1:length(AMDGPU.devices()) - amdgpu_device = gpu_device(idx) - @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice - @test AMDGPU.device_id(amdgpu_device.device) == idx - - global ps = ps |> amdgpu_device - @test ps.weight isa ROCArray - @test ps.bias isa ROCArray - @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx - @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx - @test isequal(cdev(ps.weight), ps_cpu.weight) - @test isequal(cdev(ps.bias), ps_cpu.bias) +@testset "Wrapped Arrays" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + x = rand(10, 10) |> LuxAMDGPUDevice() + @test get_device(x) isa LuxAMDGPUDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxAMDGPUDevice end +end + +@testset "Multiple Devices AMDGPU" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(AMDGPU.devices()) + amdgpu_device = gpu_device(idx) + @test typeof(amdgpu_device.device) <: AMDGPU.HIPDevice + @test AMDGPU.device_id(amdgpu_device.device) == idx - ps = ps |> cdev - @test ps.weight isa Array - @test ps.bias isa Array + ps = ps |> amdgpu_device + @test ps.weight isa ROCArray + @test ps.bias isa ROCArray + @test AMDGPU.device_id(AMDGPU.device(ps.weight)) == idx + @test AMDGPU.device_id(AMDGPU.device(ps.bias)) == idx + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxAMDGPUDevice) + for i in 1:10 + @test_nowarn LuxDeviceUtils.set_device!(LuxAMDGPUDevice, nothing, i) + end + end end diff --git a/test/component_arrays.jl b/test/component_arrays.jl deleted file mode 100644 index 3825a22..0000000 --- a/test/component_arrays.jl +++ /dev/null @@ -1,17 +0,0 @@ -using LuxDeviceUtils, ComponentArrays, Random - -@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin - dev = LuxCPUDevice() - ps = (; weight=randn(10, 1), bias=randn(1)) - - ps_ca = ps |> ComponentArray - - ps_ca_dev = ps_ca |> dev - - @test ps_ca_dev isa ComponentArray - - @test ps_ca_dev.weight == ps.weight - @test ps_ca_dev.bias == ps.bias - - @test ps_ca_dev == (ps |> dev |> ComponentArray) -end diff --git a/test/cuda.jl b/test/cuda.jl index e0dc343..8ae7e54 100644 --- a/test/cuda.jl +++ b/test/cuda.jl @@ -1,10 +1,15 @@ -using LuxDeviceUtils, Random +using LuxDeviceUtils, Random, Functors +using ArrayInterface: parameterless_type @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxCUDADevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxCUDADevice(nothing)) + @test_logs (:warn, "`CUDA.jl` hasn't been loaded. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxCUDADevice, nothing, 1) end using LuxCUDA @@ -12,7 +17,7 @@ using LuxCUDA @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @info "LuxCUDA is functional" @test gpu_device() isa LuxCUDADevice @test gpu_device(; force_gpu_usage=true) isa LuxCUDADevice @@ -28,24 +33,32 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = LuxCUDA.functional() ? CuArray : Array - rngType = LuxCUDA.functional() ? CUDA.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxCUDADevice) ? CuArray : Array + rngType = LuxDeviceUtils.functional(LuxCUDADevice) ? CUDA.RNG : Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxCUDADevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @test ps_xpu.one_elem isa CuArray @test ps_xpu.farray isa CuArray else @@ -54,44 +67,127 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if LuxCUDA.functional() + if LuxDeviceUtils.functional(LuxCUDADevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + struct MyStruct + x::Any + end + + Functors.@functor MyStruct + + data = MyStruct(rand(10)) + @test get_device(data) isa LuxCPUDevice + data_dev = data |> device + if LuxDeviceUtils.functional(LuxCUDADevice) + @test get_device(data_dev) isa LuxCUDADevice + else + @test get_device(data_dev) isa LuxCPUDevice + end + + ps_mixed = (; a=rand(2), c=(rand(2), 1), st=MyStruct(rand(2)), b=device(rand(2))) + @test get_device(ps_mixed.st) isa LuxCPUDevice + @test get_device(ps_mixed.c) isa LuxCPUDevice + @test_throws ArgumentError get_device(ps_mixed) + + dev = gpu_device() + x = rand(Float32, 10, 2) + x_dev = x |> dev + @test get_device(x_dev) isa parameterless_type(typeof(dev)) + + if LuxDeviceUtils.functional(LuxCUDADevice) + dev2 = gpu_device(length(CUDA.devices())) + x_dev2 = x_dev |> dev2 + @test get_device(x_dev2) isa typeof(dev2) + end end -if LuxCUDA.functional() - ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) - ps_cpu = deepcopy(ps) - cdev = cpu_device() - for idx in 1:length(CUDA.devices()) - cuda_device = gpu_device(idx) - @test typeof(cuda_device.device) <: CUDA.CuDevice - @test cuda_device.device.handle == (idx - 1) - - global ps = ps |> cuda_device - @test ps.weight isa CuArray - @test ps.bias isa CuArray - @test CUDA.device(ps.weight).handle == idx - 1 - @test CUDA.device(ps.bias).handle == idx - 1 - @test isequal(cdev(ps.weight), ps_cpu.weight) - @test isequal(cdev(ps.bias), ps_cpu.bias) +@testset "Wrapped Arrays" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + x = rand(10, 10) |> LuxCUDADevice() + @test get_device(x) isa LuxCUDADevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxCUDADevice end +end + +@testset "Multiple Devices CUDA" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + ps = (; weight=rand(Float32, 10), bias=rand(Float32, 10)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) - ps = ps |> cdev - @test ps.weight isa Array - @test ps.bias isa Array + ps = ps |> cuda_device + @test ps.weight isa CuArray + @test ps.bias isa CuArray + @test CUDA.device(ps.weight).handle == idx - 1 + @test CUDA.device(ps.bias).handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa Array + @test ps.bias isa Array + end +end + +using SparseArrays + +@testset "CUDA Sparse Arrays" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + ps = (; weight=sprand(Float32, 10, 10, 0.1), bias=sprand(Float32, 10, 0.1)) + ps_cpu = deepcopy(ps) + cdev = cpu_device() + for idx in 1:length(CUDA.devices()) + cuda_device = gpu_device(idx) + @test typeof(cuda_device.device) <: CUDA.CuDevice + @test cuda_device.device.handle == (idx - 1) + + ps = ps |> cuda_device + @test ps.weight isa CUSPARSE.CuSparseMatrixCSC + @test ps.bias isa CUSPARSE.CuSparseVector + @test get_device(ps.weight).device.handle == idx - 1 + @test get_device(ps.bias).device.handle == idx - 1 + @test isequal(cdev(ps.weight), ps_cpu.weight) + @test isequal(cdev(ps.bias), ps_cpu.bias) + end + + ps = ps |> cdev + @test ps.weight isa SparseMatrixCSC + @test ps.bias isa SparseVector + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxCUDADevice) + for i in 1:10 + @test_nowarn LuxDeviceUtils.set_device!(LuxCUDADevice, nothing, i) + end + end end diff --git a/test/explicit_imports.jl b/test/explicit_imports.jl index e87484c..6cf767e 100644 --- a/test/explicit_imports.jl +++ b/test/explicit_imports.jl @@ -1,7 +1,6 @@ # Load all trigger packages -import LuxAMDGPU, LuxCUDA, FillArrays, Metal, RecursiveArrayTools, SparseArrays, Zygote +import FillArrays, RecursiveArrayTools, SparseArrays, Zygote using ExplicitImports, LuxDeviceUtils @test check_no_implicit_imports(LuxDeviceUtils) === nothing -@test check_no_stale_explicit_imports( - LuxDeviceUtils; ignore=(:LuxCPUAdaptor, :LuxMetalAdaptor)) === nothing +@test check_no_stale_explicit_imports(LuxDeviceUtils) === nothing diff --git a/test/metal.jl b/test/metal.jl index 96c930e..5c500bf 100644 --- a/test/metal.jl +++ b/test/metal.jl @@ -1,10 +1,12 @@ using LuxDeviceUtils, Random @testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxMetalDevice) @test cpu_device() isa LuxCPUDevice @test gpu_device() isa LuxCPUDevice @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxMetalDevice()) end using Metal @@ -12,7 +14,7 @@ using Metal @testset "Loaded Trigger Package" begin @test LuxDeviceUtils.GPU_DEVICE[] === nothing - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @info "Metal is functional" @test gpu_device() isa LuxMetalDevice @test gpu_device(; force_gpu_usage=true) isa LuxMetalDevice @@ -28,24 +30,33 @@ end using FillArrays, Zygote # Extensions @testset "Data Transfer" begin - ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, d="string", + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, rng_default=Random.default_rng(), rng=MersenneTwister(), one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) device = gpu_device() - aType = Metal.functional() ? MtlArray : Array - rngType = Metal.functional() ? Metal.GPUArrays.RNG : Random.AbstractRNG + aType = LuxDeviceUtils.functional(LuxMetalDevice) ? MtlArray : Array + rngType = LuxDeviceUtils.functional(LuxMetalDevice) ? Metal.GPUArrays.RNG : + Random.AbstractRNG ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxMetalDevice @test ps_xpu.a.c isa aType @test ps_xpu.b isa aType @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType @test ps_xpu.e == ps.e @test ps_xpu.d == ps.d @test ps_xpu.rng_default isa rngType @test ps_xpu.rng == ps.rng - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @test ps_xpu.one_elem isa MtlArray @test ps_xpu.farray isa MtlArray else @@ -54,21 +65,47 @@ using FillArrays, Zygote # Extensions end ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice @test ps_cpu.a.c isa Array @test ps_cpu.b isa Array @test ps_cpu.a.c == ps.a.c @test ps_cpu.b == ps.b @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array @test ps_cpu.e == ps.e @test ps_cpu.d == ps.d @test ps_cpu.rng_default isa Random.TaskLocalRNG @test ps_cpu.rng == ps.rng - if Metal.functional() + if LuxDeviceUtils.functional(LuxMetalDevice) @test ps_cpu.one_elem isa Array @test ps_cpu.farray isa Array else @test ps_cpu.one_elem isa Zygote.OneElement @test ps_cpu.farray isa Fill end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) +end + +@testset "Wrapper Arrays" begin + if LuxDeviceUtils.functional(LuxMetalDevice) + x = rand(Float32, 10, 10) |> LuxMetalDevice() + @test get_device(x) isa LuxMetalDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxMetalDevice + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxMetalDevice) + @test_logs (:warn, + "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxMetalDevice, nothing, 1) + end end diff --git a/test/misc.jl b/test/misc.jl new file mode 100644 index 0000000..681f890 --- /dev/null +++ b/test/misc.jl @@ -0,0 +1,154 @@ +using Adapt, LuxDeviceUtils, ComponentArrays, Random +using ArrayInterface: parameterless_type +using ChainRulesTestUtils: test_rrule +using ReverseDiff, Tracker, ForwardDiff +using SparseArrays, FillArrays, Zygote, RecursiveArrayTools +using LuxCore + +@testset "https://github.com/LuxDL/LuxDeviceUtils.jl/issues/10 patch" begin + dev = LuxCPUDevice() + ps = (; weight=randn(10, 1), bias=randn(1)) + + ps_ca = ps |> ComponentArray + + ps_ca_dev = ps_ca |> dev + + @test ps_ca_dev isa ComponentArray + + @test ps_ca_dev.weight == ps.weight + @test ps_ca_dev.bias == ps.bias + + @test ps_ca_dev == (ps |> dev |> ComponentArray) +end + +@testset "AD Types" begin + x = randn(Float32, 10) + + x_rdiff = ReverseDiff.track(x) + @test get_device(x_rdiff) isa LuxCPUDevice + x_rdiff = ReverseDiff.track.(x) + @test get_device(x_rdiff) isa LuxCPUDevice + + gdev = gpu_device() + + x_tracker = Tracker.param(x) + @test get_device(x_tracker) isa LuxCPUDevice + x_tracker = Tracker.param.(x) + @test get_device(x_tracker) isa LuxCPUDevice + x_tracker_dev = Tracker.param(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + x_tracker_dev = Tracker.param.(x) |> gdev + @test get_device(x_tracker_dev) isa parameterless_type(typeof(gdev)) + + x_fdiff = ForwardDiff.Dual.(x) + @test get_device(x_fdiff) isa LuxCPUDevice + x_fdiff_dev = ForwardDiff.Dual.(x) |> gdev + @test get_device(x_fdiff_dev) isa parameterless_type(typeof(gdev)) +end + +@testset "CRC Tests" begin + dev = cpu_device() # Other devices don't work with FiniteDifferences.jl + test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true) + + gdev = gpu_device() + if !(gdev isa LuxMetalDevice) # On intel devices causes problems + x = randn(10) + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, gdev, x) + @test ∂dev === nothing + @test ∂x ≈ ones(10) + + x = randn(10) |> gdev + ∂dev, ∂x = Zygote.gradient(sum ∘ Adapt.adapt_storage, cpu_device(), x) + @test ∂dev === nothing + @test ∂x ≈ gdev(ones(10)) + @test get_device(∂x) isa parameterless_type(typeof(gdev)) + end +end + +# The following just test for noops +@testset "NoOps CPU" begin + cdev = cpu_device() + + @test cdev(sprand(10, 10, 0.9)) isa SparseMatrixCSC + @test cdev(1:10) isa AbstractRange + @test cdev(Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4))) isa Zygote.OneElement +end + +@testset "RecursiveArrayTools" begin + gdev = gpu_device() + + diffeqarray = DiffEqArray([rand(10) for _ in 1:10], rand(10)) + @test get_device(diffeqarray) isa LuxCPUDevice + + diffeqarray_dev = diffeqarray |> gdev + @test get_device(diffeqarray_dev) isa parameterless_type(typeof(gdev)) + + vecarray = VectorOfArray([rand(10) for _ in 1:10]) + @test get_device(vecarray) isa LuxCPUDevice + + vecarray_dev = vecarray |> gdev + @test get_device(vecarray_dev) isa parameterless_type(typeof(gdev)) +end + +@testset "CPU default rng" begin + @test default_device_rng(LuxCPUDevice()) isa Random.TaskLocalRNG +end + +@testset "CPU setdevice!" begin + @test_logs (:warn, + "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxCPUDevice, nothing, 1) +end + +@testset "get_device on Arrays" begin + x = rand(10, 10) + x_view = view(x, 1:5, 1:5) + + @test get_device(x) isa LuxCPUDevice + @test get_device(x_view) isa LuxCPUDevice + + struct MyArrayType <: AbstractArray{Float32, 2} + data::Array{Float32, 2} + end + + x_custom = MyArrayType(rand(10, 10)) + + @test get_device(x_custom) isa LuxCPUDevice +end + +@testset "loaded and functional" begin + @test LuxDeviceUtils.loaded(LuxCPUDevice) + @test LuxDeviceUtils.functional(LuxCPUDevice) +end + +@testset "writing to preferences" begin + @test_logs (:info, + "Deleted the local preference for `gpu_backend`. Restart Julia to use the new backend.") gpu_backend!() + + for backend in (:CUDA, :AMDGPU, :oneAPI, :Metal, LuxAMDGPUDevice(), + LuxCUDADevice(), LuxMetalDevice(), LuxoneAPIDevice()) + backend_name = backend isa Symbol ? string(backend) : + LuxDeviceUtils._get_device_name(backend) + @test_logs (:info, + "GPU backend has been set to $(backend_name). Restart Julia to use the new backend.") gpu_backend!(backend) + end + + gpu_backend!(:CUDA) + @test_logs (:info, "GPU backend is already set to CUDA. No action is required.") gpu_backend!(:CUDA) + + @test_throws ArgumentError gpu_backend!("my_backend") +end + +@testset "LuxCore warnings" begin + struct MyCustomLayer <: LuxCore.AbstractExplicitContainerLayer{(:layer,)} + layer::Any + end + + my_layer = MyCustomLayer(rand(10, 10)) + + dev = cpu_device() + @test_logs ( + :warn, "Lux layers are stateless and hence don't participate in device \ + transfers. Apply this function on the parameters and states generated \ + using `Lux.setup`.") dev(my_layer) +end diff --git a/test/oneapi.jl b/test/oneapi.jl new file mode 100644 index 0000000..619ef8d --- /dev/null +++ b/test/oneapi.jl @@ -0,0 +1,111 @@ +using LuxDeviceUtils, Random + +@testset "CPU Fallback" begin + @test !LuxDeviceUtils.functional(LuxoneAPIDevice) + @test cpu_device() isa LuxCPUDevice + @test gpu_device() isa LuxCPUDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + @test_throws Exception default_device_rng(LuxoneAPIDevice()) +end + +using oneAPI + +@testset "Loaded Trigger Package" begin + @test LuxDeviceUtils.GPU_DEVICE[] === nothing + + if LuxDeviceUtils.functional(LuxoneAPIDevice) + @info "oneAPI is functional" + @test gpu_device() isa LuxoneAPIDevice + @test gpu_device(; force_gpu_usage=true) isa LuxoneAPIDevice + else + @info "oneAPI is NOT functional" + @test gpu_device() isa LuxoneAPIDevice + @test_throws LuxDeviceUtils.LuxDeviceSelectionException gpu_device(; + force_gpu_usage=true) + end + @test LuxDeviceUtils.GPU_DEVICE[] !== nothing +end + +using FillArrays, Zygote # Extensions + +@testset "Data Transfer" begin + ps = (a=(c=zeros(10, 1), d=1), b=ones(10, 1), e=:c, + d="string", mixed=[2.0f0, 3.0, ones(2, 3)], # mixed array types + range=1:10, + rng_default=Random.default_rng(), rng=MersenneTwister(), + one_elem=Zygote.OneElement(2.0f0, (2, 3), (1:3, 1:4)), farray=Fill(1.0f0, (2, 3))) + + device = gpu_device() + aType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneArray : Array + rngType = LuxDeviceUtils.functional(LuxoneAPIDevice) ? oneAPI.GPUArrays.RNG : + Random.AbstractRNG + + ps_xpu = ps |> device + @test get_device(ps_xpu) isa LuxoneAPIDevice + @test ps_xpu.a.c isa aType + @test ps_xpu.b isa aType + @test ps_xpu.a.d == ps.a.d + @test ps_xpu.mixed isa Vector + @test ps_xpu.mixed[1] isa Float32 + @test ps_xpu.mixed[2] isa Float64 + @test ps_xpu.mixed[3] isa aType + @test ps_xpu.range isa aType + @test ps_xpu.e == ps.e + @test ps_xpu.d == ps.d + @test ps_xpu.rng_default isa rngType + @test ps_xpu.rng == ps.rng + + if LuxDeviceUtils.functional(LuxoneAPIDevice) + @test ps_xpu.one_elem isa oneArray + @test ps_xpu.farray isa oneArray + else + @test ps_xpu.one_elem isa Zygote.OneElement + @test ps_xpu.farray isa Fill + end + + ps_cpu = ps_xpu |> cpu_device() + @test get_device(ps_cpu) isa LuxCPUDevice + @test ps_cpu.a.c isa Array + @test ps_cpu.b isa Array + @test ps_cpu.a.c == ps.a.c + @test ps_cpu.b == ps.b + @test ps_cpu.a.d == ps.a.d + @test ps_cpu.mixed isa Vector + @test ps_cpu.mixed[1] isa Float32 + @test ps_cpu.mixed[2] isa Float64 + @test ps_cpu.mixed[3] isa Array + @test ps_cpu.range isa Array + @test ps_cpu.e == ps.e + @test ps_cpu.d == ps.d + @test ps_cpu.rng_default isa Random.TaskLocalRNG + @test ps_cpu.rng == ps.rng + + if LuxDeviceUtils.functional(LuxoneAPIDevice) + @test ps_cpu.one_elem isa Array + @test ps_cpu.farray isa Array + else + @test ps_cpu.one_elem isa Zygote.OneElement + @test ps_cpu.farray isa Fill + end + + ps_mixed = (; a=rand(2), b=device(rand(2))) + @test_throws ArgumentError get_device(ps_mixed) +end + +@testset "Wrapper Arrays" begin + if LuxDeviceUtils.functional(LuxoneAPIDevice) + x = rand(10, 10) |> LuxoneAPIDevice() + @test get_device(x) isa LuxoneAPIDevice + x_view = view(x, 1:5, 1:5) + @test get_device(x_view) isa LuxoneAPIDevice + end +end + +@testset "setdevice!" begin + if LuxDeviceUtils.functional(LuxoneAPIDevice) + @test_logs (:warn, + "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting.") LuxDeviceUtils.set_device!( + LuxoneAPIDevice, nothing, 1) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 8eba75f..d63a17c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,24 +1,33 @@ +import Pkg using Aqua, SafeTestsets, Test, LuxDeviceUtils, TestSetExtensions const GROUP = get(ENV, "GROUP", "NONE") @testset ExtendedTestSet "LuxDeviceUtils Tests" begin if GROUP == "CUDA" || GROUP == "ALL" + Pkg.add("LuxCUDA") @safetestset "CUDA" include("cuda.jl") end if GROUP == "AMDGPU" || GROUP == "ALL" + Pkg.add("AMDGPU") @safetestset "AMDGPU" include("amdgpu.jl") end if GROUP == "Metal" || GROUP == "ALL" + Pkg.add("Metal") @safetestset "Metal" include("metal.jl") end + if GROUP == "oneAPI" || GROUP == "ALL" + Pkg.add("oneAPI") + @safetestset "oneAPI" include("oneapi.jl") + end + @testset "Others" begin @testset "Aqua Tests" Aqua.test_all(LuxDeviceUtils) - @safetestset "Component Arrays" include("component_arrays.jl") + @safetestset "Misc Tests" include("misc.jl") @safetestset "Explicit Imports" include("explicit_imports.jl") end