Skip to content

Commit

Permalink
Merge pull request #649 from SciML/Vaibhavdixit02-patch-4
Browse files Browse the repository at this point in the history
Pass state to OptimizationOptimisers callback
  • Loading branch information
Vaibhavdixit02 authored Dec 27, 2023
2 parents f9f7dfb + 7f2f357 commit 80d8465
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
10 changes: 5 additions & 5 deletions lib/OptimizationOptimisers/src/OptimizationOptimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,18 +63,18 @@ function SciMLBase.__solve(cache::OptimizationCache{
Optimization.@withprogress cache.progress name="Training" begin
for (i, d) in enumerate(data)
cache.f.grad(G, θ, d...)
x = cache.f(θ, cache.p, d...)
x = (cache.f(θ, cache.p, d...), state, i)
cb_call = cache.callback(θ, x...)
if !(cb_call isa Bool)
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the sciml_train documentation for information.")
error("The callback should return a boolean `halt` for whether to stop the optimization process. Please see the `solve` documentation for information.")
elseif cb_call
break
end
msg = @sprintf("loss: %.3g", x[1])
msg = @sprintf("loss: %.3g", first(x)[1])
cache.progress && ProgressLogging.@logprogress msg i/maxiters

if cache.solver_args.save_best
if first(x) < first(min_err) #found a better solution
if first(x)[1] < first(min_err)[1] #found a better solution
min_opt = opt
min_err = x
min_θ = copy(θ)
Expand All @@ -93,7 +93,7 @@ function SciMLBase.__solve(cache::OptimizationCache{

t1 = time()

SciMLBase.build_solution(cache, cache.opt, θ, x[1], solve_time = t1 - t0)
SciMLBase.build_solution(cache, cache.opt, θ, first(x)[1], solve_time = t1 - t0)
# here should be build_solution to create the output message
end

Expand Down
16 changes: 16 additions & 0 deletions lib/OptimizationOptimisers/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,20 @@ using Zygote
sol = Optimization.solve!(cache)
@test sol.u[2.0] atol=1e-3
end

@testset "callback" begin
rosenbrock(x, p) = (p[1] - x[1])^2 + p[2] * (x[2] - x[1]^2)^2
x0 = zeros(2)
_p = [1.0, 100.0]
l1 = rosenbrock(x0, _p)

optprob = OptimizationFunction(rosenbrock, Optimization.AutoZygote())

prob = OptimizationProblem(optprob, x0, _p)
function callback(θ, l, state, iter)
Optimisers.adjust!(state, 0.1/iter)
return false
end
sol = solve(prob, Optimisers.Adam(0.1), maxiters = 1000, progress = false, callback = callback)
end
end
2 changes: 1 addition & 1 deletion test/diffeqfluxtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ function loss_neuralode(p)
end

iter = 0
callback = function (p, l, pred)
callback = function (p, l, pred, args...)
global iter
iter += 1

Expand Down
2 changes: 1 addition & 1 deletion test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function dudt_(u, p, t)
ann(u, p, st)[1] .* u
end

callback = function (p, l, pred; doplot = false) #callback function to observe training
callback = function (p, l, pred, args...; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
Expand Down

0 comments on commit 80d8465

Please sign in to comment.