From a7bafb567cfbda331c3c734784d1ac7d90d2196a Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 8 Jan 2024 11:10:56 +0100 Subject: [PATCH] Fix AbstractMCMC 5 compatibility (#2153) --- Project.toml | 4 ++-- src/mcmc/emcee.jl | 8 ++++---- src/mcmc/hmc.jl | 6 +++--- test/Project.toml | 2 +- test/mcmc/emcee.jl | 8 ++++---- test/mcmc/hmc.jl | 6 +++--- 6 files changed, 17 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 041b2635a..b274792dc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.30.0" +version = "0.30.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -46,7 +46,7 @@ TuringOptimExt = "Optim" [compat] ADTypes = "0.2" -AbstractMCMC = "4, 5" +AbstractMCMC = "5" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" AdvancedMH = "0.8" AdvancedPS = "0.5.4" diff --git a/src/mcmc/emcee.jl b/src/mcmc/emcee.jl index d41596075..f46b4fd7d 100644 --- a/src/mcmc/emcee.jl +++ b/src/mcmc/emcee.jl @@ -25,7 +25,7 @@ function AbstractMCMC.step( model::Model, spl::Sampler{<:Emcee}; resume_from = nothing, - init_params = nothing, + initial_params = nothing, kwargs... ) if resume_from !== nothing @@ -38,11 +38,11 @@ function AbstractMCMC.step( vis = [VarInfo(rng, model, SampleFromPrior()) for _ in 1:n] # Update the parameters if provided. - if init_params !== nothing - length(init_params) == n || throw( + if initial_params !== nothing + length(initial_params) == n || throw( ArgumentError("initial parameters have to be specified for each walker") ) - vis = map(vis, init_params) do vi, init + vis = map(vis, initial_params) do vi, init vi = DynamicPPL.initialize_parameters!!(vi, init, spl, model) # Update log joint probability. diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index d9a47f5c5..e5e93d69a 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -132,7 +132,7 @@ function DynamicPPL.initialstep( model::AbstractModel, spl::Sampler{<:Hamiltonian}, vi::AbstractVarInfo; - init_params=nothing, + initial_params=nothing, nadapts=0, kwargs... ) @@ -164,11 +164,11 @@ function DynamicPPL.initialstep( # If no initial parameters are provided, resample until the log probability # and its gradient are finite. - if init_params === nothing + if initial_params === nothing init_attempt_count = 1 while !isfinite(z) if init_attempt_count == 10 - @warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `init_params` keyword" + @warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword" end # NOTE: This will sample in the unconstrained space. diff --git a/test/Project.toml b/test/Project.toml index cb51a68ce..cac7ecd5d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -32,7 +32,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] -AbstractMCMC = "4, 5" +AbstractMCMC = "5" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "0.5.4" AdvancedVI = "0.2" diff --git a/test/mcmc/emcee.jl b/test/mcmc/emcee.jl index 100cfc43e..929506f95 100644 --- a/test/mcmc/emcee.jl +++ b/test/mcmc/emcee.jl @@ -25,18 +25,18 @@ nwalkers = 250 spl = Emcee(nwalkers, 2.0) - # No initial parameters, with im- and explicit `init_params=nothing` + # No initial parameters, with im- and explicit `initial_params=nothing` Random.seed!(1234) chain1 = sample(gdemo_default, spl, 1) Random.seed!(1234) - chain2 = sample(gdemo_default, spl, 1; init_params=nothing) + chain2 = sample(gdemo_default, spl, 1; initial_params=nothing) @test Array(chain1) == Array(chain2) # Initial parameters have to be specified for every walker - @test_throws ArgumentError sample(gdemo_default, spl, 1; init_params=[2.0, 1.0]) + @test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=[2.0, 1.0]) # Initial parameters - chain = sample(gdemo_default, spl, 1; init_params=fill([2.0, 1.0], nwalkers)) + chain = sample(gdemo_default, spl, 1; initial_params=fill([2.0, 1.0], nwalkers)) @test chain[:s] == fill(2.0, 1, nwalkers) @test chain[:m] == fill(1.0, 1, nwalkers) end diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index fe18fa773..3ffd8de06 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -232,7 +232,7 @@ @turing_testset "warning for difficult init params" begin attempt = 0 - @model function demo_warn_init_params() + @model function demo_warn_initial_params() x ~ Normal() if (attempt += 1) < 30 Turing.@addlogprob! -Inf @@ -241,9 +241,9 @@ @test_logs ( :warn, - "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `init_params` keyword", + "failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword", ) (:info,) match_mode=:any begin - sample(demo_warn_init_params(), NUTS(; adtype=adbackend), 5) + sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5) end end end