Skip to content

Commit

Permalink
wip: implement a better statement selection logic
Browse files Browse the repository at this point in the history
Specifically, this commit aims to review the implementation of
`add_control_flow!` and improves its accuracy. Ideally, it should pass
JET's existing test cases as well as the newly added ones, including the
test cases from JuliaDebug/LoweredCodeUtils.jl#99. The goal is to share
the same high-precision CFG selection logic between LoweredCodeUtils
and JET.

The new algorithm is based on what was proposed in [^Wei84]. If there is
even one active block in the blocks reachable from a conditional branch
up to its successors' nearest common post-dominator (referred to as
**INFL** in the paper), it is necessary to follow that conditional
branch and execute the code. Otherwise, execution can be short-circuited
from the conditional branch to the nearest common post-dominator.

COMBAK: It is important to note that in Julia's IR (`CodeInfo`),
"short-circuiting" a specific code region is not a simple task. Simply
ignoring the path to the post-dominator does not guarantee fall-through
to the post-dominator. Therefore, a more careful implementation is
required for this aspect.

[Wei84]: M. Weiser, "Program Slicing," IEEE Transactions on Software Engineering, 10, pages 352-357, July 1984.
  • Loading branch information
aviatesk committed Sep 7, 2024
1 parent 2dfcd4e commit d47b751
Show file tree
Hide file tree
Showing 3 changed files with 221 additions and 66 deletions.
7 changes: 4 additions & 3 deletions src/JET.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ using .CC: @nospecs, ⊑,
InvokeCallInfo, MethodCallResult, MethodMatchInfo, MethodMatches, NOT_FOUND,
OptimizationState, OptimizationParams, OverlayMethodTable, StmtInfo, UnionSplitInfo,
UnionSplitMethodMatches, VarState, VarTable, WorldRange, WorldView,
argextype, argtype_by_index, argtypes_to_type, hasintersect, ignorelimited,
instanceof_tfunc, istopfunction, singleton_type, slot_id, specialize_method,
tmeet, tmerge, typeinf_lattice, widenconst, widenlattice
argextype, argtype_by_index, argtypes_to_type, compute_basic_blocks, construct_domtree,
construct_postdomtree, hasintersect, ignorelimited, instanceof_tfunc, istopfunction,
nearest_common_dominator, singleton_type, slot_id, specialize_method, tmeet, tmerge,
typeinf_lattice, widenconst, widenlattice

using Base: IdSet, get_world_counter

Expand Down
169 changes: 111 additions & 58 deletions src/toplevel/virtualprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1091,16 +1091,32 @@ end

# select statements that should be concretized, and actually interpreted rather than abstracted
function select_statements(mod::Module, src::CodeInfo)
stmts = src.code
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)

concretize = falses(length(stmts))

select_direct_requirement!(concretize, stmts, edges)

concretize = falses(length(src.code))
select_direct_requirement!(concretize, src.code, edges)
select_dependencies!(concretize, src, edges, cl)
return concretize
end

# just for testing, and debugging
function select_statements(mod::Module, src::CodeInfo, names::Symbol...)
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)
concretize = falses(length(src.code))
objs = Set{GlobalRef}(GlobalRef(mod, name) for name in names)
LoweredCodeUtils.add_requests!(concretize, objs, edges, ())
select_dependencies!(concretize, src, edges, cl)
return concretize
end
function select_statements(mod::Module, src::CodeInfo, idxs::Int...)
cl = LoweredCodeUtils.CodeLinks(mod, src) # make `CodeEdges` hold `CodeLinks`?
edges = LoweredCodeUtils.CodeEdges(src, cl)
concretize = falses(length(src.code))
for idx = idxs
concretize[idx] |= true
end
select_dependencies!(concretize, src, edges, cl)
return concretize
end

Expand Down Expand Up @@ -1173,67 +1189,70 @@ end

# The goal of this function is to request concretization of the minimal necessary control
# flow to evaluate statements whose concretization have already been requested.
# The basic approach is to check if there are any active successors for each basic block,
# and if there is an active successor and the terminator is not a fall-through, then request
# the concretization of that terminator. Additionally, for conditional terminators, a simple
# optimization using post-domination analysis is also performed.
function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, postdomtree)
# The basic algorithm is based on what was proposed in [^Wei84]. If there is even one active
# block in the blocks reachable from a conditional branch up to its successors' nearest
# common post-dominator (referred to as **INFL** in the paper), it is necessary to follow
# that conditional branch and execute the code. Otherwise, execution can be short-circuited
# from the conditional branch to the nearest common post-dominator.
#
# COMBAK: It is important to note that in Julia's intermediate code representation (`CodeInfo`),
# "short-circuiting" a specific code region is not a simple task. Simply ignoring the path
# to the post-dominator does not guarantee fall-through to the post-dominator. Therefore,
# a more careful implementation is required for this aspect.
#
# [Wei84]: M. Weiser, "Program Slicing," IEEE Transactions on Software Engineering, 10, pages 352-357, July 1984.
function add_control_flow!(concretize::BitVector, src::CodeInfo, cfg::CFG, domtree, postdomtree)
local changed::Bool = false
function mark_concretize!(idx::Int)
if !concretize[idx]
concretize[idx] = true
changed |= concretize[idx] = true
return true
end
return false
end
nblocks = length(cfg.blocks)
for bbidx = 1:nblocks
bb = cfg.blocks[bbidx] # forward traversal
for bbidx = 1:length(cfg.blocks) # forward traversal
bb = cfg.blocks[bbidx]
nsuccs = length(bb.succs)
if nsuccs == 0
continue
elseif nsuccs == 1
terminator_idx = bb.stmts[end]
if src.code[terminator_idx] isa GotoNode
# If the destination of this `GotoNode` is not active, it's fine to ignore
# the control flow caused by this `GotoNode` and treat it as a fall-through.
# If the block that is fallen through to is active and has a dependency on
# this goto block, then the concretization of this goto block should already
# be requested (at some point of the higher concretization convergence cycle
# of `select_dependencies`), and thus, this `GotoNode` will be concretized.
if any(@view concretize[cfg.blocks[only(bb.succs)].stmts])
changed |= mark_concretize!(terminator_idx)
termidx = bb.stmts[end]
if src.code[termidx] isa GotoNode
succ = only(bb.succs)
if any(@view concretize[cfg.blocks[succ].stmts])
dominator = nearest_common_dominator(domtree, bbidx, succ)
if dominator succ
for blk in reachable_blocks(cfg, dominator, succ)
if blk == dominator || blk == succ
continue
end
if any(@view concretize[cfg.blocks[blk].stmts])
mark_concretize!(termidx)
break
end
end
else
mark_concretize!(termidx)
end
end
end
continue # we can just fall-through
elseif nsuccs == 2
terminator_idx = bb.stmts[end]
@assert is_conditional_terminator(src.code[terminator_idx]) "invalid IR"
termidx = bb.stmts[end]
@assert is_conditional_terminator(src.code[termidx]) "invalid IR"
succ1, succ2 = bb.succs
succ1_req = any(@view concretize[cfg.blocks[succ1].stmts])
succ2_req = any(@view concretize[cfg.blocks[succ2].stmts])
if succ1_req
if succ2_req
changed |= mark_concretize!(terminator_idx)
else
active_bb, inactive_bb = succ1, succ2
@goto asymmetric_case
postdominator = nearest_common_dominator(postdomtree, succ1, succ2)
inflblks = reachable_blocks(cfg, succ1, postdominator) reachable_blocks(cfg, succ2, postdominator)
for blk in inflblks
if blk == postdominator
continue
end
elseif succ2_req
active_bb, inactive_bb = succ2, succ1
@label asymmetric_case
# We can ignore the control flow of this conditional terminator and treat
# it as a fall-through if only one of its successors is active and the
# active block post-dominates the inactive one, since the post-domination
# ensures that the active basic block will be reached regardless of the
# control flow.
if CC.postdominates(postdomtree, active_bb, inactive_bb)
# fall through this block
else
changed |= mark_concretize!(terminator_idx)
if any(@view concretize[cfg.blocks[blk].stmts])
mark_concretize!(termidx)
break
end
else
# both successors are inactive, just fall through this block
end
# we can just fall-through to the post dominator block (by ignoring all statements between)
end
end
return changed
Expand All @@ -1242,6 +1261,25 @@ end
is_conditional_terminator(@nospecialize stmt) = stmt isa GotoIfNot ||
(@static @isdefined(EnterNode) ? stmt isa EnterNode : isexpr(stmt, :enter))

function reachable_blocks(cfg::CFG, from_bb::Int, to_bb::Int)
worklist = Int[from_bb]
visited = BitSet(from_bb)
if to_bb == from_bb
return visited
end
push!(visited, to_bb)
function visit!(bb::Int)
if bb visited
push!(visited, bb)
push!(worklist, bb)
end
end
while !isempty(worklist)
foreach(visit!, cfg.blocks[pop!(worklist)].succs)
end
return visited
end

function add_required_inplace!(concretize::BitVector, src::CodeInfo, edges, cl)
changed = false
for i = 1:length(src.code)
Expand Down Expand Up @@ -1272,27 +1310,42 @@ function is_arg_requested(@nospecialize(arg), concretize, edges, cl)
return false
end

# The purpose of this function is to find other statements that affect the execution of the
# statements choosen by `select_direct_dependencies!`. In other words, it extracts the
# minimal amount of code necessary to realize the required concretization.
# This technique is generally referred to as "program slicing," and specifically as
# "static program slicing". The basic algorithm implemented here is an extension of the one
# proposed in https://dl.acm.org/doi/10.5555/800078.802557, which is especially tuned for
# Julia's intermediate code representation.
function select_dependencies!(concretize::BitVector, src::CodeInfo, edges, cl)
typedefs = LoweredCodeUtils.find_typedefs(src)
cfg = CC.compute_basic_blocks(src.code)
postdomtree = CC.construct_postdomtree(cfg.blocks)
cfg = compute_basic_blocks(src.code)
domtree = construct_domtree(cfg.blocks)
postdomtree = construct_postdomtree(cfg.blocks)

while true
changed = false

# discover struct/method definitions at the beginning,
# and propagate the definition requirements by tracking SSA precedessors
# Discover Dtruct/method definitions at the beginning,
# and propagate the definition requirements by tracking SSA precedessors.
# (TODO maybe hoist this out of the loop?)
changed |= LoweredCodeUtils.add_typedefs!(concretize, src, edges, typedefs, ())
changed |= add_ssa_preds!(concretize, src, edges, ())

# mark some common inplace operations like `push!(x, ...)` and `setindex!(x, ...)`
# when `x` has been marked already: otherwise we may end up using it with invalid state
# Mark some common inplace operations like `push!(x, ...)` and `setindex!(x, ...)`
# when `x` has been marked already: otherwise we may end up using it with invalid state.
# However, note that this is an incomplete approach, and note that the slice created
# by this routine will not be sound because of this. This is because
# `add_required_inplace!` only requires certain special-cased function calls and
# does not take into account the possibility that arguments may be mutated in
# arbitrary function calls. Ideally, function calls should be required while
# considering the effects of these statements, or by some sort of an
# inter-procedural program slicing
changed |= add_required_inplace!(concretize, src, edges, cl)
changed |= add_ssa_preds!(concretize, src, edges, ())

# mark necessary control flows,
# and propagate the definition requirements by tracking SSA precedessors
changed |= add_control_flow!(concretize, src, cfg, postdomtree)
# Mark necessary control flows.
changed |= add_control_flow!(concretize, src, cfg, domtree, postdomtree)
changed |= add_ssa_preds!(concretize, src, edges, ())

changed || break
Expand Down
111 changes: 106 additions & 5 deletions test/toplevel/test_virtualprocess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1713,17 +1713,16 @@ end
# this particular example is adapted from https://en.wikipedia.org/wiki/Program_slicing
let src = @src let
N = 10
sum = 0
s = 0
product = 1 # should NOT be selected
w = 7
for i in 1:N
sum += i + w
s += i + w
product *= i # should NOT be selected
end
@eval global getsum() = $sum # concretization is forced
write(product) # should NOT be selected
end
slice = JET.select_statements(@__MODULE__, src)
slice = JET.select_statements(@__MODULE__, src, :sum)

found_N = found_sum = found_product = found_w = found_write = false
for (i, stmt) in enumerate(src.code)
Expand All @@ -1748,7 +1747,7 @@ end
found_write = true
@test !slice[i]
elseif (JET.isexpr(stmt, :call) && (arg1 = stmt.args[1]; arg1 isa Core.SSAValue) &&
src.code[arg1.id] === :write)
src.code[arg1.id] === :write)
found_write = true
@test !slice[i]
end
Expand Down Expand Up @@ -1778,6 +1777,108 @@ end
@test isempty(s)
end

# A more complex test case (xref: https://github.com/JuliaDebug/LoweredCodeUtils.jl/pull/99#issuecomment-2236373067)
# This test case might seem simple at first glance, but note that `x2` and `a2` are
# defined at the top level (because of the `begin` at the top).
# Since global variable type declarations have been allowed since v1.
# 10, a conditional branch that includes `Core.get_binding_type` is generated for
# these simple global variable assignments.
# Specifically, the code is lowered into something like this:
# 1 1: conditional branching based on `x2`'s binding type
# │╲
# │ ╲
# │ ╲ 2: goto block for the case when no conversion is required for the value of `x2`
# 2 3 3: fall-through block for the case when a conversion is required for the value of `x2`
# │ ╱
# │ ╱
# │╱
# 4 4: assignment to `x2`, **and**
# │╲ conditional branching based on `a2`'s binding type
# │ ╲
# │ ╲ 5: goto block for the case when no conversion is required for the value of `a2`
# 5 6 6: fall-through block for the case when a conversion is required for the value of `a2`
# │ ╱
# │ ╱
# │╱
# 7 7: assignment to `a2`
# What's important to note here is that since there's an assignment to `a2`,
# concretization of the blocks 4-6 is necessary. However, at the same time we also want
# to skip concretizing the blocks 1-3.
let src = @src begin
x2 = 5
a2 = 1
end
slice = JET.select_statements(@__MODULE__, src, :a2)

found_a2 = found_a2_get_binding_type = found_x2 = found_x2_get_binding_type = false
for (i, stmt) in enumerate(src.code)
if JET.isexpr(stmt, :(=))
lhs, rhs = stmt.args
if lhs isa GlobalRef
lhs = lhs.name
end
if lhs === :a2
found_a2 = true
@test slice[i]
elseif lhs === :x2
found_x2 = true
@test !slice[i] # this is easy to meet
end
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :a2))
found_a2_get_binding_type = true
@test slice[i]
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :x2))
found_x2_get_binding_type = true
@test !slice[i] # this is difficult to meet
end
end
@test found_a2; @test found_a2_get_binding_type; @test found_x2; @test found_x2_get_binding_type
end
let src = @src begin
cond = true
if cond
x = 1
y = 1
else
x = 2
y = 2
end
end
slice = JET.select_statements(@__MODULE__, src, :x)

found_cond = found_cond_get_binding_type = false
found_x = found_x_get_binding_type = found_y = found_y_get_binding_type = 0
for (i, stmt) in enumerate(src.code)
if JET.isexpr(stmt, :(=))
lhs, rhs = stmt.args
if lhs isa GlobalRef
lhs = lhs.name
end
if lhs === :cond
found_cond = true
@test slice[i]
elseif lhs === :x
found_x += 1
@test slice[i]
elseif lhs === :y
found_y += 1
@test !slice[i]
end
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :cond))
found_cond_get_binding_type = true
@test slice[i]
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :x))
found_x_get_binding_type += 1
@test slice[i]
elseif JET.@capture(stmt, $(GlobalRef(Core, :get_binding_type))(_, :y))
found_y_get_binding_type += 1
@test !slice[i]
end
end
@test found_cond; @test found_cond_get_binding_type
@test found_x == found_x_get_binding_type == found_y == found_y_get_binding_type == 2
end

@testset "captured variables" begin
let (vmod, res) = @analyze_toplevel2 begin
begin
Expand Down

0 comments on commit d47b751

Please sign in to comment.