Skip to content

Commit

Permalink
Merge pull request #161 from devmotion/turing_improvements
Browse files Browse the repository at this point in the history
Remove warning about internal variables
  • Loading branch information
ChrisRackauckas authored May 9, 2020
2 parents 3282d66 + f2d5ea6 commit 8b99b42
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
9 changes: 4 additions & 5 deletions src/turing_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function turing_inference(
kwargs...,
)
N = length(priors)
Turing.@model mf(x, ::Type{T} = Float64) where {T <: Real} = begin
Turing.@model function mf(x, ::Type{T} = Float64) where {T <: Real}
theta = Vector{T}(undef, length(priors))
for i in 1:length(priors)
theta[i] ~ NamedDist(priors[i], syms[i])
Expand All @@ -39,13 +39,12 @@ function turing_inference(
push!(u0, convert(T,prob.u0[i]))
end
end
_saveat = isnothing(t) ? Float64[] : t
_saveat = t === nothing ? Float64[] : t
sol = concrete_solve(prob, alg, u0, p; saveat = _saveat, progress = progress, save_idxs = save_idxs, kwargs...)
failure = size(sol, 2) < length(_saveat)

if failure
S = typeof(Turing.Inference.getlogp(_varinfo))
Turing.Inference.acclogp!(_varinfo, S(-Inf))
Turing.acclogp!(_varinfo, -Inf)
return
end
if ndims(sol) == 1
Expand All @@ -56,7 +55,7 @@ function turing_inference(
end
end
return
end
end false

# Instantiate a Model object.
model = mf(data)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ const is_TRAVIS = haskey(ENV,"TRAVIS")
@time begin
if GROUP == "All" || GROUP == "Core"
@time @safetestset "DynamicHMC" begin include("dynamicHMC.jl") end
@time @safetestset "Turing" begin include("turing.jl") end # Doesn't work on v0.6
@time @safetestset "Turing" begin include("turing.jl") end
@time @safetestset "ABC" begin include("abc.jl") end
end
end
Expand Down

0 comments on commit 8b99b42

Please sign in to comment.