-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor!: rename package to DeviceUtils.jl
BREAKING CHANGE: All "Lux" prefixes have been dropped for wider adoption
- Loading branch information
Showing
33 changed files
with
625 additions
and
604 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
Manifest.toml | ||
*.cov | ||
generated | ||
build | ||
.vscode | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
name = "LuxDeviceUtils" | ||
name = "DeviceUtils" | ||
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" | ||
authors = ["Avik Pal <[email protected]> and contributors"] | ||
version = "0.1.26" | ||
|
@@ -17,28 +17,28 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" | |
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" | ||
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" | ||
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" | ||
Metal = "dde4c033-4e86-420c-a63e-0dd931031962" | ||
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" | ||
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" | ||
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" | ||
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" | ||
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" | ||
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" | ||
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" | ||
|
||
[extensions] | ||
LuxDeviceUtilsAMDGPUExt = "AMDGPU" | ||
LuxDeviceUtilsCUDAExt = "CUDA" | ||
LuxDeviceUtilsFillArraysExt = "FillArrays" | ||
LuxDeviceUtilsGPUArraysExt = "GPUArrays" | ||
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA" | ||
LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"] | ||
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" | ||
LuxDeviceUtilsReverseDiffExt = "ReverseDiff" | ||
LuxDeviceUtilsSparseArraysExt = "SparseArrays" | ||
LuxDeviceUtilsTrackerExt = "Tracker" | ||
LuxDeviceUtilsZygoteExt = "Zygote" | ||
LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] | ||
DeviceUtilsAMDGPUExt = "AMDGPU" | ||
DeviceUtilsCUDAExt = "CUDA" | ||
DeviceUtilsFillArraysExt = "FillArrays" | ||
DeviceUtilsGPUArraysExt = "GPUArrays" | ||
DeviceUtilsMetalExt = ["GPUArrays", "Metal"] | ||
DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools" | ||
DeviceUtilsReverseDiffExt = "ReverseDiff" | ||
DeviceUtilsSparseArraysExt = "SparseArrays" | ||
DeviceUtilsTrackerExt = "Tracker" | ||
DeviceUtilsZygoteExt = "Zygote" | ||
DeviceUtilscuDNNExt = ["CUDA", "cuDNN"] | ||
DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"] | ||
|
||
[compat] | ||
AMDGPU = "0.9.6" | ||
|
@@ -54,7 +54,6 @@ FillArrays = "1" | |
ForwardDiff = "0.10.36" | ||
Functors = "0.4.8" | ||
GPUArrays = "10" | ||
LuxCUDA = "0.3.2" | ||
LuxCore = "0.1.4" | ||
Metal = "1" | ||
Pkg = "1.10" | ||
|
@@ -68,9 +67,11 @@ Test = "1.10" | |
Tracker = "0.2.34" | ||
UnrolledUtilities = "0.1.2" | ||
Zygote = "0.6.69" | ||
cuDNN = "1.3" | ||
julia = "1.10" | ||
oneAPI = "1.5" | ||
|
||
|
||
[extras] | ||
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" | ||
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
module DeviceUtilsAMDGPUExt | ||
|
||
using Adapt: Adapt | ||
using AMDGPU: AMDGPU | ||
using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device! | ||
using Random: Random | ||
|
||
__init__() = reset_gpu_device!() | ||
|
||
# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package. | ||
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing) | ||
|
||
function _check_use_amdgpu!() | ||
USE_AMD_GPU[] === nothing || return | ||
|
||
USE_AMD_GPU[] = AMDGPU.functional() | ||
if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen) | ||
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \ | ||
available." maxlog=1 | ||
end | ||
return | ||
end | ||
|
||
DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true | ||
function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool | ||
_check_use_amdgpu!() | ||
return USE_AMD_GPU[] | ||
end | ||
|
||
function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing) | ||
return AMDGPUDevice(nothing) | ||
end | ||
function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer) | ||
id > length(AMDGPU.devices()) && | ||
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))")) | ||
old_dev = AMDGPU.device() | ||
AMDGPU.device!(AMDGPU.devices()[id]) | ||
device = AMDGPUDevice(AMDGPU.device()) | ||
AMDGPU.device!(old_dev) | ||
return device | ||
end | ||
|
||
DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device) | ||
|
||
# Default RNG | ||
DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng() | ||
|
||
# Query Device from Array | ||
function DeviceUtils._get_device(x::AMDGPU.AnyROCArray) | ||
parent_x = parent(x) | ||
parent_x === x && return AMDGPUDevice(AMDGPU.device(x)) | ||
return DeviceUtils._get_device(parent_x) | ||
end | ||
|
||
DeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice | ||
|
||
# Set Device | ||
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice) | ||
return AMDGPU.device!(dev) | ||
end | ||
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer) | ||
return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id]) | ||
end | ||
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer) | ||
id = mod1(rank + 1, length(AMDGPU.devices())) | ||
return DeviceUtils.set_device!(AMDGPUDevice, id) | ||
end | ||
|
||
# Device Transfer | ||
## To GPU | ||
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x) | ||
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray) | ||
old_dev = AMDGPU.device() # remember the current device | ||
dev = DeviceUtils.get_device(x) | ||
if !(dev isa AMDGPUDevice) | ||
AMDGPU.device!(to.device) | ||
x_new = AMDGPU.roc(x) | ||
AMDGPU.device!(old_dev) | ||
return x_new | ||
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device) | ||
return x | ||
else | ||
AMDGPU.device!(to.device) | ||
x_new = copy(x) | ||
AMDGPU.device!(old_dev) | ||
return x_new | ||
end | ||
end | ||
|
||
Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng() | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
module DeviceUtilsCUDAExt | ||
|
||
using Adapt: Adapt | ||
using CUDA: CUDA | ||
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector | ||
using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice | ||
using Random: Random | ||
|
||
function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer) | ||
id > length(CUDA.devices()) && | ||
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))")) | ||
old_dev = CUDA.device() | ||
CUDA.device!(id - 1) | ||
device = CUDADevice(CUDA.device()) | ||
CUDA.device!(old_dev) | ||
return device | ||
end | ||
|
||
function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing) | ||
return CUDADevice(nothing) | ||
end | ||
|
||
DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1 | ||
|
||
# Default RNG | ||
DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng() | ||
|
||
# Query Device from Array | ||
function DeviceUtils._get_device(x::CUDA.AnyCuArray) | ||
parent_x = parent(x) | ||
parent_x === x && return CUDADevice(CUDA.device(x)) | ||
return DeviceUtils.get_device(parent_x) | ||
end | ||
function DeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray) | ||
return CUDADevice(CUDA.device(x.nzVal)) | ||
end | ||
|
||
function DeviceUtils._get_device_type(::Union{ | ||
<:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray}) | ||
return CUDADevice | ||
end | ||
|
||
# Set Device | ||
function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice) | ||
return CUDA.device!(dev) | ||
end | ||
function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer) | ||
return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id]) | ||
end | ||
function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer) | ||
id = mod1(rank + 1, length(CUDA.devices())) | ||
return DeviceUtils.set_device!(CUDADevice, id) | ||
end | ||
|
||
# Device Transfer | ||
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x) | ||
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray) | ||
old_dev = CUDA.device() # remember the current device | ||
dev = DeviceUtils.get_device(x) | ||
if !(dev isa CUDADevice) | ||
CUDA.device!(to.device) | ||
x_new = CUDA.cu(x) | ||
CUDA.device!(old_dev) | ||
return x_new | ||
elseif dev.device == to.device | ||
return x | ||
else | ||
CUDA.device!(to.device) | ||
x_new = copy(x) | ||
CUDA.device!(old_dev) | ||
return x_new | ||
end | ||
end | ||
|
||
Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng() | ||
|
||
# Defining as extensions seems to case precompilation errors | ||
@static if isdefined(CUDA.CUSPARSE, :SparseArrays) | ||
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix) | ||
return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x) | ||
end | ||
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector) | ||
return CUDA.CUSPARSE.SparseArrays.SparseVector(x) | ||
end | ||
else | ||
@warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \ | ||
an issue in DeviceUtils.jl repository." | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
module DeviceUtilsFillArraysExt | ||
|
||
using Adapt: Adapt | ||
using FillArrays: FillArrays, AbstractFill | ||
using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice | ||
|
||
Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x | ||
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x)) | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
module DeviceUtilsGPUArraysExt | ||
|
||
using Adapt: Adapt | ||
using GPUArrays: GPUArrays | ||
using DeviceUtils: CPUDevice | ||
using Random: Random | ||
|
||
Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng() | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
module DeviceUtilsMetalExt | ||
|
||
using Adapt: Adapt | ||
using GPUArrays: GPUArrays | ||
using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device! | ||
using Metal: Metal, MtlArray | ||
|
||
__init__() = reset_gpu_device!() | ||
|
||
DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true | ||
function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}}) | ||
return Metal.functional() | ||
end | ||
|
||
# Default RNG | ||
DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray) | ||
|
||
# Query Device from Array | ||
DeviceUtils._get_device(::MtlArray) = MetalDevice() | ||
|
||
DeviceUtils._get_device_type(::MtlArray) = MetalDevice | ||
|
||
# Device Transfer | ||
## To GPU | ||
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x) | ||
|
||
end |
12 changes: 6 additions & 6 deletions
12
ext/LuxDeviceUtilsRecursiveArrayToolsExt.jl → ext/DeviceUtilsRecursiveArrayToolsExt.jl
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
module DeviceUtilsReverseDiffExt | ||
|
||
using DeviceUtils: DeviceUtils | ||
using ReverseDiff: ReverseDiff | ||
|
||
for op in (:_get_device, :_get_device_type) | ||
@eval begin | ||
function DeviceUtils.$op(x::ReverseDiff.TrackedArray) | ||
return DeviceUtils.$op(ReverseDiff.value(x)) | ||
end | ||
function DeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal}) | ||
return DeviceUtils.$op(ReverseDiff.value.(x)) | ||
end | ||
end | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
module DeviceUtilsSparseArraysExt | ||
|
||
using Adapt: Adapt | ||
using DeviceUtils: CPUDevice | ||
using SparseArrays: AbstractSparseArray | ||
|
||
Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x | ||
|
||
end |
Oops, something went wrong.