Skip to content

Commit

Permalink
Use replace_call!() to replace Expr(:call, ...) values
Browse files Browse the repository at this point in the history
This is necessary to prevent the callinfo field from falling out of sync
with the call itself, causing future optimization passes (such as inlining)
to compute incorrect results.
  • Loading branch information
staticfloat committed Jul 10, 2023
1 parent dc3c60e commit 077de69
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 4 additions & 7 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I
frule_result = insert_node!(ir, ssa, NewInstruction(
frule_call, frule_rt, info.frule_call.info, inst[:line],
frule_flag))
ir[ssa][:inst] = Expr(:call, GlobalRef(Core, :getfield), frule_result, 1)
replace_call!(ir, ssa, Expr(:call, GlobalRef(Core, :getfield), frule_result, 1))
Δssa = insert_node!(ir, ssa, NewInstruction(
Expr(:call, GlobalRef(Core, :getfield), frule_result, 2), CC.getfield_tfunc(CC.typeinf_lattice(interp), frule_rt, Const(2))), #=attach_after=#true)
return Δssa
Expand Down Expand Up @@ -285,15 +285,13 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
newargs = map(stmt.args[2:end]) do @nospecialize arg
maparg(arg, SSAValue(ssa), order)
end
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
inst[:type] = Any
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...))
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
newargs = map(stmt.args) do @nospecialize arg
maparg(arg, SSAValue(ssa), order)
end
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
inst[:inst] = Expr(:call, f, newargs...)
inst[:type] = Any
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
elseif isa(stmt, PiNode)
# TODO: New PiNode that discriminates based on primal?
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
Expand All @@ -304,8 +302,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
stmt = insert_node!(ir, ssa, NewInstruction(inst))
end

inst[:inst] = Expr(:call, ZeroBundle{order}, stmt)
inst[:type] = Any
replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt))
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
inst[:type] = Any
Expand Down
8 changes: 7 additions & 1 deletion src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,10 @@ function find_end_of_phi_block(ir, start_search_idx::Int)
stmt !== nothing && !isa(stmt, PhiNode) && return idx
end
return end_search_idx
end
end

function replace_call!(ir, idx::SSAValue, new_call)
ir[idx][:inst] = new_call
ir[idx][:type] = Any
ir[idx][:info] = CC.NoCallInfo()
end

0 comments on commit 077de69

Please sign in to comment.