Skip to content

Commit

Permalink
Add a get_device function
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 18, 2024
1 parent 5961172 commit 9a28dc2
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 58 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.14"
version = "0.1.15"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
23 changes: 4 additions & 19 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module LuxDeviceUtilsLuxAMDGPUExt

using ChainRulesCore, LuxAMDGPU, LuxDeviceUtils, Random
using LuxAMDGPU, LuxDeviceUtils, Random
import Adapt: adapt_storage, adapt
import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

Expand All @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()
# Default RNG
LuxDeviceUtils.default_device_rng(::LuxAMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
LuxDeviceUtils.get_device(::AMDGPU.AnyROCArray) = LuxAMDGPUDevice()

# Device Transfer
## To GPU
adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x)
Expand All @@ -20,21 +22,4 @@ adapt_storage(::LuxAMDGPUAdaptor, rng::Random.TaskLocalRNG) = AMDGPU.rocrand_rng

adapt_storage(::LuxCPUAdaptor, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

## Chain Rules
CRC.rrule(::Type{Array}, x::ROCArray) = Array(x), Δ -> (NoTangent(), roc(Δ))

function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::AMDGPU.AnyROCArray)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxAMDGPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

function CRC.rrule(::typeof(adapt_storage), to::LuxAMDGPUAdaptor, x::Array)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
23 changes: 4 additions & 19 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module LuxDeviceUtilsLuxCUDAExt

using ChainRulesCore, LuxCUDA, LuxDeviceUtils, Random
using LuxCUDA, LuxDeviceUtils, Random
import Adapt: adapt_storage, adapt
import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

Expand All @@ -12,6 +11,9 @@ LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()
# Default RNG
LuxDeviceUtils.default_device_rng(::LuxCUDADevice) = CUDA.default_rng()

# Query Device from Array
LuxDeviceUtils.get_device(::CUDA.AnyCuArray) = LuxCUDADevice()

# Device Transfer
## To GPU
adapt_storage(::LuxCUDAAdaptor, x) = cu(x)
Expand All @@ -23,21 +25,4 @@ adapt_storage(::LuxCPUAdaptor, rng::CUDA.RNG) = Random.default_rng()
## To CPU
adapt_storage(::LuxCPUAdaptor, x::CUSPARSE.AbstractCuSparseMatrix) = adapt(Array, x)

## Chain Rules
CRC.rrule(::Type{Array}, x::CuArray) = Array(x), Δ -> (NoTangent(), cu(Δ))

function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::CUDA.AnyCuArray)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCUDAAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

function CRC.rrule(::typeof(adapt_storage), to::LuxCUDAAdaptor, x::Array)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
23 changes: 4 additions & 19 deletions ext/LuxDeviceUtilsMetalGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
module LuxDeviceUtilsMetalGPUArraysExt

using ChainRulesCore, GPUArrays, LuxDeviceUtils, Metal, Random
using GPUArrays, LuxDeviceUtils, Metal, Random
import Adapt: adapt_storage, adapt
import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

Expand All @@ -12,27 +11,13 @@ LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()
# Default RNG
LuxDeviceUtils.default_device_rng(::LuxMetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
LuxDeviceUtils.get_device(::MtlArray) = LuxMetalDevice()

# Device Transfer
## To GPU
adapt_storage(::LuxMetalAdaptor, x) = mtl(x)
adapt_storage(::LuxMetalAdaptor, rng::AbstractRNG) = rng
adapt_storage(::LuxMetalAdaptor, rng::Random.TaskLocalRNG) = GPUArrays.default_rng(MtlArray)

## Chain Rules
CRC.rrule(::Type{Array}, x::MtlArray) = Array(x), Δ -> (NoTangent(), MtlArray(Δ))

function CRC.rrule(::typeof(adapt_storage), to::LuxCPUAdaptor, x::MtlArray)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxMetalAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

function CRC.rrule(::typeof(adapt_storage), to::LuxMetalAdaptor, x::Array)
function ∇adapt_storage(Δ)
return (NoTangent(), NoTangent(), adapt_storage(LuxCPUAdaptor(), Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end
20 changes: 20 additions & 0 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@ import PrecompileTools: @recompile_invalidations
@recompile_invalidations begin
using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
import Adapt: adapt, adapt_storage
import ChainRulesCore as CRC
end

export gpu_backend!, supported_gpu_backends, reset_gpu_device!
export default_device_rng
export gpu_device, cpu_device, LuxCPUDevice, LuxCUDADevice, LuxAMDGPUDevice, LuxMetalDevice
export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
export get_device

abstract type AbstractLuxDevice <: Function end
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end
Expand Down Expand Up @@ -255,6 +257,15 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal)
end
end

# Query Device from Array
"""
get_device(x::AbstractArray) -> AbstractLuxDevice
Returns the device of the array `x`. Trigger Packages must be loaded for this to return the
correct device.
"""
get_device(x::AbstractArray) = LuxCPUDevice()

# Adapt Interface
abstract type AbstractLuxDeviceAdaptor end

Expand All @@ -277,4 +288,13 @@ _isbitsarray(x) = false
_isleaf(::AbstractRNG) = true
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)

# Chain Rules Core
function CRC.rrule(::typeof(adapt_storage), to::AbstractLuxDeviceAdaptor, x::AbstractArray)
function ∇adapt_storage(Δ)
dev = get_device(x)
return (NoTangent(), NoTangent(), dev(Δ))
end
return adapt_storage(to, x), ∇adapt_storage
end

end

0 comments on commit 9a28dc2

Please sign in to comment.