Skip to content

Commit

Permalink
Try sparsity aware AD in Enzyme
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Jan 26, 2024
1 parent 4da7756 commit 9d23251
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 11 deletions.
28 changes: 22 additions & 6 deletions ext/OptimizationEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module OptimizationEnzymeExt

import Optimization, Optimization.ArrayInterface
import Optimization, Optimization.ArrayInterface, Optimization.SparseArrays
import Optimization.SciMLBase: OptimizationFunction
import Optimization.LinearAlgebra: I
import Optimization.ADTypes: AutoEnzyme
Expand Down Expand Up @@ -47,11 +47,27 @@ function Optimization.instantiate_function(f::OptimizationFunction{true}, x,
return nothing
end
function hess(res, θ, args...)
vdθ = Tuple((Array(r) for r in eachrow(I(length(θ)) * 1.0)))

= zeros(length(θ))
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))

if f.hess_prototype === nothing
vdθ = Tuple((similar(r) for r in eachrow(I(length(θ)) * 1.0)))
= zeros(length(θ))
@show
@show typeof(bθ)
vdbθ = Tuple(zeros(length(θ)) for i in eachindex(θ))
@show vdbθ
@show typeof(vdbθ)

Check warning on line 57 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L50-L57

Added lines #L50 - L57 were not covered by tests
else
θ = SparseArrays.sparse(θ)
@show θ
vdθ = Tuple((similar(SparseArrays.sparse(r)) for r in eachrow(I(length(θ)) * 1.0)))
@show vdθ
@show typeof(vdθ)
= SparseArrays.similar(θ)
@show
@show typeof(bθ)
vdbθ = Tuple(similar(i) for i in eachrow(f.hess_prototype))
@show vdbθ
@show typeof(vdbθ)

Check warning on line 69 in ext/OptimizationEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationEnzymeExt.jl#L59-L69

Added lines #L59 - L69 were not covered by tests
end
Enzyme.autodiff(Enzyme.Forward,
g,
Enzyme.BatchDuplicated(θ, vdθ),
Expand Down
22 changes: 17 additions & 5 deletions test/ADtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]
G2 = Array{Float64}(undef, 2)
H2 = Array{Float64}(undef, 2, 2)

optf = OptimizationFunction(rosenbrock, Optimization.AutoZygote(), cons = con2_c)
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoZygote(),
optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(), cons = con2_c)
optprob = Optimization.instantiate_function(optf, x0, Optimization.AutoEnzyme(),
nothing, 2)
optprob.grad(G2, x0)
@test G1 == G2
Expand All @@ -167,15 +167,27 @@ H3 = [Array{Float64}(undef, 2, 2), Array{Float64}(undef, 2, 2)]
optprob.cons_h(H3, x0)
H3 == [[2.0 0.0; 0.0 2.0], [-0.0 1.0; 1.0 0.0]]

optf = OptimizationFunction(rosenbrock, Optimization.AutoModelingToolkit(true, true),
optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(),
cons = con2_c)
optprob = Optimization.instantiate_function(optf, x0,
Optimization.AutoModelingToolkit(true, true),
Optimization.AutoEnzyme(),
nothing, 2)
using SparseArrays
sH = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4))
@test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2]
# @test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2]
optprob.hess(sH, x0)

sH = sparse([1, 1, 2, 2], [1, 2, 1, 2], zeros(4))
optf = OptimizationFunction(rosenbrock, Optimization.AutoEnzyme(),
cons = con2_c, hess_prototype = sH)
optprob = Optimization.instantiate_function(optf, x0,
Optimization.AutoEnzyme(),
nothing, 2)
using SparseArrays

# @test findnz(sH)[1:2] == findnz(optprob.hess_prototype)[1:2]
optprob.hess(sH, x0)

@test sH == H2
res = Array{Float64}(undef, 2)
optprob.cons(res, x0)
Expand Down

0 comments on commit 9d23251

Please sign in to comment.