Skip to content

Commit

Permalink
Remove unwanted deps
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Jun 6, 2024
1 parent 1b949a3 commit d801954
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
4 changes: 0 additions & 4 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@ version = "0.1.21"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Expand Down Expand Up @@ -42,12 +40,10 @@ LuxDeviceUtilsoneAPIExt = ["GPUArrays", "oneAPI"]
AMDGPU = "0.8.4, 0.9"
Adapt = "4"
Aqua = "0.8.4"
ArgCheck = "2.3"
CUDA = "5.2"
ChainRulesCore = "1.20"
ComponentArrays = "0.15.8"
ExplicitImports = "1.4.1"
FastClosures = "0.3.2"
FillArrays = "1"
Functors = "0.4.4"
GPUArrays = "10"
Expand Down
13 changes: 8 additions & 5 deletions src/LuxDeviceUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using Adapt: Adapt
using ArgCheck: @argcheck
using ChainRulesCore: ChainRulesCore, NoTangent
using FastClosures: @closure
using Functors: Functors, fmap
using LuxCore: LuxCore
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
Expand Down Expand Up @@ -280,7 +278,9 @@ function gpu_backend!(backend::String)
return
end

@argcheck backend in allowed_backends
if backend allowed_backends
throw(ArgumentError("Invalid backend: $backend. Valid backends are $allowed_backends."))
end

@set_preferences!("gpu_backend"=>backend)
@info "GPU backend has been set to $backend. Restart Julia to use the new backend."
Expand Down Expand Up @@ -378,7 +378,8 @@ __combine_devices(dev1) = dev1
function __combine_devices(dev1, dev2)
dev1 === nothing && return dev2
dev2 === nothing && return dev1
@argcheck dev1 == dev2
dev1 != dev2 &&
throw(ArgumentError("Objects are on different devices: $dev1 and $dev2."))
return dev1
end
function __combine_devices(dev1, dev2, rem_devs...)
Expand Down Expand Up @@ -484,7 +485,9 @@ end

# Chain Rules Core
function CRC.rrule(::typeof(Adapt.adapt_storage), to::AbstractLuxDevice, x::AbstractArray)
∇adapt_storage = @closure Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
∇adapt_storage = let x = x
Δ -> (NoTangent(), NoTangent(), (get_device(x))(Δ))
end
return Adapt.adapt_storage(to, x), ∇adapt_storage
end

Expand Down

0 comments on commit d801954

Please sign in to comment.