Skip to content

Commit

Permalink
use callback to terminate minibatch tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 15, 2024
1 parent 2a803ff commit 42896b6
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
2 changes: 1 addition & 1 deletion test/diffeqfluxtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ function loss_neuralode(p)
end

iter = 0
callback = function (st, l)
callback = function (st, l, pred...)
global iter
iter += 1

Expand Down
14 changes: 7 additions & 7 deletions test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ end

function callback(state, l) #callback function to observe training
display(l)
return false
return l < 1e-2
end

u0 = Float32[200.0]
Expand Down Expand Up @@ -58,11 +58,11 @@ optfun = OptimizationFunction(loss_adjoint,
Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, pp, train_loader)

res1 = Optimization.solve(optprob,
Optimization.Sophia(; η = 0.5,
λ = 0.0), callback = callback,
maxiters = 1000)
@test 10res1.objective < l1
# res1 = Optimization.solve(optprob,
# Optimization.Sophia(; η = 0.5,
# λ = 0.0), callback = callback,
# maxiters = 1000)
# @test 10res1.objective < l1

optfun = OptimizationFunction(loss_adjoint,
Optimization.AutoForwardDiff())
Expand Down Expand Up @@ -100,7 +100,7 @@ function callback(st, l, pred; doplot = false)
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
end
return false
return l < 1e-3
end

optfun = OptimizationFunction(loss_adjoint,
Expand Down

0 comments on commit 42896b6

Please sign in to comment.