Skip to content

Commit

Permalink
Add Optimisers
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Feb 25, 2024
1 parent 23b996f commit ab718c6
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Expand All @@ -28,6 +29,7 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ComponentArraysAdaptExt = "Adapt"
ComponentArraysConstructionBaseExt = "ConstructionBase"
ComponentArraysGPUArraysExt = "GPUArrays"
ComponentArraysOptimisersExt = "Optimisers"
ComponentArraysRecursiveArrayToolsExt = "RecursiveArrayTools"
ComponentArraysReverseDiffExt = "ReverseDiff"
ComponentArraysSciMLBaseExt = "SciMLBase"
Expand Down Expand Up @@ -58,9 +60,11 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
21 changes: 21 additions & 0 deletions ext/ComponentArraysOptimisersExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
module ComponentArraysOptimisersExt

using ComponentArrays, Optimisers

# Optimisers can handle componentarrays by default, but we can vectorize the entire
# operation here instead of doing multiple smaller operations
Optimisers.setup(opt::AbstractRule, ps::ComponentArray) = Optimisers.setup(opt, getdata(ps))

function Optimisers.update(tree, ps::ComponentArray, gs::ComponentArray)
gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff
tree, ps_new = Optimisers.update(tree, getdata(ps), gs_flat)
return tree, ComponentArray(ps_new, getaxes(ps))
end

function Optimisers.update!(tree::Optimisers.Leaf, ps::ComponentArray, gs::ComponentArray)
gs_flat = ComponentArrays.__value(getdata(gs)) # Safety against ReverseDiff
tree, ps_new = Optimisers.update!(tree, getdata(ps), gs_flat)
return tree, ComponentArray(ps_new, getaxes(ps))
end

end
4 changes: 4 additions & 0 deletions ext/ComponentArraysReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,8 @@ function Base.NamedTuple(tca::TrackedComponentArray)
return NamedTuple{props}(getproperty(tca, p) for p in props)
end

@inline ComponentArrays.__value(x::AbstractArray{<:ReverseDiff.TrackedReal}) = ReverseDiff.value.(x)
@inline ComponentArrays.__value(x::ReverseDiff.TrackedArray) = ReverseDiff.value(x)
@inline ComponentArrays.__value(x::TrackedComponentArray) = ComponentArray(ComponentArrays.__value(getdata(x)), getaxes(x))

end
3 changes: 2 additions & 1 deletion ext/ComponentArraysTruncatedStacktracesExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module ComponentArraysTruncatedStacktracesExt

using ComponentArrays, TruncatedStacktraces
using ComponentArrays
import TruncatedStacktraces: @truncate_stacktrace

@truncate_stacktrace ComponentArray 1

Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,5 @@ recursive_eltype(x::AbstractArray{<:Any}) = isempty(x) ? Base.Bottom : mapreduce
recursive_eltype(x::Dict) = isempty(x) ? Base.Bottom : mapreduce(recursive_eltype, promote_type, values(x))
recursive_eltype(::AbstractArray{T,N}) where {T<:Number, N} = T
recursive_eltype(x) = typeof(x)

@inline __value(x) = x

0 comments on commit ab718c6

Please sign in to comment.