Skip to content

Commit

Permalink
Fix AbstractMCMC 5 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jan 7, 2024
1 parent 99a121c commit b04db61
Show file tree
Hide file tree
Showing 6 changed files with 17 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.30.0"
version = "0.30.1"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down Expand Up @@ -46,7 +46,7 @@ TuringOptimExt = "Optim"

[compat]
ADTypes = "0.2"
AbstractMCMC = "4, 5"
AbstractMCMC = "5"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6"
AdvancedMH = "0.8"
AdvancedPS = "0.5.4"
Expand Down
8 changes: 4 additions & 4 deletions src/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function AbstractMCMC.step(
model::Model,
spl::Sampler{<:Emcee};
resume_from = nothing,
init_params = nothing,
initial_params = nothing,
kwargs...
)
if resume_from !== nothing
Expand All @@ -38,11 +38,11 @@ function AbstractMCMC.step(
vis = [VarInfo(rng, model, SampleFromPrior()) for _ in 1:n]

# Update the parameters if provided.
if init_params !== nothing
length(init_params) == n || throw(
if initial_params !== nothing
length(initial_params) == n || throw(

Check warning on line 42 in src/mcmc/emcee.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/emcee.jl#L41-L42

Added lines #L41 - L42 were not covered by tests
ArgumentError("initial parameters have to be specified for each walker")
)
vis = map(vis, init_params) do vi, init
vis = map(vis, initial_params) do vi, init

Check warning on line 45 in src/mcmc/emcee.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/emcee.jl#L45

Added line #L45 was not covered by tests
vi = DynamicPPL.initialize_parameters!!(vi, init, spl, model)

# Update log joint probability.
Expand Down
6 changes: 3 additions & 3 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ function DynamicPPL.initialstep(
model::AbstractModel,
spl::Sampler{<:Hamiltonian},
vi::AbstractVarInfo;
init_params=nothing,
initial_params=nothing,
nadapts=0,
kwargs...
)
Expand Down Expand Up @@ -164,11 +164,11 @@ function DynamicPPL.initialstep(

# If no initial parameters are provided, resample until the log probability
# and its gradient are finite.
if init_params === nothing
if initial_params === nothing

Check warning on line 167 in src/mcmc/hmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/hmc.jl#L167

Added line #L167 was not covered by tests
init_attempt_count = 1
while !isfinite(z)
if init_attempt_count == 10
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `init_params` keyword"
@warn "failed to find valid initial parameters in $(init_attempt_count) tries; consider providing explicit initial parameters using the `initial_params` keyword"

Check warning on line 171 in src/mcmc/hmc.jl

View check run for this annotation

Codecov / codecov/patch

src/mcmc/hmc.jl#L171

Added line #L171 was not covered by tests
end

# NOTE: This will sample in the unconstrained space.
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractMCMC = "4, 5"
AbstractMCMC = "5"
AdvancedMH = "0.6, 0.7, 0.8"
AdvancedPS = "0.5.4"
AdvancedVI = "0.2"
Expand Down
8 changes: 4 additions & 4 deletions test/mcmc/emcee.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,18 @@
nwalkers = 250
spl = Emcee(nwalkers, 2.0)

# No initial parameters, with im- and explicit `init_params=nothing`
# No initial parameters, with im- and explicit `initial_params=nothing`
Random.seed!(1234)
chain1 = sample(gdemo_default, spl, 1)
Random.seed!(1234)
chain2 = sample(gdemo_default, spl, 1; init_params=nothing)
chain2 = sample(gdemo_default, spl, 1; initial_params=nothing)
@test Array(chain1) == Array(chain2)

# Initial parameters have to be specified for every walker
@test_throws ArgumentError sample(gdemo_default, spl, 1; init_params=[2.0, 1.0])
@test_throws ArgumentError sample(gdemo_default, spl, 1; initial_params=[2.0, 1.0])

# Initial parameters
chain = sample(gdemo_default, spl, 1; init_params=fill([2.0, 1.0], nwalkers))
chain = sample(gdemo_default, spl, 1; initial_params=fill([2.0, 1.0], nwalkers))
@test chain[:s] == fill(2.0, 1, nwalkers)
@test chain[:m] == fill(1.0, 1, nwalkers)
end
Expand Down
6 changes: 3 additions & 3 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@

@turing_testset "warning for difficult init params" begin
attempt = 0
@model function demo_warn_init_params()
@model function demo_warn_initial_params()
x ~ Normal()
if (attempt += 1) < 30
Turing.@addlogprob! -Inf
Expand All @@ -241,9 +241,9 @@

@test_logs (
:warn,
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `init_params` keyword",
"failed to find valid initial parameters in 10 tries; consider providing explicit initial parameters using the `initial_params` keyword",
) (:info,) match_mode=:any begin
sample(demo_warn_init_params(), NUTS(; adtype=adbackend), 5)
sample(demo_warn_initial_params(), NUTS(; adtype=adbackend), 5)
end
end
end

0 comments on commit b04db61

Please sign in to comment.