Skip to content

Commit

Permalink
Fix epsilon derivatives (#5)
Browse files Browse the repository at this point in the history
Fixes inverter_noise test in Cedar.
  • Loading branch information
Keno authored Aug 2, 2024
1 parent ef6e297 commit 9afb5c5
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/transform/state_reconstruct/derivative.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ The `vars` that are being reconstructed must appear in sorted order.
T == AbstractVector{<:Number} ? Vector{Float64} :
T
end...}
F! = JITOpaqueClosure{:reconstruct_derivative, goldclass_sig}() do arg_types...
F! = JITOpaqueClosure{with_eps ? :reconstruct_derivative_with_eps : :reconstruct_derivative, goldclass_sig}() do arg_types...
ir = copy(ir)
ir.argtypes[2:end] .= arg_types

Expand Down Expand Up @@ -192,9 +192,9 @@ function define_transform_for_reconstruct_der(var_assignment, vars, obs, param_b
@assert with_eps
eps_ii = epsnum(inst[:type])
input_basis_row = ntuple(neqs + nparams + neps) do active_state_ii
Float64(var_ii == (active_state_ii - neqs + nparams))
Float64(eps_ii == (active_state_ii - neqs + nparams))
end
replace_call!(ir, ssa, Expr(:call, BatchOfBundles{neqs + nparams + neps}, u_ii, input_basis_row...))
replace_call!(ir, ssa, Expr(:call, BatchOfBundles{neqs + nparams + neps}, 0., input_basis_row...))
return nothing
elseif is_solved_variable(stmt) || is_known_invoke(stmt, observed!, ir)
if is_solved_variable(stmt)
Expand Down Expand Up @@ -248,7 +248,7 @@ function get_reconstruct_der_visit_custom!(var_assignment)
end

stmt = ir[ssa][:inst]
if is_known_invoke_or_call(stmt, variable, ir) || is_known_invoke_or_call(stmt, state_ddt, ir)
if is_known_invoke_or_call(stmt, variable, ir) || is_known_invoke_or_call(stmt, state_ddt, ir) || is_known_invoke(stmt, epsilon, ir)
return true
elseif is_known_invoke_or_call(stmt, solved_variable, ir)
recurse(stmt.args[end])
Expand Down

0 comments on commit 9afb5c5

Please sign in to comment.