Skip to content

Commit

Permalink
Update docs remove extra returns from loss and extra args from callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Vaibhavdixit02 committed Sep 28, 2024
1 parent 63cd6dd commit 2b83735
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 21 deletions.
29 changes: 12 additions & 17 deletions docs/src/getting_started/fit_simulation.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,14 @@ function loss(newp)
newprob = remake(prob, p = newp)
sol = solve(newprob, saveat = 1)
loss = sum(abs2, sol .- xy_data)
return loss, sol
return loss
end
# Define a callback function to monitor optimization progress
function callback(p, l, sol)
function callback(state, l)
display(l)
newprob = remake(prob, p = state.u)
sol = solve(newprob, saveat = 1)
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
display(plt)
Expand Down Expand Up @@ -277,37 +279,28 @@ function loss(newp)
newprob = remake(prob, p = newp)
sol = solve(newprob, saveat = 1)
l = sum(abs2, sol .- xy_data)
return l, sol
return l
end
```

Notice that our loss function returns the loss value as the first return,
but returns extra information (the ODE solution with the new parameters)
as an extra return argument.
We will explain why this extra return information is helpful in the next section.

### Step 5: Solve the Optimization Problem

This step will look very similar to [the first optimization tutorial](@ref first_opt),
except now we have a new loss function `loss` which returns both the loss value
and the associated ODE solution.
(In the previous tutorial, `L` only returned the loss value.)
The `Optimization.solve` function can accept an optional callback function
to monitor the optimization process using extra arguments returned from `loss`.

The callback syntax is always:

```
callback(
optimization variables,
state,
the current loss value,
other arguments returned from the loss function, ...
)
```

In this case, we will provide the callback the arguments `(p, l, sol)`,
since it always takes the current state of the optimization first (`p`)
then the returns from the loss function (`l, sol`).
In this case, we will provide the callback the arguments `(state, l)`,
since it always takes the current state of the optimization first (`state`)
then the current loss value (`l`).
The return value of the callback function should default to `false`.
`Optimization.solve` will halt if/when the callback function returns `true` instead.
Typically the `return` statement would monitor the loss value
Expand All @@ -317,8 +310,10 @@ More details about callbacks in Optimization.jl can be found
[here](https://docs.sciml.ai/Optimization/stable/API/solve/).

```@example odefit
function callback(p, l, sol)
function callback(p, l)
display(l)
newprob = remake(prob, p = p)
sol = solve(newprob, saveat = 1)
plt = plot(sol, ylim = (0, 6), label = ["Current x Prediction" "Current y Prediction"])
scatter!(plt, t_data, xy_data', label = ["x Data" "y Data"])
display(plt)
Expand Down
5 changes: 3 additions & 2 deletions docs/src/showcase/blackhole.md
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ function loss(NN_params)
pred_waveform = compute_waveform(dt_data, pred, mass_ratio, model_params)[1]
loss = ( sum(abs2, view(waveform,obs_to_use_for_training) .- view(pred_waveform,obs_to_use_for_training) ) )
return loss, pred_waveform
return loss
end
```

Expand All @@ -510,10 +510,11 @@ We'll use the following callback to save the history of the loss values.
```@example ude
losses = []
callback(θ,l,pred_waveform; doplot = true) = begin
callback(state, l; doplot = true) = begin
push!(losses, l)
#= Disable plotting as it trains since in docs
display(l)
waveform = compute_waveform(dt_data, soln, mass_ratio, model_params)[1]
# plot current prediction against data
plt = plot(tsteps, waveform,
markershape=:circle, markeralpha = 0.25,
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/missing_physics.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ current loss:
```@example ude
losses = Float64[]
callback = function (p, l)
callback = function (state, l)
push!(losses, l)
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
Expand Down
2 changes: 1 addition & 1 deletion docs/src/showcase/pinngpu.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ prob = discretize(pde_system, discretization)
## Step 6: Solve the Optimization Problem

```@example pinn
callback = function (p, l)
callback = function (state, l)
println("Current loss is: $l")
return false
end
Expand Down

0 comments on commit 2b83735

Please sign in to comment.