Skip to content

Commit

Permalink
Merge pull request #173 from SciML/noncrete_solve
Browse files Browse the repository at this point in the history
end of concrete_solve
  • Loading branch information
Vaibhavdixit02 authored May 30, 2020
2 parents 500a6e0 + 140cc2e commit acb0219
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"

[compat]
ApproxBayes = "0.3"
DiffEqBase = "6.5"
DiffEqBase = "6.36"
DiffResults = "0.0.4, 1.0"
Distances = "0.8, 0.9"
Distributions = "0.21, 0.22, 0.23"
Expand Down
2 changes: 1 addition & 1 deletion src/abc_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ function createabcfunction(prob, t, distancefunction, alg; save_idxs = nothing,
else
u0 = prob.u0
end
sol = concrete_solve(STANDARD_PROB_GENERATOR(prob, params), alg, u0; saveat = t, save_idxs = save_idxs, kwargs...)
sol = solve(prob, alg, u0=u0, p=params, saveat = t, save_idxs = save_idxs, kwargs...)
if size(sol, 2) < length(t)
return Inf,nothing
else
Expand Down
10 changes: 5 additions & 5 deletions src/dynamichmc_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ function (P::DynamicHMCPosterior)(θ)
u0 = convert.(T, sample_u0 ? parameters[1:nu] : problem.u0)
p = convert.(T, sample_u0 ? parameters[(nu + 1):end] : parameters)
if length(u0) < length(problem.u0)
# assumes u is ordered such that the observed variables are in the begining, consistent with ordered theta
# assumes u is ordered such that the observed variables are in the begining, consistent with ordered theta
for i in length(u0):length(problem.u0)
push!(u0, convert(T,problem.u0[i]))
end
end
_saveat = t === nothing ? Float64[] : t
sol = concrete_solve(problem, algorithm, u0, p; saveat = _saveat, save_idxs = save_idxs, solve_kwargs...)
sol = solve(problem, algorithm; u0=u0, p=p, saveat = _saveat, save_idxs = save_idxs, solve_kwargs...)
failure = size(sol, 2) < length(_saveat)
failure && return T(0) * sum(σ) + T(-Inf)
log_likelihood = sum(sum(map(logpdf, Normal.(0.0, σ), sol[:, i] .- data[:, i])) for (i, t) in enumerate(t))
Expand Down Expand Up @@ -93,15 +93,15 @@ posterior values (transformed from `ℝⁿ`).
function dynamichmc_inference(problem::DiffEqBase.DEProblem, algorithm, t, data,
parameter_priors, parameter_transformations=as(Vector, asℝ₊, length(parameter_priors));
σ_priors = fill(Normal(0, 5), size(data, 1)),sample_u0 = false, rng = Random.GLOBAL_RNG,
num_samples = 1000, AD_gradient_kind = Val(:ForwardDiff), save_idxs = nothing,solve_kwargs = (),
num_samples = 1000, AD_gradient_kind = Val(:ForwardDiff), save_idxs = nothing,solve_kwargs = (),
mcmc_kwargs = (initialization = (q = zeros(length(parameter_priors) + (save_idxs === nothing ? length(data[:,1]) : length(save_idxs))),),))
P = DynamicHMCPosterior(; algorithm = algorithm, problem = problem, t = t, data = data,
parameter_priors = parameter_priors, σ_priors = σ_priors,
parameter_priors = parameter_priors, σ_priors = σ_priors,
solve_kwargs = solve_kwargs, sample_u0 = sample_u0, save_idxs = save_idxs)
trans = as((parameters = parameter_transformations,
σ = as(Vector, asℝ₊, length(σ_priors))))
= TransformedLogDensity(trans, P)
∇ℓ = LogDensityProblems.ADgradient(AD_gradient_kind, ℓ)
results = mcmc_with_warmup(rng, ∇ℓ, num_samples; mcmc_kwargs...)
merge((posterior = transform.(Ref(trans), results.chain), ), results)
end
end
8 changes: 4 additions & 4 deletions src/turing_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ function turing_inference(
num_samples=1000, sampler = Turing.NUTS(0.65),
syms = [Turing.@varname(theta[i]) for i in 1:length(priors)],
sample_u0 = false,
save_idxs = nothing,
progress = false,
save_idxs = nothing,
progress = false,
kwargs...,
)
N = length(priors)
Expand All @@ -29,13 +29,13 @@ function turing_inference(
u0 = convert.(T, sample_u0 ? theta[1:nu] : prob.u0)
p = convert.(T, sample_u0 ? theta[(nu + 1):end] : theta)
if length(u0) < length(prob.u0)
# assumes u is ordered such that the observed variables are in the begining, consistent with ordered theta
# assumes u is ordered such that the observed variables are in the begining, consistent with ordered theta
for i in length(u0):length(prob.u0)
push!(u0, convert(T,prob.u0[i]))
end
end
_saveat = t === nothing ? Float64[] : t
sol = concrete_solve(prob, alg, u0, p; saveat = _saveat, progress = progress, save_idxs = save_idxs, kwargs...)
sol = solve(prob, alg; u0=u0, p=p, saveat = _saveat, progress = progress, save_idxs = save_idxs, kwargs...)
failure = size(sol, 2) < length(_saveat)

if failure
Expand Down

0 comments on commit acb0219

Please sign in to comment.