Skip to content

Commit

Permalink
Merge pull request #9 from LuxDL/ap/nopkgid
Browse files Browse the repository at this point in the history
Use `__is_functional` & `__is_loaded` instead of PkgIDs
  • Loading branch information
avik-pal authored Jul 26, 2023
2 parents 10ccab1 + f2d4d62 commit db95b0a
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 36 deletions.
4 changes: 1 addition & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxDeviceUtils"
uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.4"
version = "0.1.5"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -12,7 +12,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"

[weakdeps]
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Expand All @@ -39,7 +38,6 @@ LuxCore = "0.1.4"
Metal = "0.4, 0.5"
PackageExtensionCompat = "1"
Preferences = "1"
TruncatedStacktraces = "1"
Zygote = "0.6"
julia = "1.6"

Expand Down
3 changes: 3 additions & 0 deletions ext/LuxDeviceUtilsLuxAMDGPUExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true
LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional()

# Device Transfer
## To GPU
adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x)
Expand Down
3 changes: 3 additions & 0 deletions ext/LuxDeviceUtilsLuxCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true
LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional()

# Device Transfer
## To GPU
adapt_storage(::LuxCUDAAdaptor, x) = cu(x)
Expand Down
3 changes: 3 additions & 0 deletions ext/LuxDeviceUtilsMetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ import ChainRulesCore as CRC

__init__() = reset_gpu_device!()

LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true
LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional()

# Device Transfer
## To GPU
adapt_storage(::LuxMetalAdaptor, x) = mtl(x)
Expand Down
57 changes: 24 additions & 33 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module LuxDeviceUtils

using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays
import Adapt: adapt, adapt_storage
import Base: PkgId, UUID
import TruncatedStacktraces

using PackageExtensionCompat
function __init__()
Expand All @@ -17,41 +15,33 @@ export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor
abstract type AbstractLuxDevice <: Function end
abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end

__is_functional(::AbstractLuxDevice) = false
__is_loaded(::AbstractLuxDevice) = false

struct LuxCPUDevice <: AbstractLuxDevice end
struct LuxCUDADevice <: AbstractLuxGPUDevice end
struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end
struct LuxMetalDevice <: AbstractLuxGPUDevice end

Base.@kwdef struct LuxCUDADevice <: AbstractLuxGPUDevice
name::String = "CUDA"
pkgid::PkgId = PkgId(UUID("d0bbae9a-e099-4d5b-a835-1c6931763bda"), "LuxCUDA")
end
__is_functional(::LuxCPUDevice) = true
__is_loaded(::LuxCPUDevice) = true

Base.@kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice
name::String = "AMDGPU"
pkgid::PkgId = PkgId(UUID("83120cb1-ca15-4f04-bf3b-6967d2e6b60b"), "LuxAMDGPU")
end
_get_device_name(::LuxCPUDevice) = "CPU"
_get_device_name(::LuxCUDADevice) = "CUDA"
_get_device_name(::LuxAMDGPUDevice) = "AMDGPU"
_get_device_name(::LuxMetalDevice) = "Metal"

Base.@kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice
name::String = "Metal"
pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal")
end
_get_triggerpkg_name(::LuxCPUDevice) = ""
_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA"
_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU"
_get_triggerpkg_name(::LuxMetalDevice) = "Metal"

Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev))

struct LuxDeviceSelectionException <: Exception end

function Base.showerror(io::IO, e::LuxDeviceSelectionException)
print(io, "LuxDeviceSelectionException(No functional GPU device found!!)")
if !TruncatedStacktraces.VERBOSE[]
println(io, TruncatedStacktraces.VERBOSE_MSG)
end
end

@generated function _get_device_name(t::T) where {T <: AbstractLuxDevice}
return hasfield(T, :name) ? :(t.name) : :("")
end

@generated function _get_trigger_pkgid(t::T) where {T <: AbstractLuxDevice}
return hasfield(T, :pkgid) ? :(t.pkgid) :
:(PkgId(UUID("b2108857-7c20-44ae-9111-449ecde12c47"), "Lux"))
return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)")
end

# Order is important here
Expand Down Expand Up @@ -125,32 +115,33 @@ function _get_gpu_device(; force_gpu_usage::Bool)
else
@debug "Using GPU backend set in preferences: $backend."
device = GPU_DEVICES[idx]
if !haskey(Base.loaded_modules, device.pkgid)
if !__is_loaded(device)
@warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded.
Ignoring the Preferences backend!!!
Please load the package and call this function again to respect the Preferences backend.""" maxlog=1
else
if getproperty(Base.loaded_modules[device.pkgid], :functional)()
if __is_functional(device)
@debug "Using GPU backend: $(_get_device_name(device))."
return device
else
@warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog=1
@warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional.
Defaulting to automatic GPU Backend selection." maxlog=1
end
end
end
end

@debug "Running automatic GPU backend selection..."
for device in GPU_DEVICES
if haskey(Base.loaded_modules, device.pkgid)
if __is_loaded(device)
@debug "Trying backend: $(_get_device_name(device))."
if getproperty(Base.loaded_modules[device.pkgid], :functional)()
if __is_functional(device)
@debug "Using GPU backend: $(_get_device_name(device))."
return device
end
@debug "GPU backend: $(_get_device_name(device)) is not functional."
else
@debug "Trigger package for backend ($(_get_device_name(device))): $((device.pkgid)) not loaded."
@debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded."
end
end

Expand Down

2 comments on commit db95b0a

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/88419

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.1.5 -m "<description of version>" db95b0a949a839f0b0e06d1981533f56e2e09b4a
git push origin v0.1.5

Please sign in to comment.