diff --git a/Project.toml b/Project.toml index a18f3c3..b6b6eb6 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "LuxDeviceUtils" uuid = "34f89e08-e1d5-43b4-8944-0b49ac560553" authors = ["Avik Pal and contributors"] -version = "0.1.4" +version = "0.1.5" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -12,7 +12,6 @@ PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" [weakdeps] FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" @@ -39,7 +38,6 @@ LuxCore = "0.1.4" Metal = "0.4, 0.5" PackageExtensionCompat = "1" Preferences = "1" -TruncatedStacktraces = "1" Zygote = "0.6" julia = "1.6" diff --git a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl index c22fd03..e9e2fa4 100644 --- a/ext/LuxDeviceUtilsLuxAMDGPUExt.jl +++ b/ext/LuxDeviceUtilsLuxAMDGPUExt.jl @@ -6,6 +6,9 @@ import ChainRulesCore as CRC __init__() = reset_gpu_device!() +LuxDeviceUtils.__is_loaded(::LuxAMDGPUDevice) = true +LuxDeviceUtils.__is_functional(::LuxAMDGPUDevice) = LuxAMDGPU.functional() + # Device Transfer ## To GPU adapt_storage(::LuxAMDGPUAdaptor, x) = roc(x) diff --git a/ext/LuxDeviceUtilsLuxCUDAExt.jl b/ext/LuxDeviceUtilsLuxCUDAExt.jl index c61e00a..b3525a1 100644 --- a/ext/LuxDeviceUtilsLuxCUDAExt.jl +++ b/ext/LuxDeviceUtilsLuxCUDAExt.jl @@ -6,6 +6,9 @@ import ChainRulesCore as CRC __init__() = reset_gpu_device!() +LuxDeviceUtils.__is_loaded(::LuxCUDADevice) = true +LuxDeviceUtils.__is_functional(::LuxCUDADevice) = LuxCUDA.functional() + # Device Transfer ## To GPU adapt_storage(::LuxCUDAAdaptor, x) = cu(x) diff --git a/ext/LuxDeviceUtilsMetalExt.jl b/ext/LuxDeviceUtilsMetalExt.jl index 505107d..9f6218f 100644 --- a/ext/LuxDeviceUtilsMetalExt.jl +++ b/ext/LuxDeviceUtilsMetalExt.jl @@ -6,6 +6,9 @@ import ChainRulesCore as CRC __init__() = reset_gpu_device!() +LuxDeviceUtils.__is_loaded(::LuxMetalDevice) = true +LuxDeviceUtils.__is_functional(::LuxMetalDevice) = Metal.functional() + # Device Transfer ## To GPU adapt_storage(::LuxMetalAdaptor, x) = mtl(x) diff --git a/src/LuxDeviceUtils.jl b/src/LuxDeviceUtils.jl index dbab572..ca439dd 100644 --- a/src/LuxDeviceUtils.jl +++ b/src/LuxDeviceUtils.jl @@ -2,8 +2,6 @@ module LuxDeviceUtils using ChainRulesCore, Functors, LuxCore, Preferences, Random, SparseArrays import Adapt: adapt, adapt_storage -import Base: PkgId, UUID -import TruncatedStacktraces using PackageExtensionCompat function __init__() @@ -17,41 +15,33 @@ export LuxCPUAdaptor, LuxCUDAAdaptor, LuxAMDGPUAdaptor, LuxMetalAdaptor abstract type AbstractLuxDevice <: Function end abstract type AbstractLuxGPUDevice <: AbstractLuxDevice end +__is_functional(::AbstractLuxDevice) = false +__is_loaded(::AbstractLuxDevice) = false + struct LuxCPUDevice <: AbstractLuxDevice end +struct LuxCUDADevice <: AbstractLuxGPUDevice end +struct LuxAMDGPUDevice <: AbstractLuxGPUDevice end +struct LuxMetalDevice <: AbstractLuxGPUDevice end -Base.@kwdef struct LuxCUDADevice <: AbstractLuxGPUDevice - name::String = "CUDA" - pkgid::PkgId = PkgId(UUID("d0bbae9a-e099-4d5b-a835-1c6931763bda"), "LuxCUDA") -end +__is_functional(::LuxCPUDevice) = true +__is_loaded(::LuxCPUDevice) = true -Base.@kwdef struct LuxAMDGPUDevice <: AbstractLuxGPUDevice - name::String = "AMDGPU" - pkgid::PkgId = PkgId(UUID("83120cb1-ca15-4f04-bf3b-6967d2e6b60b"), "LuxAMDGPU") -end +_get_device_name(::LuxCPUDevice) = "CPU" +_get_device_name(::LuxCUDADevice) = "CUDA" +_get_device_name(::LuxAMDGPUDevice) = "AMDGPU" +_get_device_name(::LuxMetalDevice) = "Metal" -Base.@kwdef struct LuxMetalDevice <: AbstractLuxGPUDevice - name::String = "Metal" - pkgid::PkgId = PkgId(UUID("dde4c033-4e86-420c-a63e-0dd931031962"), "Metal") -end +_get_triggerpkg_name(::LuxCPUDevice) = "" +_get_triggerpkg_name(::LuxCUDADevice) = "LuxCUDA" +_get_triggerpkg_name(::LuxAMDGPUDevice) = "LuxAMDGPU" +_get_triggerpkg_name(::LuxMetalDevice) = "Metal" Base.show(io::IO, dev::AbstractLuxDevice) = print(io, nameof(dev)) struct LuxDeviceSelectionException <: Exception end function Base.showerror(io::IO, e::LuxDeviceSelectionException) - print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") - if !TruncatedStacktraces.VERBOSE[] - println(io, TruncatedStacktraces.VERBOSE_MSG) - end -end - -@generated function _get_device_name(t::T) where {T <: AbstractLuxDevice} - return hasfield(T, :name) ? :(t.name) : :("") -end - -@generated function _get_trigger_pkgid(t::T) where {T <: AbstractLuxDevice} - return hasfield(T, :pkgid) ? :(t.pkgid) : - :(PkgId(UUID("b2108857-7c20-44ae-9111-449ecde12c47"), "Lux")) + return print(io, "LuxDeviceSelectionException(No functional GPU device found!!)") end # Order is important here @@ -125,16 +115,17 @@ function _get_gpu_device(; force_gpu_usage::Bool) else @debug "Using GPU backend set in preferences: $backend." device = GPU_DEVICES[idx] - if !haskey(Base.loaded_modules, device.pkgid) + if !__is_loaded(device) @warn """Trying to use backend: $(_get_device_name(device)) but the trigger package $(device.pkgid) is not loaded. Ignoring the Preferences backend!!! Please load the package and call this function again to respect the Preferences backend.""" maxlog=1 else - if getproperty(Base.loaded_modules[device.pkgid], :functional)() + if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device else - @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. Defaulting to automatic GPU Backend selection." maxlog=1 + @warn "GPU backend: $(_get_device_name(device)) set via Preferences.jl is not functional. + Defaulting to automatic GPU Backend selection." maxlog=1 end end end @@ -142,15 +133,15 @@ function _get_gpu_device(; force_gpu_usage::Bool) @debug "Running automatic GPU backend selection..." for device in GPU_DEVICES - if haskey(Base.loaded_modules, device.pkgid) + if __is_loaded(device) @debug "Trying backend: $(_get_device_name(device))." - if getproperty(Base.loaded_modules[device.pkgid], :functional)() + if __is_functional(device) @debug "Using GPU backend: $(_get_device_name(device))." return device end @debug "GPU backend: $(_get_device_name(device)) is not functional." else - @debug "Trigger package for backend ($(_get_device_name(device))): $((device.pkgid)) not loaded." + @debug "Trigger package for backend ($(_get_device_name(device))): $(_get_trigger_pkgname(device)) not loaded." end end