Skip to content

Commit

Permalink
Merge pull request #597 from SciML/performancetuning
Browse files Browse the repository at this point in the history
Eliminate some runtime dispatch and other things
  • Loading branch information
Vaibhavdixit02 authored Sep 28, 2023
2 parents ff678ed + 57c2f39 commit 1c3349e
Showing 1 changed file with 51 additions and 24 deletions.
75 changes: 51 additions & 24 deletions ext/OptimizationSparseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -561,16 +561,18 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && f.cons_j === nothing
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
zeros(eltype(x), num_cons),
x)
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = ForwardColorJacCache(cons, x;
colorvec = cons_jac_colorvec,
sparsity = cons_jac_prototype,
dx = zeros(eltype(x), num_cons))
cons_j = function (J, θ)
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, x, fx = zeros(eltype(x), num_cons))
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
# ForwardColorJacCache(cons, θ;
# colorvec = cons_jac_colorvec,
# sparsity = cons_jac_prototype,
# dx = zeros(eltype(θ), num_cons))
# end
cons_jac_prototype = jaccache.jac_prototype
cons_jac_colorvec = jaccache.coloring
cons_j = function (J, θ, args...;cons = cons, cache = jaccache.cache)
forwarddiff_color_jacobian!(J, cons, θ, cache)
return
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, p)
Expand All @@ -592,7 +594,7 @@ function Optimization.instantiate_function(f, x, adtype::AutoSparseReverseDiff,
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardColorJacCache(gs[i], x; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
cons_h = function (res, θ)
cons_h = function (res, θ, args...)
for i in 1:num_cons
SparseDiffTools.forwarddiff_color_jacobian!(res[i], gs[i], θ, jaccfgs[i])
end
Expand Down Expand Up @@ -692,23 +694,32 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
if f.cons === nothing
cons = nothing
else
cons = (res, θ) -> f.cons(res, θ, cache.p)
cons = function (res, θ)
f.cons(res, θ, cache.p)
return
end
cons_oop = (x) -> (_res = zeros(eltype(x), num_cons); cons(_res, x); _res)
end

cons_jac_prototype = f.cons_jac_prototype
cons_jac_colorvec = f.cons_jac_colorvec
if cons !== nothing && f.cons_j === nothing
cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
zeros(eltype(cache.u0), num_cons),
cache.u0)
cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = ForwardColorJacCache(cons, cache.u0;
colorvec = cons_jac_colorvec,
sparsity = cons_jac_prototype,
dx = zeros(eltype(cache.u0), num_cons))
# cons_jac_prototype = Symbolics.jacobian_sparsity(cons,
# zeros(eltype(cache.u0), num_cons),
# cache.u0)
# cons_jac_colorvec = matrix_colors(cons_jac_prototype)
jaccache = SparseDiffTools.sparse_jacobian_cache(AutoSparseForwardDiff(), SparseDiffTools.SymbolicsSparsityDetection(), cons_oop, cache.u0, fx = zeros(eltype(cache.u0), num_cons))
# let cons = cons, θ = cache.u0, cons_jac_colorvec = cons_jac_colorvec, cons_jac_prototype = cons_jac_prototype, num_cons = num_cons
# ForwardColorJacCache(cons, θ;
# colorvec = cons_jac_colorvec,
# sparsity = cons_jac_prototype,
# dx = zeros(eltype(θ), num_cons))
# end
cons_jac_prototype = jaccache.jac_prototype
cons_jac_colorvec = jaccache.coloring
cons_j = function (J, θ)
forwarddiff_color_jacobian!(J, cons, θ, jaccache)
forwarddiff_color_jacobian!(J, cons, θ, jaccache.cache)
return
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
Expand All @@ -717,8 +728,18 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
conshess_sparsity = f.cons_hess_prototype
conshess_colors = f.cons_hess_colorvec
if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
conshess_sparsity = Symbolics.hessian_sparsity.(fncs, Ref(cache.u0))
fncs = map(1:num_cons) do i
function (x)
res = zeros(eltype(x), num_cons)
f.cons(res, x, cache.p)
return res[i]
end
end
conshess_sparsity = map(1:num_cons) do i
let fnc = fncs[i], θ = cache.u0
Symbolics.hessian_sparsity(fnc, θ)
end
end
conshess_colors = SparseDiffTools.matrix_colors.(conshess_sparsity)
if adtype.compile
T = ForwardDiff.Tag(OptimizationSparseReverseTag(),eltype(cache.u0))
Expand All @@ -728,7 +749,13 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
function grad_cons(res1, θ, htape)
ReverseDiff.gradient!(res1, htape, θ)
end
gs = [(res1, x) -> grad_cons(res1, x, conshtapes[i]) for i in 1:num_cons]
gs = let conshtapes = conshtapes
map(1:num_cons) do i
function (res1, x)
grad_cons(res1, x, conshtapes[i])
end
end
end
jaccfgs = [ForwardColorJacCache(gs[i], cache.u0; tag = typeof(T), colorvec = conshess_colors[i], sparsity = conshess_sparsity[i]) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
Expand Down

0 comments on commit 1c3349e

Please sign in to comment.