Skip to content

Commit

Permalink
Setup code for oneAPI support
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 5, 2024
1 parent f169142 commit 62f0417
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 50 deletions.
6 changes: 4 additions & 2 deletions ext/LuxDeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,11 @@ end
Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::AbstractRNG) = rng
Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::AbstractRNG) = rng
function Adapt.adapt_storage(::LuxAMDGPUAdaptor{Nothing}, rng::Random.TaskLocalRNG)
return AMDGPU.rocrand_rng()
return LuxDeviceUtils.default_device_rng(LuxAMDGPUDevice(nothing))
end
function Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG)
return LuxDeviceUtils.default_device_rng(rng)
end
Adapt.adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng()

Adapt.adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

Expand Down
8 changes: 6 additions & 2 deletions ext/LuxDeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,19 @@ end
Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::AbstractRNG) = rng
Adapt.adapt_storage(::LuxCUDAAdaptor, rng::AbstractRNG) = rng
function Adapt.adapt_storage(::LuxCUDAAdaptor{Nothing}, rng::Random.TaskLocalRNG)
return CUDA.default_rng()
return LuxDeviceUtils.default_device_rng(LuxCUDADevice(nothing))
end
function Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG)
return LuxDeviceUtils.default_device_rng(LuxCUDADevice(nothing))
end
Adapt.adapt_storage(::LuxCUDAAdaptor, rng::Random.TaskLocalRNG) = CUDA.default_rng()

Adapt.adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()

## To CPU
## FIXME: Use SparseArrays to preserve the sparsity
function Adapt.adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix)
@warn "Currently we don't convert CUSPARSE matrices to CPU SparseArrays. Constructing \
a dense matrix instead." maxlog=1
return Adapt.adapt(Array, x)
end

Expand Down
2 changes: 1 addition & 1 deletion ext/LuxDeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice()
Adapt.adapt_storage(::LuxMetalAdaptor, x) = Metal.mtl(x)
Adapt.adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng
function Adapt.adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG)
return GPUArrays.default_rng(MtlArray)
return LuxDeviceUtils.default_device_rng(LuxMetalDevice())
end

end
98 changes: 53 additions & 45 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@ const CRC = ChainRulesCore
export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device
export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
export LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor
export get_device

abstract type AbstractLuxDevice <: Function end
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end

__is_functional(x) = false
__is_loaded(x) = false
@inline __is_functional(x) = false
@inline __is_loaded(x) = false

struct LuxCPUDevice <: AbstractLuxDevice end
@kwdef struct LuxCUDADevice{D} <: AbstractLuxGPUDevice
Expand All @@ -36,41 +36,44 @@ end
device::D = nothing
end
struct LuxMetalDevice <: AbstractLuxGPUDevice end
struct LuxoneAPIDevice <: AbstractLuxGPUDevice end

_with_device(::Type{LuxCPUDevice}, ::Nothing) = LuxCPUDevice()
function _with_device(::Type{LuxCPUDevice}, device_id)
@warn "`device_id` is not applicable for `LuxCPUDevice`." maxlog=1
return LuxCPUDevice()
end

_with_device(::Type{LuxMetalDevice}, ::Nothing) = LuxMetalDevice()
function _with_device(::Type{LuxMetalDevice}, device_id)
@warn "`device_id` is not applicable for `LuxMetalDevice`." maxlog=1
return LuxMetalDevice()
for dev in (LuxCPUDevice, LuxMetalDevice, LuxoneAPIDevice)
@eval begin
_with_device(::Type{$dev}, ::Nothing) = $dev()
function _with_device(::Type{$dev}, device_id)
@warn "`device_id` is not applicable for `$dev`." maxlog=1
return $dev()
end
end
end

__is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true
__is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true

_get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU"
_get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA"
_get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU"
_get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal"

_get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = ""
_get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA"
_get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU"
_get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal"

_get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor()
_get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device)
_get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device)
_get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor()

_get_device_id(::LuxCPUDevice) = nothing
_get_device_id(::LuxCUDADevice{Nothing}) = nothing
_get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing
_get_device_id(::LuxMetalDevice) = nothing
@inline __is_functional(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true
@inline __is_loaded(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = true

@inline _get_device_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = "CPU"
@inline _get_device_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "CUDA"
@inline _get_device_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "AMDGPU"
@inline _get_device_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal"
@inline _get_device_name(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = "oneAPI"

@inline _get_triggerpkg_name(::Union{LuxCPUDevice, Type{<:LuxCPUDevice}}) = ""
@inline _get_triggerpkg_name(::Union{LuxCUDADevice, Type{<:LuxCUDADevice}}) = "LuxCUDA"
@inline _get_triggerpkg_name(::Union{LuxAMDGPUDevice, Type{<:LuxAMDGPUDevice}}) = "LuxAMDGPU"
@inline _get_triggerpkg_name(::Union{LuxMetalDevice, Type{<:LuxMetalDevice}}) = "Metal"
@inline _get_triggerpkg_name(::Union{LuxoneAPIDevice, Type{<:LuxoneAPIDevice}}) = "oneAPI"

@inline _get_adaptor(::LuxCPUDevice) = LuxCPUAdaptor()
@inline _get_adaptor(dev::LuxCUDADevice) = LuxCUDAAdaptor(dev.device)
@inline _get_adaptor(dev::LuxAMDGPUDevice) = LuxAMDGPUAdaptor(dev.device)
@inline _get_adaptor(::LuxMetalDevice) = LuxMetalAdaptor()
@inline _get_adaptor(::LuxoneAPIDevice) = LuxoneAPIAdaptor()

@inline _get_device_id(::LuxCPUDevice) = nothing
@inline _get_device_id(::LuxCUDADevice{Nothing}) = nothing
@inline _get_device_id(::LuxAMDGPUDevice{Nothing}) = nothing
@inline _get_device_id(::LuxMetalDevice) = nothing
@inline _get_device_id(::LuxoneAPIDevice) = nothing

Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev))

Expand All @@ -81,7 +84,7 @@ function Base.showerror(io::IO, ::LuxDeviceSelectionException)
end

# Order is important here
const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice)
const GPU_DEVICES = (LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice, LuxoneAPIDevice)

const GPU_DEVICE = Ref{Union{Nothing, AbstractLuxDevice}}(nothing)

Expand All @@ -105,8 +108,8 @@ Return a tuple of supported GPU backends.
!!! danger
`Metal.jl` support is **extremely** experimental and most things are not expected to
work.
`Metal.jl` and `oneAPI.jl` support is **extremely** experimental and most things are not
expected to work.
"""
supported_gpu_backends() = map(_get_device_name, GPU_DEVICES)

Expand Down Expand Up @@ -222,9 +225,10 @@ function _get_gpu_device(; force_gpu_usage::Bool)
1. If no GPU is available, nothing needs to be done.
2. If GPU is available, load the corresponding trigger package.
a. LuxCUDA.jl for NVIDIA CUDA Support.
b. LuxAMDGPU.jl for AMD GPU ROCM Support.
c. Metal.jl for Apple Metal GPU Support.""" maxlog=1
a. `LuxCUDA.jl` for NVIDIA CUDA Support.
b. `LuxAMDGPU.jl` for AMD GPU ROCM Support.
c. `Metal.jl` for Apple Metal GPU Support.
d. `oneAPI.jl` for Intel oneAPI GPU Support.""" maxlog=1
return LuxCPUDevice
end
end
Expand Down Expand Up @@ -284,7 +288,8 @@ and states on the device using
[WeightInitializers.jl](https://github.com/LuxDL/WeightInitializers.jl).
"""
function default_device_rng(D::AbstractLuxDevice)
return error("""`default_device_rng` not implemented for $(typeof(D)). This is either because:
return error("""`default_device_rng` not implemented for `$(typeof(D))`. This is \
either because:
1. The default RNG for this device is not known / officially provided.
2. The trigger package for the device is not loaded.
Expand All @@ -296,7 +301,7 @@ default_device_rng(::LuxCPUDevice) = Random.default_rng()
# Abstract Array / Tuples / NamedTuples have special fast paths to facilitate type stability
# For all other types we rely on fmap which means we lose type stability.
# For Lux, typically models only has these 3 datastructures so we should be mostly fine.
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol("Lux$(dev)Device")
@eval begin
function (D::$(ldev))(x::AbstractArray)
Expand Down Expand Up @@ -406,6 +411,8 @@ function set_device!(::Type{T}, dev_or_id) where {T <: AbstractLuxDevice}
@warn "`AMDGPU.jl` hasn't been loaded. Ignoring the device setting." maxlog=1
T === LuxMetalDevice &&
@warn "Support for Multi Device Metal hasn't been implemented yet. Ignoring the device setting." maxlog=1
T === LuxoneAPIDevice &&
@warn "Support for Multi Device oneAPI hasn't been implemented yet. Ignoring the device setting." maxlog=1
T === LuxCPUDevice &&
@warn "Setting device for `LuxCPUDevice` doesn't make sense. Ignoring the device setting." maxlog=1
return
Expand Down Expand Up @@ -440,13 +447,14 @@ struct LuxAMDGPUAdaptor{D} <: AbstractLuxGPUDeviceAdaptor
device::D
end
struct LuxMetalAdaptor <: AbstractLuxGPUDeviceAdaptor end
struct LuxoneAPIAdaptor <: AbstractLuxGPUDeviceAdaptor end

Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractRange) = x
Adapt.adapt_storage(::LuxCPUAdaptor, x::AbstractArray) = Adapt.adapt(Array, x)
Adapt.adapt_storage(::LuxCPUAdaptor, rng::AbstractRNG) = rng

# Prevent Ambiguity
for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor)
for T in (LuxAMDGPUAdaptor, LuxCUDAAdaptor, LuxMetalAdaptor, LuxoneAPIAdaptor)
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
end

Expand Down

0 comments on commit 62f0417

Please sign in to comment.