Skip to content

Commit

Permalink
ReverseDiff compile mode completed
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 16, 2023
1 parent 3d25be3 commit 09d7cf9
Showing 1 changed file with 108 additions and 31 deletions.
139 changes: 108 additions & 31 deletions ext/OptimizationReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,19 @@ import Optimization.ADTypes: AutoReverseDiff
isdefined(Base, :get_extension) ? (using ReverseDiff, ReverseDiff.ForwardDiff) :
(using ..ReverseDiff, ..ReverseDiff.ForwardDiff)

struct OptimizationReverseDiffTag end

function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
p = SciMLBase.NullParameters(),
num_cons = 0)
_f = (θ, args...) -> first(f.f(θ, p, args...))

if f.grad === nothing

if adtype.compile
_tape = ReverseDiff.GradientTape((x,)) do θ
_f(θ)
end
_tape = ReverseDiff.GradientTape(_f, x)
tape = ReverseDiff.compile(_tape)

grad = function (res, θ, args...)
= ReverseDiff.input_hook(tape)[1]
output = ReverseDiff.output_hook(tape)
ReverseDiff.unseed!(tθ) # clear any "leftover" derivatives from previous calls
ReverseDiff.value!(tθ, θ)
ReverseDiff.forward_pass!(tape)
ReverseDiff.increment_deriv!(output, one(eltype(θ)))
ReverseDiff.reverse_pass!(tape)
copyto!(res, ReverseDiff.deriv(tθ))
nothing
ReverseDiff.gradient!(res, tape, θ)

Check warning on line 22 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L18-L22

Added lines #L18 - L22 were not covered by tests
end
else
cfg = ReverseDiff.GradientConfig(x)
Expand All @@ -40,8 +30,23 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
end

if f.hess === nothing
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(x))
xdual = ForwardDiff.Dual{typeof(T),eltype(x),length(x)}.(x, Ref(ForwardDiff.Partials((ones(eltype(x), length(x))...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
function g(θ)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 40 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L33-L40

Added lines #L33 - L40 were not covered by tests
end
jaccfg = ForwardDiff.JacobianConfig(g, x, ForwardDiff.Chunk(x), T)
hess = function (res, θ, args...)
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())

Check warning on line 44 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L42-L44

Added lines #L42 - L44 were not covered by tests
end
else
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)

Check warning on line 48 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L47-L48

Added lines #L47 - L48 were not covered by tests
end
end
else
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
Expand All @@ -66,19 +71,43 @@ function Optimization.instantiate_function(f, x, adtype::AutoReverseDiff,
end

if cons !== nothing && f.cons_j === nothing
cjconfig = ReverseDiff.JacobianConfig(x)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
if adtype.compile
_jac_tape = ReverseDiff.JacobianTape(cons_oop, x)
jac_tape = ReverseDiff.compile(_jac_tape)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, jac_tape, θ)

Check warning on line 78 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L74-L78

Added lines #L74 - L78 were not covered by tests
end
else
cjconfig = ReverseDiff.JacobianConfig(x)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)

Check warning on line 83 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L81-L83

Added lines #L81 - L83 were not covered by tests
end
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, p)
end

if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
if adtype.compile
consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual))
conshtapes = ReverseDiff.compile.(consh_tapes)
function grad_cons(θ, htape)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 97 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L92-L97

Added lines #L92 - L97 were not covered by tests
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], x, ForwardDiff.Chunk(x), T) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
end

Check warning on line 104 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L99-L104

Added lines #L99 - L104 were not covered by tests
end
else
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
end

Check warning on line 110 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L107-L110

Added lines #L107 - L110 were not covered by tests
end
end
else
Expand All @@ -103,14 +132,38 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
_f = (θ, args...) -> first(f.f(θ, cache.p, args...))

if f.grad === nothing
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ)
if adtype.compile
_tape = ReverseDiff.GradientTape(_f, cache.u0)
tape = ReverseDiff.compile(_tape)
grad = function (res, θ, args...)
ReverseDiff.gradient!(res, tape, θ)

Check warning on line 139 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L135-L139

Added lines #L135 - L139 were not covered by tests
end
else
cfg = ReverseDiff.GradientConfig(cache.u0)
grad = (res, θ, args...) -> ReverseDiff.gradient!(res, x -> _f(x, args...), θ, cfg)

Check warning on line 143 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L142-L143

Added lines #L142 - L143 were not covered by tests
end
else
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
end

if f.hess === nothing
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)
if adtype.compile
T = ForwardDiff.Tag(OptimizationReverseDiffTag(),eltype(cache.u0))
xdual = ForwardDiff.Dual{typeof(T),eltype(cache.u0),length(cache.u0)}.(cache.u0, Ref(ForwardDiff.Partials((ones(eltype(cache.u0), length(cache.u0))...,))))
h_tape = ReverseDiff.GradientTape(_f, xdual)
htape = ReverseDiff.compile(h_tape)
function g(θ)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 157 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L150-L157

Added lines #L150 - L157 were not covered by tests
end
jaccfg = ForwardDiff.JacobianConfig(g, cache.u0, ForwardDiff.Chunk(cache.u0), T)
hess = function (res, θ, args...)
ForwardDiff.jacobian!(res, g, θ, jaccfg, Val{false}())

Check warning on line 161 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L159-L161

Added lines #L159 - L161 were not covered by tests
end
else
hess = function (res, θ, args...)
ReverseDiff.hessian!(res, x -> _f(x, args...), θ)

Check warning on line 165 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L164-L165

Added lines #L164 - L165 were not covered by tests
end
end
else
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
Expand All @@ -135,19 +188,43 @@ function Optimization.instantiate_function(f, cache::Optimization.ReInitCache,
end

if cons !== nothing && f.cons_j === nothing
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)
if adtype.compile
_jac_tape = ReverseDiff.JacobianTape(cons_oop, cache.u0)
jac_tape = ReverseDiff.compile(_jac_tape)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, jac_tape, θ)

Check warning on line 195 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L191-L195

Added lines #L191 - L195 were not covered by tests
end
else
cjconfig = ReverseDiff.JacobianConfig(cache.u0)
cons_j = function (J, θ)
ReverseDiff.jacobian!(J, cons_oop, θ, cjconfig)

Check warning on line 200 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L198-L200

Added lines #L198 - L200 were not covered by tests
end
end
else
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
end

if cons !== nothing && f.cons_h === nothing
fncs = [(x) -> cons_oop(x)[i] for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
if adtype.compile
consh_tapes = ReverseDiff.GradientTape.(fncs, Ref(xdual))
conshtapes = ReverseDiff.compile.(consh_tapes)
function grad_cons(θ, htape)
res1 = zeros(eltype(θ), length(θ))
ReverseDiff.gradient!(res1, htape, θ)

Check warning on line 214 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L209-L214

Added lines #L209 - L214 were not covered by tests
end
gs = [x -> grad_cons(x, conshtapes[i]) for i in 1:num_cons]
jaccfgs = [ForwardDiff.JacobianConfig(gs[i], cache.u0, ForwardDiff.Chunk(cache.u0), T) for i in 1:num_cons]
cons_h = function (res, θ)
for i in 1:num_cons
ForwardDiff.jacobian!(res[i], gs[i], θ, jaccfgs[i], Val{false}())
end

Check warning on line 221 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L216-L221

Added lines #L216 - L221 were not covered by tests
end
else
cons_h = function (res, θ)
for i in 1:num_cons
ReverseDiff.hessian!(res[i], fncs[i], θ)
end

Check warning on line 227 in ext/OptimizationReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/OptimizationReverseDiffExt.jl#L224-L227

Added lines #L224 - L227 were not covered by tests
end
end
else
Expand Down

0 comments on commit 09d7cf9

Please sign in to comment.