diff --git a/src/transform/state_reconstruct/derivative.jl b/src/transform/state_reconstruct/derivative.jl index 6e0fde9..2f17785 100644 --- a/src/transform/state_reconstruct/derivative.jl +++ b/src/transform/state_reconstruct/derivative.jl @@ -86,7 +86,7 @@ The `vars` that are being reconstructed must appear in sorted order. T == AbstractVector{<:Number} ? Vector{Float64} : T end...} - F! = JITOpaqueClosure{:reconstruct_derivative, goldclass_sig}() do arg_types... + F! = JITOpaqueClosure{with_eps ? :reconstruct_derivative_with_eps : :reconstruct_derivative, goldclass_sig}() do arg_types... ir = copy(ir) ir.argtypes[2:end] .= arg_types @@ -192,9 +192,9 @@ function define_transform_for_reconstruct_der(var_assignment, vars, obs, param_b @assert with_eps eps_ii = epsnum(inst[:type]) input_basis_row = ntuple(neqs + nparams + neps) do active_state_ii - Float64(var_ii == (active_state_ii - neqs + nparams)) + Float64(eps_ii == (active_state_ii - neqs + nparams)) end - replace_call!(ir, ssa, Expr(:call, BatchOfBundles{neqs + nparams + neps}, u_ii, input_basis_row...)) + replace_call!(ir, ssa, Expr(:call, BatchOfBundles{neqs + nparams + neps}, 0., input_basis_row...)) return nothing elseif is_solved_variable(stmt) || is_known_invoke(stmt, observed!, ir) if is_solved_variable(stmt) @@ -248,7 +248,7 @@ function get_reconstruct_der_visit_custom!(var_assignment) end stmt = ir[ssa][:inst] - if is_known_invoke_or_call(stmt, variable, ir) || is_known_invoke_or_call(stmt, state_ddt, ir) + if is_known_invoke_or_call(stmt, variable, ir) || is_known_invoke_or_call(stmt, state_ddt, ir) || is_known_invoke(stmt, epsilon, ir) return true elseif is_known_invoke_or_call(stmt, solved_variable, ir) recurse(stmt.args[end])