Skip to content

Commit

Permalink
inference: backward constraint propagation from call signatures
Browse files Browse the repository at this point in the history
This PR implements another (limited) backward analysis pass in abstract
interpretation; it exploits signatures of matching methods and refines
types of slots.

Here are couple of examples where these changes will improve the accuracy:

> generic function example
```julia
addi(a::Integer, b::Integer) = a + b
Base.return_types((Any,Any,)) do a, b
    c = addi(a, b)
    return a, b, c # now the compiler understands `a::Integer`, `b::Integer`
end
```

> `typeassert` example
```julia
Base.return_types((Any,)) do a
    typeassert(a, Int)
    return a # now the compiler understands `a::Int`
end
```

This PR consists of two main parts: 1.) obtain refinement information
and back-propagate it, and 2.) apply state updates

As for 1., unlike conditional constraints, refinement information isn't
encoded within lattice element, but rather they are conveyed by the
newly defined field `InferenceState.state_updates`, which is refreshed
on each program counter increment. For now refinement information is
obtained from call signatures of generic functions and `typeassert`.

Finally, in order to apply multiple state updates, this PR extends
`StateUpdate` and `stupdate` so that they can hold and apply multiple
state updates.
  • Loading branch information
aviatesk committed May 20, 2021
1 parent 9117b4d commit e87ad2b
Show file tree
Hide file tree
Showing 4 changed files with 241 additions and 25 deletions.
95 changes: 87 additions & 8 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
any_const_result = false
const_results = Union{InferenceResult,Nothing}[]
multiple_matches = napplicable > 1
refine_targets = nothing # keeps refinement information on slot types obtained from call signature
if fargs !== nothing
refine_targets = Union{Nothing,Tuple{SlotNumber,Any}}[]
for x in fargs
push!(refine_targets, isa(x, SlotNumber) ? (x, Bottom) : nothing)
end
end

if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch)
val = pure_eval_call(f, argtypes)
Expand Down Expand Up @@ -197,6 +204,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = tmerge(conditionals[2][i], elsetype)
end
end
if refine_targets !== nothing
for i in 1:length(refine_targets)
target = refine_targets[i]
if target !== nothing
slot, t = target
refine_targets[i] = (slot, tmerge(fieldtype(sig, i), t))
end
end
end
if bail_out_call(interp, rettype, sv)
break
end
Expand All @@ -209,6 +225,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
info = ConstCallInfo(info, const_results)
end

# refinement information from call signatures is valid only after we succeed in inferring
# all the matching signatures and we should invalidate it if we bailed out early
if seen napplicable
refine_targets = nothing
end

if rettype isa LimitedAccuracy
union!(sv.pclimitations, rettype.causes)
rettype = rettype.typ
Expand Down Expand Up @@ -263,6 +285,18 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
@assert !(rettype isa InterConditional) "invalid lattice element returned from inter-procedural context"

# if refinement information on slot types is available, apply it now
anyrefined = false
if rettype !== Bottom && refine_targets !== nothing
for target in refine_targets
if target !== nothing
slot, t = target
if t !== Bottom
anyrefined |= add_state_update!(slot, t, sv)
end
end
end
end
if call_result_unused(sv) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
Expand All @@ -273,7 +307,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv)
add_call_backedges!(interp, anyrefined, rettype, edges, fullmatch, mts, atype, sv)
if !isempty(sv.pclimitations) # remove self, if present
delete!(sv.pclimitations, sv)
for caller in sv.callers_in_cycle
Expand All @@ -285,13 +319,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end

function add_call_backedges!(interp::AbstractInterpreter,
@nospecialize(rettype),
anyrefined::Bool, @nospecialize(rettype),
edges::Vector{MethodInstance},
fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype),
sv::InferenceState)
if rettype === Any
# for `NativeInterpreter`, we don't add backedges when a new method couldn't refine
# (widen) this type
if !anyrefined && rettype === Any
# for `NativeInterpreter`, we don't add backedges when we've not used refinement
# information from call signature and a new method couldn't refine (widen) this type
return
end
for edge in edges
Expand Down Expand Up @@ -1000,6 +1034,11 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U
if 1 <= idx <= length(cti)
rt = unwrapva(cti[idx])
end
elseif f === typeassert
# perform very limited back-propagation of invariants after this type asertion
if rt !== Bottom && isa(fargs, Vector{Any}) && (x2 = fargs[2]; isa(x2, SlotNumber))
add_state_update!(x2, rt, sv)
end
elseif (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any})
# perform very limited back-propagation of type information for `is` and `isa`
if f === isa
Expand Down Expand Up @@ -1658,6 +1697,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
stmt = frame.src.code[pc]
changes = states[pc]::VarTable
t = nothing
empty!(frame.state_updates)

hd = isa(stmt, Expr) ? stmt.head : nothing

Expand Down Expand Up @@ -1778,12 +1818,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
frame.src.ssavaluetypes[pc] = t
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(t, false), changes, false)
changes = StateUpdate([lhs], [VarState(t, false)], changes, false)
end
elseif hd === :method
fname = stmt.args[1]
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), changes, false)
changes = StateUpdate([fname], [VarState(Any, false)], changes, false)
end
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
# these do not generate code
Expand Down Expand Up @@ -1821,6 +1861,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

pc´ > n && break # can't proceed with the fast-path fall-through
frame.handler_at[pc´] = frame.cur_hand
changes = collect_state_updates!(changes, frame)
newstate = stupdate!(states[pc´], changes)
if isa(stmt, GotoNode) && frame.pc´´ < pc´
# if we are processing a goto node anyways,
Expand All @@ -1846,14 +1887,52 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
nothing
end

function add_state_update!(slot::SlotNumber, @nospecialize(new), frame::InferenceState)
states = frame.stmt_types[frame.currpc]::VarTable
old = ((states[slot_id(slot)])::VarState).typ
if !(old new) # new ⋤ old
push!(frame.state_updates, (slot, new))
return true
end
return false
end

function collect_state_updates!(changes::StateUpdate, frame::InferenceState)
state_updates = frame.state_updates
vars = changes.vars
vtypes = changes.vtypes
while !isempty(state_updates)
var, typ = pop!(state_updates)
var in vars && continue # state update from lhs assigment should always has the precedence
push!(vars, var)
vtype = VarState(typ, (changes.state[slot_id(var)]::VarState).undef)
push!(vtypes, vtype)
end
return changes
end

function collect_state_updates!(changes::VarTable, frame::InferenceState)
state_updates = frame.state_updates
isempty(state_updates) && return changes
vars = SlotNumber[]
vtypes = VarState[]
while !isempty(state_updates)
var, typ = pop!(state_updates)
push!(vars, var)
vtype = VarState(typ, (changes[slot_id(var)]::VarState).undef)
push!(vtypes, vtype)
end
return StateUpdate(vars, vtypes, changes, false)
end

function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber)
oldtyp = (changes[slot_id(var)]::VarState).typ
# approximate test for `typ ∩ oldtyp` being better than `oldtyp`
# since we probably formed these types with `typesubstract`, the comparison is likely simple
if ignorelimited(typ) ignorelimited(oldtyp)
# typ is better unlimited, but we may still need to compute the tmeet with the limit "causes" since we ignored those in the comparison
oldtyp isa LimitedAccuracy && (typ = tmerge(typ, LimitedAccuracy(Bottom, oldtyp.causes)))
return StateUpdate(var, VarState(typ, false), changes, true)
return StateUpdate([var], [VarState(typ, false)], changes, true)
end
return changes
end
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ mutable struct InferenceState
stmt_types::Vector{Union{Nothing, Vector{Any}}} # ::Vector{Union{Nothing, VarTable}}
stmt_edges::Vector{Union{Nothing, Vector{Any}}}
stmt_info::Vector{Any}
state_updates::Vector{Tuple{SlotNumber,Any}} # additional state update obtained at currpc
# return type
bestguess #::Type
# current active instruction pointers
Expand Down Expand Up @@ -108,7 +109,7 @@ mutable struct InferenceState
sp, slottypes, inmodule, 0,
IdSet{InferenceState}(), IdSet{InferenceState}(),
src, get_world_counter(interp), valid_worlds,
nargs, s_types, s_edges, stmt_info,
nargs, s_types, s_edges, stmt_info, Tuple{SlotNumber,Any}[],
Union{}, W, 1, n,
cur_hand, handler_at, n_handlers,
ssavalue_uses, throw_blocks,
Expand Down
43 changes: 27 additions & 16 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ end
const VarTable = Array{Any,1}

struct StateUpdate
var::SlotNumber
vtype::VarState
vars::Vector{SlotNumber}
vtypes::Vector{VarState}
state::VarTable
conditional::Bool
end
Expand Down Expand Up @@ -320,32 +320,40 @@ ignorelimited(@nospecialize typ) = typ
ignorelimited(typ::LimitedAccuracy) = typ.typ

function stupdate!(state::Nothing, changes::StateUpdate)
newst = copy(changes.state)
changeid = slot_id(changes.var)
newst[changeid] = changes.vtype
newstate = copy(changes.state)
changeids = Int[]
for (var, vtype) in zip(changes.vars, changes.vtypes)
changeid = slot_id(var)
newstate[changeid] = vtype
push!(changeids, changeid)
end
# remove any Conditional for this slot from the vtable
# (unless this change is came from the conditional)
if !changes.conditional
for i = 1:length(newst)
newtype = newst[i]
for i = 1:length(newstate)
newtype = newstate[i]
if isa(newtype, VarState)
newtypetyp = ignorelimited(newtype.typ)
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) in changeids
newtypetyp = widenwrappedconditional(newtype.typ)
newst[i] = VarState(newtypetyp, newtype.undef)
newstate[i] = VarState(newtypetyp, newtype.undef)
end
end
end
end
return newst
return newstate
end

function stupdate!(state::VarTable, changes::StateUpdate)
changeids = Int[]
for var in changes.vars
push!(changeids, slot_id(var))
end
newstate = nothing
changeid = slot_id(changes.var)
for i = 1:length(state)
if i == changeid
newtype = changes.vtype
j = findfirst(==(i), changeids)
if j !== nothing
newtype = changes.vtypes[j]
else
newtype = changes.state[i]
end
Expand All @@ -354,7 +362,7 @@ function stupdate!(state::VarTable, changes::StateUpdate)
# (unless this change is came from the conditional)
if !changes.conditional && isa(newtype, VarState)
newtypetyp = ignorelimited(newtype.typ)
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid
if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) in changeids
newtypetyp = widenwrappedconditional(newtype.typ)
newtype = VarState(newtypetyp, newtype.undef)
end
Expand Down Expand Up @@ -385,7 +393,10 @@ stupdate!(state::Nothing, changes::VarTable) = copy(changes)
stupdate!(state::Nothing, changes::Nothing) = nothing

function stupdate1!(state::VarTable, change::StateUpdate)
changeid = slot_id(change.var)
vars, vtypes = change.vars, change.vtypes
@assert length(vars) == length(vtypes) == 1
var, vtype = vars[1], vtypes[1]
changeid = slot_id(var)
# remove any Conditional for this slot from the catch block vtable
# (unless this change is came from the conditional)
if !change.conditional
Expand All @@ -404,7 +415,7 @@ function stupdate1!(state::VarTable, change::StateUpdate)
end
end
# and update the type of it
newtype = change.vtype
newtype = vtype
oldtype = state[changeid]
if schanged(newtype, oldtype)
state[changeid] = smerge(oldtype, newtype)
Expand Down
Loading

0 comments on commit e87ad2b

Please sign in to comment.