Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NaN in gradients #6

Closed
SebastianM-C opened this issue Apr 6, 2024 · 15 comments · Fixed by #8
Closed

NaN in gradients #6

SebastianM-C opened this issue Apr 6, 2024 · 15 comments · Fixed by #8

Comments

@SebastianM-C
Copy link
Collaborator

When working locally on this I initially encountered an issue where the gradient would always be NaN, which is what I think it's causing #5. I enabled NaN safe mode and that seemed to fix the issue. Is that a bug or should we just document this?

Also, what's the best way of setting this up in CI?

@ChrisRackauckas
Copy link
Member

What do you mean by NaN safe mode?

@SebastianM-C
Copy link
Collaborator Author

@ChrisRackauckas
Copy link
Member

You shouldn't need that here.

@ChrisRackauckas
Copy link
Member

Why do you get a NaN in the first place?

@SebastianM-C SebastianM-C added the bug Something isn't working label Apr 7, 2024
@SebastianM-C
Copy link
Collaborator Author

SebastianM-C commented Apr 7, 2024

It looks like the problem comes from the loss computation

using ForwardDiff
x0′ = ForwardDiff.Dual{:tag}.(x0, 1)

test_p = SciMLStructures.replace(Tunable(), prob.p, x0′)
test_prob = remake(prob, p = test_p)

test_sol = solve(test_prob, Rodas4(autodiff=false), saveat=sol_ref.t)
sum(sqrt.(abs2.(get_vars(test_sol, 1) .- get_refs(sol_ref, 1))))

gives

Dual{:tag}(0.0,NaN)

I also see that NaNs appear if I print in the loss with

for i in eachindex(new_sol.u)
        loss += sum(sqrt.(abs2.(get_vars(new_sol, i) .- get_refs(sol_ref, i))))
        if any(isnan.(ForwardDiff.partials(loss)))
            @info i
        end
    end

@ChrisRackauckas
Copy link
Member

What's the first spot of nan?

@SebastianM-C
Copy link
Collaborator Author

It's due to sqrt.

@ChrisRackauckas
Copy link
Member

What are the values?

@SebastianM-C
Copy link
Collaborator Author

julia> sqrt.(abs2.(get_vars(test_sol, 1) .- get_refs(sol_ref, 1)))
2-element Vector{ForwardDiff.Dual{:tag, Float64, 1}}:
 Dual{:tag}(0.0,NaN)
 Dual{:tag}(0.0,NaN)

@ChrisRackauckas
Copy link
Member

Yes but what are the values that go in?

@SebastianM-C
Copy link
Collaborator Author

julia> get_vars(test_sol, 1)
2-element Vector{ForwardDiff.Dual{:tag, Float64, 1}}:
 Dual{:tag}(3.1,0.0)
 Dual{:tag}(1.5,0.0)

julia> get_refs(sol_ref, 1)
2-element Vector{Float64}:
 3.1
 1.5

Hmm, let me check why they are the same 🤔

@SebastianM-C
Copy link
Collaborator Author

aah, it because we start with the same initial conditions

sum(sqrt.(abs2.(get_vars(test_sol, 2) .- get_refs(sol_ref, 2)))

gives Dual{:tag}(0.2685941909005718,0.08984350230039442)

@ChrisRackauckas
Copy link
Member

ChrisRackauckas commented Apr 7, 2024

Yeah the gradient at zero is NaN for sqrt. That seems like a loss function issue.

@SebastianM-C
Copy link
Collaborator Author

SebastianM-C commented Apr 7, 2024

I started with the same initial conditions as in https://docs.sciml.ai/Overview/stable/showcase/missing_physics/, which means that at the very first time point we get 0 and NaN in the gradient, which ends up poisoning the whole loss.

@SebastianM-C
Copy link
Collaborator Author

So not a bug, but we should document this.

@SebastianM-C SebastianM-C removed the bug Something isn't working label Apr 7, 2024
@SebastianM-C SebastianM-C mentioned this issue Apr 7, 2024
5 tasks
@SebastianM-C SebastianM-C mentioned this issue Apr 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants