Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement faster get_device_type #54

Merged
merged 8 commits into from
Jul 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.24"
version = "0.1.25"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -10,6 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
UnrolledUtilities = "0fe1646c-419e-43be-ac14-22321958931b"

[weakdeps]
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
Expand Down Expand Up @@ -51,7 +52,7 @@ ComponentArrays = "0.15.8"
ExplicitImports = "1.9.0"
FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.4"
Functors = "0.4.8"
GPUArrays = "10"
LuxCUDA = "0.3.2"
LuxCore = "0.1.4"
Expand All @@ -65,6 +66,7 @@ SafeTestsets = "0.1"
SparseArrays = "1.10"
Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
julia = "1.10"
oneAPI = "1.5"
Expand Down
6 changes: 4 additions & 2 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ LuxDeviceUtils._get_device_id(dev::LuxAMDGPUDevice) = AMDGPU.device_id(dev.devic
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
function LuxDeviceUtils.get_device(x::AMDGPU.AnyROCArray)
function LuxDeviceUtils._get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return LuxAMDGPUDevice(AMDGPU.device(x))
return LuxDeviceUtils.get_device(parent_x)
return LuxDeviceUtils._get_device(parent_x)
end

LuxDeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxAMDGPUDevice}, dev::AMDGPU.HIPDevice)
return AMDGPU.device!(dev)
Expand Down
9 changes: 7 additions & 2 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,20 @@ LuxDeviceUtils._get_device_id(dev::LuxCUDADevice) = CUDA.deviceid(dev.device) +
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Query Device from Array
function LuxDeviceUtils.get_device(x::CUDA.AnyCuArray)
function LuxDeviceUtils._get_device(x::CUDA.AnyCuArray)
parent_x = parent(x)
parent_x === x && return LuxCUDADevice(CUDA.device(x))
return LuxDeviceUtils.get_device(parent_x)
end
function LuxDeviceUtils.get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
function LuxDeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
return LuxCUDADevice(CUDA.device(x.nzVal))
end

function LuxDeviceUtils._get_device_type(::Union{
<:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray})
return LuxCUDADevice
end

# Set Device
function LuxDeviceUtils.set_device!(::Type{LuxCUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
Expand Down
4 changes: 3 additions & 1 deletion ext/LuxDeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ end
LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice()
LuxDeviceUtils._get_device(::MtlArray) = LuxMetalDevice()

LuxDeviceUtils._get_device_type(::MtlArray) = LuxMetalDevice

# Device Transfer
## To GPU
Expand Down
7 changes: 5 additions & 2 deletions ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray)
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

function LuxDeviceUtils.get_device(x::Union{VectorOfArray, DiffEqArray})
return mapreduce(LuxDeviceUtils.get_device, LuxDeviceUtils.__combine_devices, x.u)
for op in (:_get_device, :_get_device_type)
@eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing)
return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u)
end
end

end
14 changes: 9 additions & 5 deletions ext/LuxDeviceUtilsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,15 @@ module LuxDeviceUtilsReverseDiffExt
using LuxDeviceUtils: LuxDeviceUtils
using ReverseDiff: ReverseDiff

@inline function LuxDeviceUtils.get_device(x::ReverseDiff.TrackedArray)
return LuxDeviceUtils.get_device(ReverseDiff.value(x))
end
@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:ReverseDiff.TrackedReal})
return LuxDeviceUtils.get_device(ReverseDiff.value.(x))
for op in (:_get_device, :_get_device_type)
@eval begin
function LuxDeviceUtils.$op(x::ReverseDiff.TrackedArray)
return LuxDeviceUtils.$op(ReverseDiff.value(x))
end
function LuxDeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal})
return LuxDeviceUtils.$op(ReverseDiff.value.(x))
end
end
end

end
14 changes: 8 additions & 6 deletions ext/LuxDeviceUtilsTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,16 @@ using LuxDeviceUtils: LuxDeviceUtils, LuxAMDGPUDevice, LuxCUDADevice, LuxMetalDe
LuxoneAPIDevice
using Tracker: Tracker

@inline function LuxDeviceUtils.get_device(x::Tracker.TrackedArray)
return LuxDeviceUtils.get_device(Tracker.data(x))
end
@inline function LuxDeviceUtils.get_device(x::AbstractArray{<:Tracker.TrackedReal})
return LuxDeviceUtils.get_device(Tracker.data.(x))
for op in (:_get_device, :_get_device_type)
@eval begin
LuxDeviceUtils.$op(x::Tracker.TrackedArray) = LuxDeviceUtils.$op(Tracker.data(x))
function LuxDeviceUtils.$op(x::AbstractArray{<:Tracker.TrackedReal})
return LuxDeviceUtils.$op(Tracker.data.(x))
end
end
end

@inline LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true
LuxDeviceUtils.__special_aos(::AbstractArray{<:Tracker.TrackedReal}) = true

for T in (LuxAMDGPUDevice, LuxAMDGPUDevice{Nothing}, LuxCUDADevice,
LuxCUDADevice{Nothing}, LuxMetalDevice, LuxoneAPIDevice)
Expand Down
4 changes: 3 additions & 1 deletion ext/LuxDeviceUtilsoneAPIExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@ end
LuxDeviceUtils.default_device_rng(::LuxoneAPIDevice) = GPUArrays.default_rng(oneArray)

# Query Device from Array
LuxDeviceUtils.get_device(::oneArray) = LuxoneAPIDevice()
LuxDeviceUtils._get_device(::oneArray) = LuxoneAPIDevice()

LuxDeviceUtils._get_device_type(::oneArray) = LuxoneAPIDevice

# Device Transfer
## To GPU
Expand Down
109 changes: 77 additions & 32 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,19 @@ module LuxDeviceUtils

using Adapt: Adapt
using ChainRulesCore: ChainRulesCore, NoTangent
using Functors: Functors, fmap
using Functors: Functors, fmap, fleaves
using LuxCore: LuxCore
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
using Random: AbstractRNG, Random
using UnrolledUtilities: unrolled_mapreduce

const CRC = ChainRulesCore

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device
export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice
export get_device
export get_device, get_device_type

abstract type AbstractLuxDevice <: Function end
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end
Expand Down Expand Up @@ -335,54 +336,98 @@ end

@inline __special_aos(x::AbstractArray) = false

const GET_DEVICE_ADMONITIONS = """
!!! note

Trigger Packages must be loaded for this to return the correct device.

!!! warning

RNG types currently don't participate in device determination. We will remove this
restriction in the future.
"""

# Query Device from Array
"""
get_device(x) -> AbstractLuxDevice | Exception | Nothing
get_device(x) -> dev::AbstractLuxDevice | Exception | nothing

If all arrays (on the leaves of the structure) are on the same device, we return that
device. Otherwise, we throw an error. If the object is device agnostic, we return `nothing`.

!!! note
$(GET_DEVICE_ADMONITIONS)

See also [`get_device_type`](@ref) for a faster alternative that can be used for dispatch
based on device type.
"""
function get_device end

Trigger Packages must be loaded for this to return the correct device.
"""
function get_device(x::AbstractArray{T}) where {T}
!isbitstype(T) && return mapreduce(get_device, __combine_devices, x)
if hasmethod(parent, Tuple{typeof(x)})
parent_x = parent(x)
parent_x === x && return LuxCPUDevice()
return get_device(parent_x)
get_device_type(x) -> Type{<:AbstractLuxDevice} | Exception | Type{Nothing}

Similar to [`get_device`](@ref) but returns the type of the device instead of the device
itself. This value is often a compile time constant and is recommended to be used instead
of [`get_device`](@ref) where ever defining dispatches based on the device type.

$(GET_DEVICE_ADMONITIONS)
"""
function get_device_type end

for op in (:get_device, :get_device_type)
_op = Symbol("_", op)
cpu_ret_val = op == :get_device ? LuxCPUDevice() : LuxCPUDevice
@eval begin
function $(op)(x)
hasmethod($(_op), Tuple{typeof(x)}) && return $(_op)(x)
return mapreduce($(_op), __combine_devices, fleaves(x))
end

CRC.@non_differentiable $op(::Any)

function $(_op)(x::AbstractArray{T}) where {T}
__recursible_array_eltype(T) && return mapreduce($(op), __combine_devices, x)
if hasmethod(parent, Tuple{typeof(x)})
parent_x = parent(x)
parent_x === x && return $(cpu_ret_val)
return $(_op)(parent_x)
end
return $(cpu_ret_val)
end

function $(_op)(x::Union{Tuple, NamedTuple})
length(x) == 0 && return $(op == :get_device ? nothing : Nothing)
return unrolled_mapreduce($(op), __combine_devices, values(x))
end
end
return LuxCPUDevice()
end
function get_device(x)
dev = Ref{Union{AbstractLuxDevice, Nothing}}(nothing)
_get_device(x) = (dev[] = __combine_devices(dev[], get_device(x)))
fmap(_get_device, x)
return dev[]
end
for T in (Number, AbstractRNG, Val, Symbol, String)
@eval get_device(::$(T)) = nothing
end
get_device(x::Tuple) = mapreduce(get_device, __combine_devices, x)
get_device(x::NamedTuple) = mapreduce(get_device, __combine_devices, values(x))

CRC.@non_differentiable get_device(::Any...)
for T in (Number, AbstractRNG, Val, Symbol, String)
@eval $(_op)(::$(T)) = $(op == :get_device ? nothing : Nothing)
end
end

function __combine_devices(dev1, dev2)
dev1 === nothing && return dev2
dev2 === nothing && return dev1
dev1 != dev2 &&
throw(ArgumentError("Objects are on different devices: $dev1 and $dev2."))
return dev1
__recursible_array_eltype(::Type{T}) where {T} = !isbitstype(T) && !(T <: Number)

__combine_devices(::Nothing, ::Nothing) = nothing
__combine_devices(::Type{Nothing}, ::Type{Nothing}) = Nothing
__combine_devices(::Nothing, dev::AbstractLuxDevice) = dev
__combine_devices(::Type{Nothing}, ::Type{T}) where {T <: AbstractLuxDevice} = T
__combine_devices(dev::AbstractLuxDevice, ::Nothing) = dev
__combine_devices(::Type{T}, ::Type{Nothing}) where {T <: AbstractLuxDevice} = T
function __combine_devices(dev1::AbstractLuxDevice, dev2::AbstractLuxDevice)
dev1 == dev2 && return dev1
throw(ArgumentError("Objects are on different devices: $(dev1) and $(dev2)."))
end
__combine_devices(::Type{T}, ::Type{T}) where {T <: AbstractLuxDevice} = T
function __combine_devices(
::Type{T1}, ::Type{T2}) where {T1 <: AbstractLuxDevice, T2 <: AbstractLuxDevice}
throw(ArgumentError("Objects are on devices with different types: $(T1) and $(T2)."))
end

# Set the device
const SET_DEVICE_DOCS = """
Set the device for the given type. This is a no-op for `LuxCPUDevice`. For `LuxCUDADevice`
and `LuxAMDGPUDevice`, it prints a warning if the corresponding trigger package is not
loaded.

Currently, `LuxMetalDevice` and `LuxoneAPIDevice` doesn't support setting the device.
"""

Expand Down
17 changes: 17 additions & 0 deletions test/amdgpu_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ using FillArrays, Zygote # Extensions

ps_xpu = ps |> device
@test get_device(ps_xpu) isa LuxAMDGPUDevice
@test get_device_type(ps_xpu) <: LuxAMDGPUDevice
@test ps_xpu.a.c isa aType
@test ps_xpu.b isa aType
@test ps_xpu.a.d == ps.a.d
Expand All @@ -69,6 +70,7 @@ using FillArrays, Zygote # Extensions

ps_cpu = ps_xpu |> cpu_device()
@test get_device(ps_cpu) isa LuxCPUDevice
@test get_device_type(ps_cpu) <: LuxCPUDevice
@test ps_cpu.a.c isa Array
@test ps_cpu.b isa Array
@test ps_cpu.a.c == ps.a.c
Expand Down Expand Up @@ -99,20 +101,35 @@ using FillArrays, Zygote # Extensions
x = rand(Float32, 10, 2)
x_dev = x |> dev
@test get_device(x_dev) isa parameterless_type(typeof(dev))
@test get_device_type(x_dev) <: parameterless_type(typeof(dev))

if LuxDeviceUtils.functional(LuxAMDGPUDevice)
dev2 = gpu_device(length(AMDGPU.devices()))
x_dev2 = x_dev |> dev2
@test get_device(x_dev2) isa typeof(dev2)
@test get_device_type(x_dev2) <: parameterless_type(typeof(dev2))
end

@testset "get_device_type compile constant" begin
x = rand(10, 10) |> device
ps = (; weight=x, bias=x, d=(x, x))

return_val(x) = Val(get_device_type(x)) # If it is a compile time constant then type inference will work
@test @inferred(return_val(ps)) isa Val{parameterless_type(typeof(device))}

return_val2(x) = Val(get_device(x))
@test_throws ErrorException @inferred(return_val2(ps))
end
end

@testset "Wrapped Arrays" begin
if LuxDeviceUtils.functional(LuxAMDGPUDevice)
x = rand(10, 10) |> LuxAMDGPUDevice()
@test get_device(x) isa LuxAMDGPUDevice
@test get_device_type(x) <: LuxAMDGPUDevice
x_view = view(x, 1:5, 1:5)
@test get_device(x_view) isa LuxAMDGPUDevice
@test get_device_type(x_view) <: LuxAMDGPUDevice
end
end

Expand Down
Loading
Loading