diff --git a/src/turing_inference.jl b/src/turing_inference.jl index 8b673f15..70555007 100644 --- a/src/turing_inference.jl +++ b/src/turing_inference.jl @@ -21,7 +21,7 @@ function turing_inference( kwargs..., ) N = length(priors) - Turing.@model mf(x, ::Type{T} = Float64) where {T <: Real} = begin + Turing.@model function mf(x, ::Type{T} = Float64) where {T <: Real} theta = Vector{T}(undef, length(priors)) for i in 1:length(priors) theta[i] ~ NamedDist(priors[i], syms[i]) @@ -39,13 +39,12 @@ function turing_inference( push!(u0, convert(T,prob.u0[i])) end end - _saveat = isnothing(t) ? Float64[] : t + _saveat = t === nothing ? Float64[] : t sol = concrete_solve(prob, alg, u0, p; saveat = _saveat, progress = progress, save_idxs = save_idxs, kwargs...) failure = size(sol, 2) < length(_saveat) if failure - S = typeof(Turing.Inference.getlogp(_varinfo)) - Turing.Inference.acclogp!(_varinfo, S(-Inf)) + Turing.acclogp!(_varinfo, -Inf) return end if ndims(sol) == 1 @@ -56,7 +55,7 @@ function turing_inference( end end return - end + end false # Instantiate a Model object. model = mf(data) diff --git a/test/runtests.jl b/test/runtests.jl index f4f3f9e0..529cc581 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,7 +8,7 @@ const is_TRAVIS = haskey(ENV,"TRAVIS") @time begin if GROUP == "All" || GROUP == "Core" @time @safetestset "DynamicHMC" begin include("dynamicHMC.jl") end - @time @safetestset "Turing" begin include("turing.jl") end # Doesn't work on v0.6 + @time @safetestset "Turing" begin include("turing.jl") end @time @safetestset "ABC" begin include("abc.jl") end end end