Skip to content

Commit

Permalink
refactor!: rename package to DeviceUtils.jl
Browse files Browse the repository at this point in the history
BREAKING CHANGE: All "Lux" prefixes have been dropped for wider adoption
  • Loading branch information
avik-pal committed Jul 16, 2024
1 parent a3c57a6 commit 1e7a438
Show file tree
Hide file tree
Showing 33 changed files with 625 additions and 604 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Manifest.toml
*.cov
generated
build
.vscode
Expand Down
31 changes: 16 additions & 15 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name = "LuxDeviceUtils"
name = "DeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.26"
Expand All @@ -17,28 +17,28 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
oneAPI = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b"

[extensions]
LuxDeviceUtilsAMDGPUExt = "AMDGPU"
LuxDeviceUtilsCUDAExt = "CUDA"
LuxDeviceUtilsFillArraysExt = "FillArrays"
LuxDeviceUtilsGPUArraysExt = "GPUArrays"
LuxDeviceUtilsLuxCUDAExt = "LuxCUDA"
LuxDeviceUtilsMetalExt = ["GPUArrays", "Metal"]
LuxDeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
LuxDeviceUtilsReverseDiffExt = "ReverseDiff"
LuxDeviceUtilsSparseArraysExt = "SparseArrays"
LuxDeviceUtilsTrackerExt = "Tracker"
LuxDeviceUtilsZygoteExt = "Zygote"
LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]
DeviceUtilsAMDGPUExt = "AMDGPU"
DeviceUtilsCUDAExt = "CUDA"
DeviceUtilsFillArraysExt = "FillArrays"
DeviceUtilsGPUArraysExt = "GPUArrays"
DeviceUtilsMetalExt = ["GPUArrays", "Metal"]
DeviceUtilsRecursiveArrayToolsExt = "RecursiveArrayTools"
DeviceUtilsReverseDiffExt = "ReverseDiff"
DeviceUtilsSparseArraysExt = "SparseArrays"
DeviceUtilsTrackerExt = "Tracker"
DeviceUtilsZygoteExt = "Zygote"
DeviceUtilscuDNNExt = ["CUDA", "cuDNN"]
DeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]

[compat]
AMDGPU = "0.9.6"
Expand All @@ -54,7 +54,6 @@ FillArrays = "1"
ForwardDiff = "0.10.36"
Functors = "0.4.8"
GPUArrays = "10"
LuxCUDA = "0.3.2"
LuxCore = "0.1.4"
Metal = "1"
Pkg = "1.10"
Expand All @@ -68,9 +67,11 @@ Test = "1.10"
Tracker = "0.2.34"
UnrolledUtilities = "0.1.2"
Zygote = "0.6.69"
cuDNN = "1.3"
julia = "1.10"
oneAPI = "1.5"


[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
16 changes: 8 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# LuxDeviceUtils
# DeviceUtils

[![Join the chat at https://julialang.zulipchat.com #machine-learning](https://img.shields.io/static/v1?label=Zulip&message=chat&color=9558b2&labelColor=389826)](https://julialang.zulipchat.com/#narrow/stream/machine-learning)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/LuxDeviceUtils)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/LuxDeviceUtils)
[![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://lux.csail.mit.edu/dev/api/Accelerator_Support/DeviceUtils)
[![Stable Docs](https://img.shields.io/badge/docs-stable-blue.svg)](https://lux.csail.mit.edu/stable/api/Accelerator_Support/DeviceUtils)

[![CI](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/LuxDeviceUtils.jl/actions/workflows/CI.yml)
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/luxdeviceutils-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/LuxDeviceUtils.jl)
[![CI](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml/badge.svg)](https://github.com/LuxDL/DeviceUtils.jl/actions/workflows/CI.yml)
[![Buildkite](https://badge.buildkite.com/b098d6387b2c69bd0ab684293ff66332047b219e1b8f9bb486.svg?branch=main)](https://buildkite.com/julialang/DeviceUtils-dot-jl)
[![codecov](https://codecov.io/gh/LuxDL/DeviceUtils.jl/branch/main/graph/badge.svg?token=1ZY0A2NPEM)](https://codecov.io/gh/LuxDL/DeviceUtils.jl)
[![Aqua QA](https://raw.githubusercontent.com/JuliaTesting/Aqua.jl/master/badge.svg)](https://github.com/JuliaTesting/Aqua.jl)

[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)
[![SciML Code Style](https://img.shields.io/static/v1?label=code%20style&message=SciML&color=9558b2&labelColor=389826)](https://github.com/SciML/SciMLStyle)

`LuxDeviceUtils.jl` is a lightweight package defining rules for transferring data across
devices. Most users should directly use [Lux.jl](https://lux.csail.mit.edu/) instead.
`DeviceUtils.jl` is a lightweight package defining rules for transferring data across
devices. It is used in deep learning frameworks such as [Lux.jl](https://lux.csail.mit.edu/).

Currently we provide support for the following backends:

Expand Down
92 changes: 92 additions & 0 deletions ext/DeviceUtilsAMDGPUExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
module DeviceUtilsAMDGPUExt

using Adapt: Adapt
using AMDGPU: AMDGPU
using DeviceUtils: DeviceUtils, AMDGPUDevice, CPUDevice, reset_gpu_device!
using Random: Random

__init__() = reset_gpu_device!()

# This code used to be in `LuxAMDGPU.jl`, but we no longer need that package.
const USE_AMD_GPU = Ref{Union{Nothing, Bool}}(nothing)

function _check_use_amdgpu!()
USE_AMD_GPU[] === nothing || return

USE_AMD_GPU[] = AMDGPU.functional()
if USE_AMD_GPU[] && !AMDGPU.functional(:MIOpen)
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be \
available." maxlog=1
end
return
end

DeviceUtils.loaded(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}}) = true
function DeviceUtils.functional(::Union{AMDGPUDevice, <:Type{AMDGPUDevice}})::Bool
_check_use_amdgpu!()
return USE_AMD_GPU[]
end

function DeviceUtils._with_device(::Type{AMDGPUDevice}, ::Nothing)
return AMDGPUDevice(nothing)
end
function DeviceUtils._with_device(::Type{AMDGPUDevice}, id::Integer)
id > length(AMDGPU.devices()) &&
throw(ArgumentError("id = $id > length(AMDGPU.devices()) = $(length(AMDGPU.devices()))"))
old_dev = AMDGPU.device()
AMDGPU.device!(AMDGPU.devices()[id])
device = AMDGPUDevice(AMDGPU.device())
AMDGPU.device!(old_dev)
return device
end

DeviceUtils._get_device_id(dev::AMDGPUDevice) = AMDGPU.device_id(dev.device)

# Default RNG
DeviceUtils.default_device_rng(::AMDGPUDevice) = AMDGPU.rocrand_rng()

# Query Device from Array
function DeviceUtils._get_device(x::AMDGPU.AnyROCArray)
parent_x = parent(x)
parent_x === x && return AMDGPUDevice(AMDGPU.device(x))
return DeviceUtils._get_device(parent_x)
end

DeviceUtils._get_device_type(::AMDGPU.AnyROCArray) = AMDGPUDevice

# Set Device
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, dev::AMDGPU.HIPDevice)
return AMDGPU.device!(dev)
end
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, id::Integer)
return DeviceUtils.set_device!(AMDGPUDevice, AMDGPU.devices()[id])
end
function DeviceUtils.set_device!(::Type{AMDGPUDevice}, ::Nothing, rank::Integer)
id = mod1(rank + 1, length(AMDGPU.devices()))
return DeviceUtils.set_device!(AMDGPUDevice, id)
end

# Device Transfer
## To GPU
Adapt.adapt_storage(::AMDGPUDevice{Nothing}, x::AbstractArray) = AMDGPU.roc(x)
function Adapt.adapt_storage(to::AMDGPUDevice, x::AbstractArray)
old_dev = AMDGPU.device() # remember the current device
dev = DeviceUtils.get_device(x)
if !(dev isa AMDGPUDevice)
AMDGPU.device!(to.device)
x_new = AMDGPU.roc(x)
AMDGPU.device!(old_dev)
return x_new
elseif AMDGPU.device_id(dev.device) == AMDGPU.device_id(to.device)
return x
else
AMDGPU.device!(to.device)
x_new = copy(x)
AMDGPU.device!(old_dev)
return x_new
end
end

Adapt.adapt_storage(::CPUDevice, rng::AMDGPU.rocRAND.RNG) = Random.default_rng()

end
90 changes: 90 additions & 0 deletions ext/DeviceUtilsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
module DeviceUtilsCUDAExt

using Adapt: Adapt
using CUDA: CUDA
using CUDA.CUSPARSE: AbstractCuSparseMatrix, AbstractCuSparseVector
using DeviceUtils: DeviceUtils, CUDADevice, CPUDevice
using Random: Random

function DeviceUtils._with_device(::Type{CUDADevice}, id::Integer)
id > length(CUDA.devices()) &&
throw(ArgumentError("id = $id > length(CUDA.devices()) = $(length(CUDA.devices()))"))
old_dev = CUDA.device()
CUDA.device!(id - 1)
device = CUDADevice(CUDA.device())
CUDA.device!(old_dev)
return device
end

function DeviceUtils._with_device(::Type{CUDADevice}, ::Nothing)
return CUDADevice(nothing)
end

DeviceUtils._get_device_id(dev::CUDADevice) = CUDA.deviceid(dev.device) + 1

# Default RNG
DeviceUtils.default_device_rng(::CUDADevice) = CUDA.default_rng()

# Query Device from Array
function DeviceUtils._get_device(x::CUDA.AnyCuArray)
parent_x = parent(x)
parent_x === x && return CUDADevice(CUDA.device(x))
return DeviceUtils.get_device(parent_x)
end
function DeviceUtils._get_device(x::CUDA.CUSPARSE.AbstractCuSparseArray)
return CUDADevice(CUDA.device(x.nzVal))
end

function DeviceUtils._get_device_type(::Union{
<:CUDA.AnyCuArray, <:CUDA.CUSPARSE.AbstractCuSparseArray})
return CUDADevice
end

# Set Device
function DeviceUtils.set_device!(::Type{CUDADevice}, dev::CUDA.CuDevice)
return CUDA.device!(dev)
end
function DeviceUtils.set_device!(::Type{CUDADevice}, id::Integer)
return DeviceUtils.set_device!(CUDADevice, collect(CUDA.devices())[id])
end
function DeviceUtils.set_device!(::Type{CUDADevice}, ::Nothing, rank::Integer)
id = mod1(rank + 1, length(CUDA.devices()))
return DeviceUtils.set_device!(CUDADevice, id)
end

# Device Transfer
Adapt.adapt_storage(::CUDADevice{Nothing}, x::AbstractArray) = CUDA.cu(x)
function Adapt.adapt_storage(to::CUDADevice, x::AbstractArray)
old_dev = CUDA.device() # remember the current device
dev = DeviceUtils.get_device(x)
if !(dev isa CUDADevice)
CUDA.device!(to.device)
x_new = CUDA.cu(x)
CUDA.device!(old_dev)
return x_new
elseif dev.device == to.device
return x
else
CUDA.device!(to.device)
x_new = copy(x)
CUDA.device!(old_dev)
return x_new
end
end

Adapt.adapt_storage(::CPUDevice, rng::CUDA.RNG) = Random.default_rng()

# Defining as extensions seems to case precompilation errors
@static if isdefined(CUDA.CUSPARSE, :SparseArrays)
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseMatrix)
return CUDA.CUSPARSE.SparseArrays.SparseMatrixCSC(x)
end
function Adapt.adapt_storage(::CPUDevice, x::AbstractCuSparseVector)
return CUDA.CUSPARSE.SparseArrays.SparseVector(x)
end
else
@warn "CUDA.CUSPARSE seems to have removed SparseArrays as a dependency. Please open \
an issue in DeviceUtils.jl repository."
end

end
10 changes: 10 additions & 0 deletions ext/DeviceUtilsFillArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DeviceUtilsFillArraysExt

using Adapt: Adapt
using FillArrays: FillArrays, AbstractFill
using DeviceUtils: DeviceUtils, CPUDevice, AbstractDevice

Adapt.adapt_structure(::CPUDevice, x::AbstractFill) = x
Adapt.adapt_structure(to::AbstractDevice, x::AbstractFill) = Adapt.adapt(to, collect(x))

end
10 changes: 10 additions & 0 deletions ext/DeviceUtilsGPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module DeviceUtilsGPUArraysExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using DeviceUtils: CPUDevice
using Random: Random

Adapt.adapt_storage(::CPUDevice, rng::GPUArrays.RNG) = Random.default_rng()

end
27 changes: 27 additions & 0 deletions ext/DeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
module DeviceUtilsMetalExt

using Adapt: Adapt
using GPUArrays: GPUArrays
using DeviceUtils: DeviceUtils, MetalDevice, reset_gpu_device!
using Metal: Metal, MtlArray

__init__() = reset_gpu_device!()

DeviceUtils.loaded(::Union{MetalDevice, Type{<:MetalDevice}}) = true
function DeviceUtils.functional(::Union{MetalDevice, Type{<:MetalDevice}})
return Metal.functional()
end

# Default RNG
DeviceUtils.default_device_rng(::MetalDevice) = GPUArrays.default_rng(MtlArray)

# Query Device from Array
DeviceUtils._get_device(::MtlArray) = MetalDevice()

DeviceUtils._get_device_type(::MtlArray) = MetalDevice

# Device Transfer
## To GPU
Adapt.adapt_storage(::MetalDevice, x::AbstractArray) = Metal.mtl(x)

end
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
module LuxDeviceUtilsRecursiveArrayToolsExt
module DeviceUtilsRecursiveArrayToolsExt

using Adapt: Adapt, adapt
using LuxDeviceUtils: LuxDeviceUtils, AbstractLuxDevice
using DeviceUtils: DeviceUtils, AbstractDevice
using RecursiveArrayTools: VectorOfArray, DiffEqArray

# We want to preserve the structure
function Adapt.adapt_structure(to::AbstractLuxDevice, x::VectorOfArray)
function Adapt.adapt_structure(to::AbstractDevice, x::VectorOfArray)
return VectorOfArray(map(Base.Fix1(adapt, to), x.u))
end

function Adapt.adapt_structure(to::AbstractLuxDevice, x::DiffEqArray)
function Adapt.adapt_structure(to::AbstractDevice, x::DiffEqArray)
# Don't move the `time` to the GPU
return DiffEqArray(map(Base.Fix1(adapt, to), x.u), x.t)
end

for op in (:_get_device, :_get_device_type)
@eval function LuxDeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray})
@eval function DeviceUtils.$op(x::Union{VectorOfArray, DiffEqArray})
length(x.u) == 0 && return $(op == :_get_device ? nothing : Nothing)
return mapreduce(LuxDeviceUtils.$op, LuxDeviceUtils.__combine_devices, x.u)
return mapreduce(DeviceUtils.$op, DeviceUtils.__combine_devices, x.u)
end
end

Expand Down
17 changes: 17 additions & 0 deletions ext/DeviceUtilsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module DeviceUtilsReverseDiffExt

using DeviceUtils: DeviceUtils
using ReverseDiff: ReverseDiff

for op in (:_get_device, :_get_device_type)
@eval begin
function DeviceUtils.$op(x::ReverseDiff.TrackedArray)
return DeviceUtils.$op(ReverseDiff.value(x))
end
function DeviceUtils.$op(x::AbstractArray{<:ReverseDiff.TrackedReal})
return DeviceUtils.$op(ReverseDiff.value.(x))
end
end
end

end
9 changes: 9 additions & 0 deletions ext/DeviceUtilsSparseArraysExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
module DeviceUtilsSparseArraysExt

using Adapt: Adapt
using DeviceUtils: CPUDevice
using SparseArrays: AbstractSparseArray

Adapt.adapt_storage(::CPUDevice, x::AbstractSparseArray) = x

end
Loading

0 comments on commit 1e7a438

Please sign in to comment.