Skip to content

Commit

Permalink
tests pass now pls
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 13, 2024
1 parent 6e1999d commit 2a803ff
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ LinearAlgebra = "1.10"
Logging = "1.10"
LoggingExtras = "0.4, 1"
MLUtils = "0.4.4"
OptimizationBase = "2.0.2"
OptimizationBase = "2.0.3"
Printf = "1.10"
ProgressLogging = "0.1"
Reexport = "1.2"
Expand Down
10 changes: 5 additions & 5 deletions src/sophia.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Optimization.LinearAlgebra
using Optimization.LinearAlgebra, MLUtils

struct Sophia
η::Float64
Expand Down Expand Up @@ -80,14 +80,14 @@ function SciMLBase.__solve(cache::OptimizationCache{
for _ in 1:maxiters
for (i, d) in enumerate(data)
if cache.f.fg !== nothing && dataiterate
x = cache.f.fg(G, θ, d)
x = cache.f.fg(gₜ, θ, d)
elseif dataiterate
cache.f.grad(G, θ, d)
cache.f.grad(gₜ, θ, d)
x = cache.f(θ, d)
elseif cache.f.fg !== nothing
x = cache.f.fg(G, θ)
x = cache.f.fg(gₜ, θ)
else
cache.f.grad(G, θ)
cache.f.grad(gₜ, θ)
x = cache.f(θ)
end
opt_state = Optimization.OptimizationState(; iter = i,
Expand Down
2 changes: 1 addition & 1 deletion test/minibatch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ function dudt_(u, p, t)
ann(u, p, st)[1] .* u
end

function callback(state, l, pred) #callback function to observe training
function callback(state, l) #callback function to observe training
display(l)
return false
end
Expand Down

0 comments on commit 2a803ff

Please sign in to comment.