Skip to content

Commit

Permalink
Merge pull request #180 from JuliaDiff/sf/replace_call
Browse files Browse the repository at this point in the history
Use `replace_call!()` to replace `Expr(:call, ...)` values
  • Loading branch information
oxinabox authored Jul 11, 2023
2 parents dc3c60e + 077de69 commit e308c82
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 4 additions & 7 deletions src/codegen/forward_demand.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ function forward_diff_uncached!(ir::IRCode, interp::AbstractInterpreter, irsv::I
frule_result = insert_node!(ir, ssa, NewInstruction(
frule_call, frule_rt, info.frule_call.info, inst[:line],
frule_flag))
ir[ssa][:inst] = Expr(:call, GlobalRef(Core, :getfield), frule_result, 1)
replace_call!(ir, ssa, Expr(:call, GlobalRef(Core, :getfield), frule_result, 1))
Δssa = insert_node!(ir, ssa, NewInstruction(
Expr(:call, GlobalRef(Core, :getfield), frule_result, 2), CC.getfield_tfunc(CC.typeinf_lattice(interp), frule_rt, Const(2))), #=attach_after=#true)
return Δssa
Expand Down Expand Up @@ -285,15 +285,13 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
newargs = map(stmt.args[2:end]) do @nospecialize arg
maparg(arg, SSAValue(ssa), order)
end
inst[:inst] = Expr(:call, ∂☆{order}(), newargs...)
inst[:type] = Any
replace_call!(ir, SSAValue(ssa), Expr(:call, ∂☆{order}(), newargs...))
elseif isexpr(stmt, :call) || isexpr(stmt, :new)
newargs = map(stmt.args) do @nospecialize arg
maparg(arg, SSAValue(ssa), order)
end
f = isexpr(stmt, :call) ? ∂☆{order}() : ∂☆new{order}()
inst[:inst] = Expr(:call, f, newargs...)
inst[:type] = Any
replace_call!(ir, SSAValue(ssa), Expr(:call, f, newargs...))
elseif isa(stmt, PiNode)
# TODO: New PiNode that discriminates based on primal?
inst[:inst] = maparg(stmt.val, SSAValue(ssa), order)
Expand All @@ -304,8 +302,7 @@ function forward_diff_no_inf!(ir::IRCode, to_diff::Vector{Pair{SSAValue,Int}};
stmt = insert_node!(ir, ssa, NewInstruction(inst))
end

inst[:inst] = Expr(:call, ZeroBundle{order}, stmt)
inst[:type] = Any
replace_call!(ir, SSAValue(ssa), Expr(:call, ZeroBundle{order}, stmt))
elseif isa(stmt, SSAValue) || isa(stmt, QuoteNode)
inst[:inst] = maparg(stmt, SSAValue(ssa), order)
inst[:type] = Any
Expand Down
8 changes: 7 additions & 1 deletion src/stage1/compiler_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,4 +89,10 @@ function find_end_of_phi_block(ir, start_search_idx::Int)
stmt !== nothing && !isa(stmt, PhiNode) && return idx
end
return end_search_idx
end
end

function replace_call!(ir, idx::SSAValue, new_call)
ir[idx][:inst] = new_call
ir[idx][:type] = Any
ir[idx][:info] = CC.NoCallInfo()
end

0 comments on commit e308c82

Please sign in to comment.