Skip to content

Commit

Permalink
Merge pull request #601 from SciML/Vaibhavdixit02-patch-3
Browse files Browse the repository at this point in the history
Add callback to MOI
  • Loading branch information
ChrisRackauckas authored Oct 5, 2023
2 parents 7805260 + cb304de commit eddf8c4
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 11 deletions.
2 changes: 1 addition & 1 deletion lib/OptimizationMOI/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "OptimizationMOI"
uuid = "fd9f6733-72f4-499f-8506-86b2bdd0dea1"
authors = ["Vaibhav Dixit <[email protected]> and contributors"]
version = "0.1.15"
version = "0.1.16"

[deps]
Ipopt_jll = "9cc047cb-c261-5740-88fc-0cf96f7bdcc7"
Expand Down
8 changes: 4 additions & 4 deletions lib/OptimizationMOI/src/OptimizationMOI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@ function SciMLBase.allowsconstraints(opt::Union{MOI.AbstractOptimizer,
MOI.OptimizerWithAttributes})
true
end
function SciMLBase.allowscallback(opt::Union{MOI.AbstractOptimizer,
MOI.OptimizerWithAttributes})
false
end
# function SciMLBase.allowscallback(opt::Union{MOI.AbstractOptimizer,
# MOI.OptimizerWithAttributes})
# false
# end

function _create_new_optimizer(opt::MOI.OptimizerWithAttributes)
return _create_new_optimizer(MOI.instantiate(opt, with_bridge_type = Float64))
Expand Down
21 changes: 17 additions & 4 deletions lib/OptimizationMOI/src/nlp.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
mutable struct MOIOptimizationNLPEvaluator{T, F <: OptimizationFunction, RC, LB, UB,
I,
JT <: DenseOrSparse{T}, HT <: DenseOrSparse{T},
CHT <: DenseOrSparse{T}, S} <:
CHT <: DenseOrSparse{T}, S, CB} <:
MOI.AbstractNLPEvaluator
f::F
reinit_cache::RC
Expand All @@ -14,6 +14,7 @@ mutable struct MOIOptimizationNLPEvaluator{T, F <: OptimizationFunction, RC, LB,
J::JT
H::HT
cons_H::Vector{CHT}
callback::CB
end

function Base.getproperty(evaluator::MOIOptimizationNLPEvaluator, x::Symbol)
Expand Down Expand Up @@ -101,7 +102,7 @@ function SciMLBase.get_paramsyms(sol::SciMLBase.OptimizationSolution{
sol.cache.evaluator.f.paramsyms
end

function MOIOptimizationNLPCache(prob::OptimizationProblem, opt; kwargs...)
function MOIOptimizationNLPCache(prob::OptimizationProblem, opt; callback = nothing, kwargs...)
reinit_cache = Optimization.ReInitCache(prob.u0, prob.p) # everything that can be changed via `reinit`

num_cons = prob.ucons === nothing ? 0 : length(prob.ucons)
Expand Down Expand Up @@ -142,7 +143,8 @@ function MOIOptimizationNLPCache(prob::OptimizationProblem, opt; kwargs...)
prob.sense,
J,
H,
cons_H)
cons_H,
callback)
return MOIOptimizationNLPCache(evaluator, opt, NamedTuple(kwargs))
end

Expand All @@ -169,7 +171,13 @@ function MOI.initialize(evaluator::MOIOptimizationNLPEvaluator,
end

function MOI.eval_objective(evaluator::MOIOptimizationNLPEvaluator, x)
return evaluator.f(x, evaluator.p)
if evaluator.callback === nothing
return evaluator.f(x, evaluator.p)
else
l = evaluator.f(x, evaluator.p)
evaluator.callback(x, l)
return l
end
end

function MOI.eval_constraint(evaluator::MOIOptimizationNLPEvaluator, g, x)
Expand Down Expand Up @@ -406,6 +414,11 @@ function SciMLBase.__solve(cache::MOIOptimizationNLPCache)
MOI.set(opt_setup,
MOI.NLPBlock(),
MOI.NLPBlockData(con_bounds, cache.evaluator, true))

if cache.evaluator.callback !== nothing
MOI.set(opt_setup, MOI.Silent(), true)
end

MOI.optimize!(opt_setup)
if MOI.get(opt_setup, MOI.ResultCount()) >= 1
minimizer = MOI.get(opt_setup, MOI.VariablePrimal(), θ)
Expand Down
12 changes: 10 additions & 2 deletions lib/OptimizationMOI/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,16 @@ end

optprob = OptimizationFunction((x, p) -> -rosenbrock(x, p), Optimization.AutoZygote())
prob = OptimizationProblem(optprob, x0, _p; sense = Optimization.MaxSense)

sol = solve(prob, Ipopt.Optimizer())
global iter = 0
callback = function (p, l)
global iter
iter += 1

display(l)
return false
end

sol = solve(prob, Ipopt.Optimizer(); callback)
@test 10 * sol.objective < l1

# cache interface
Expand Down

0 comments on commit eddf8c4

Please sign in to comment.