Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type error in backward pass #282

Closed
Red-Portal opened this issue Oct 4, 2024 · 14 comments · Fixed by #299
Closed

Type error in backward pass #282

Red-Portal opened this issue Oct 4, 2024 · 14 comments · Fixed by #299
Labels
bug Something isn't working high priority

Comments

@Red-Portal
Copy link

Hi, here's the MWE for the failing test in AdvancedVI.

using Bijectors
using DifferentiationInterface
using Distributions
using Mooncake
using Optimisers
using Random

function f(params, aux)
    (; re, samples, baseline, q_stop) = aux
    q = re(params)

    ℓq = logpdf.(Ref(q), eachcol(samples))
    ℓq_stop = logpdf.(Ref(q_stop), eachcol(samples))
    ℓπ = sum(abs2, samples, dims=1)[1,:]
    ℓπ_mean = mean(ℓπ)
    score_grad = mean(@. ℓq * (ℓπ - baseline))
    score_grad_stop = mean(@. ℓq_stop * (ℓπ - baseline))
    energy = ℓπ_mean + (score_grad - score_grad_stop)
    energy
end

function main()
    rng = Random.default_rng()
    q0  = MvNormal(zeros(3), ones(3))
    b   = Bijectors.Stacked(
        Bijectors.bijector.([LogNormal(0, 1), MvNormal(zeros(2), ones(2))]),
        [1:1, 2:3]
    )
    q  = Bijectors.transformed(q0, Bijectors.inverse(b), )

    params, re = Optimisers.destructure(q)
    
    adtype = AutoMooncake(; config=nothing)
    aux = (
        samples  = rand(rng, q, 10),
        baseline = 1.0,
        re       = re,
        q_stop   = q,
    )
    value_and_gradient(f, adtype, params, Constant(aux))
end
@willtebbutt
Copy link
Member

Thanks for narrowing this down. I won't have time to look at it today unforunately, but I should do on Monday.

@willtebbutt willtebbutt added bug Something isn't working high priority labels Oct 4, 2024
@willtebbutt
Copy link
Member

Alas, the emergence of v1.11 in CI has meant that I need to focus on that. I'll try to ensure that this is resolved by the end of the week.

@Red-Portal
Copy link
Author

The v0.3 release of AdvancedVI will wait until this is resolved!

@willtebbutt
Copy link
Member

Annoyingly the 1.11 transition is taking longer than I had planned, so this isn't going to get a look until next week. Apologies!

@Red-Portal
Copy link
Author

I saw all the works that you're going through. No apology needed! And thanks for all the effort.

@willtebbutt
Copy link
Member

@Red-Portal I believe that this is resolved in 0.4.18. Please let me know if it is, and close this issue if you find that it is!

@willtebbutt willtebbutt reopened this Oct 21, 2024
@willtebbutt
Copy link
Member

edit: I meant 0.4.17

@Red-Portal
Copy link
Author

Running the tests just now! Thanks for all the work.

@Red-Portal
Copy link
Author

@willtebbutt The original error seems to have disappeared, but a new one took its place. Until I can package a MWE, here is the error.

@willtebbutt
Copy link
Member

Thanks for this. I'll take a look later today!

@willtebbutt
Copy link
Member

@Red-Portal this problem should be fixed on 0.4.20 . If you could please let me know, and close this issue if it is resolved.

@yebai
Copy link
Contributor

yebai commented Oct 22, 2024

This now works on Julia 1.11, but it is still broken for Julia 1.0, see https://github.com/TuringLang/AdvancedVI.jl/actions/runs/11459971691/job/31885647008?pr=99#step:6:831

@willtebbutt
Copy link
Member

@yebai this should work on 0.4.24, which will be available shortly. Could you re-start CI once it's released, and let me know if the problem has been resolved?

@yebai
Copy link
Contributor

yebai commented Oct 23, 2024

Thanks @willtebbutt -- it now works.

The remaining errors in TuringLang/AdvancedVI.jl#99 are no longer associated with Mooncake.

@yebai yebai closed this as completed Oct 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working high priority
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants