Skip to content

Commit

Permalink
inference: backward constraint propagation from call signatures
Browse files Browse the repository at this point in the history
This PR implements another (limited) backward analysis pass in abstract
interpretation; it exploits signatures of matching methods and refines
types of slots.

Here are couple of examples where these changes will improve the accuracy:

> generic function example
```julia
addi(a::Integer, b::Integer) = a + b
Base.return_types((Any,Any,)) do a, b
    c = addi(a, b)
    return a, b, c # now the compiler understands `a::Integer`, `b::Integer`
end
```

> `typeassert` example
```julia
Base.return_types((Any,)) do a
    typeassert(a, Int)
    return a # now the compiler understands `a::Int`
end
```

This PR consists of two main parts: 1.) obtain refinement information
and back-propagate it, and 2.) apply state updates

As for 1., unlike conditional constraints, refinement information isn't
encoded within lattice element, but rather they are stored in the
newly defined field `InferenceState.state_updates`, which is refreshed
on each program counter increment. For now refinement information is
obtained from call signatures of generic functions and `typeassert`.

Finally, in order to apply multiple state updates, this PR extends
`StateUpdate` and `stupdate` so that they can hold and apply multiple
state updates.
  • Loading branch information
aviatesk committed Jul 22, 2024
1 parent b451a1c commit 6e8d7a4
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 28 deletions.
59 changes: 54 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
multiple_matches = napplicable > 1
fargs = arginfo.fargs
all_effects = EFFECTS_TOTAL
# if fargs !== nothing
# # keeps refinement information on slot types obtained from call signature
# refine_targets = Union{Nothing,Tuple{SlotNumber,Any}}[]
# for i = 1:length(fargs)
# x = fargs[i]
# push!(refine_targets, isa(x, SlotNumber) ? (x, Bottom) : nothing)
# end
# else
# refine_targets = nothing
# end
refine_targets = nothing

for i in 1:napplicable
match = applicable[i]::MethodMatch
Expand Down Expand Up @@ -162,6 +173,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = tmerge(𝕃ᵢ, conditionals[2][i], cnd.elsetype)
end
end
if refine_targets !== nothing
for i in 1:length(refine_targets)
target = refine_targets[i]
if target !== nothing
refine_targets[i] = (target[1], tmerge(𝕃ᵢ, fieldtype(sig, i), target[2]))
end
end
end
if bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.")
break
Expand All @@ -177,6 +196,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# there is unanalyzed candidate, widen type and effects to the top
rettype = excttype = Any
all_effects = Effects()
refine_targets = nothing
elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
Expand All @@ -203,6 +223,19 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

# if refinement information on slot types is available, apply it now
anyrefined = false
if rettype !== Bottom && refine_targets !== nothing
for target in refine_targets
if target !== nothing
slot, t = target
if t !== Bottom
push!(frame.curr_stmt_changes, (slot, t))
anyrefined = true # TODO limit this when t ⋤ old
end
end
end
end
if call_result_unused(si) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
Expand All @@ -213,7 +246,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
add_call_backedges!(interp, rettype, all_effects, anyrefined, edges, matches, atype, sv)
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
Expand Down Expand Up @@ -492,8 +525,9 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
end
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), all_effects::Effects,
edges::Vector{MethodInstance}, matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype),
all_effects::Effects, anyrefined::Bool, edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::AbsIntState)
# don't bother to add backedges when both type and effects information are already
# maximized to the top since a new method couldn't refine or widen them anyway
Expand All @@ -503,7 +537,9 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
if !isoverlayed(method_table(interp))
all_effects = Effects(all_effects; nonoverlayed=ALWAYS_FALSE)
end
all_effects === Effects() && return nothing
if all_effects === Effects() && !anyrefined
return nothing
end
end
for edge in edges
add_backedge!(sv, edge)
Expand Down Expand Up @@ -1821,7 +1857,12 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
ft = popfirst!(argtypes)
rt = builtin_tfunction(interp, f, argtypes, sv)
pushfirst!(argtypes, ft)
if has_mustalias(𝕃ᵢ) && f === getfield && isa(fargs, Vector{Any}) && la 3
if f === typeassert
# perform very limited back-propagation of invariants after this type asertion
if rt !== Bottom && isa(fargs, Vector{Any}) && (x2 = fargs[2]; isa(x2, SlotNumber))
push!(sv.curr_stmt_changes, (x2, rt))
end
elseif has_mustalias(𝕃ᵢ) && f === getfield && isa(fargs, Vector{Any}) && la 3
a3 = argtypes[3]
if isa(a3, Const)
if rt !== Bottom && !isalreadyconst(rt)
Expand Down Expand Up @@ -3374,6 +3415,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if changes !== nothing
stoverwrite1!(currstate, changes)
end
while !isempty(frame.curr_stmt_changes)
var, typ = pop!(frame.curr_stmt_changes)
if changes !== nothing && var == changes.var
continue # type propagation from statement (like assignment) should have the precedence
end
stmt_changes = StateUpdate(var, VarState(typ, false), currstate, false)
stoverwrite1!(currstate, stmt_changes)
end
if rt === nothing
ssavaluetypes[currpc] = Any
continue
Expand Down
17 changes: 16 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,9 @@ mutable struct InferenceState
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
stmt_info::Vector{CallInfo}
# additional state updates at current statement made by means other than the assignment
# e.g. type information refinement from `typeassert` call itself
curr_stmt_changes::Vector{Tuple{SlotNumber,Any}}

#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand Down Expand Up @@ -301,6 +304,8 @@ mutable struct InferenceState
stmt_edges = Vector{Vector{Any}}(undef, nstmts)
stmt_info = CallInfo[ NoCallInfo() for i = 1:nstmts ]

curr_stmt_changes = Tuple{SlotNumber,Any}[]

nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
Expand Down Expand Up @@ -350,7 +355,7 @@ mutable struct InferenceState

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, curr_stmt_changes,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand Down Expand Up @@ -833,6 +838,16 @@ function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
return nothing
end

function add_stmt_change!(slot::SlotNumber, @nospecialize(new), frame::InferenceState)
state = frame.stmt_types[frame.currpc]::VarTable
old = get_varstate(state, slot).typ
if !(old new) # new ⋤ old
push!(frame.curr_stmt_changes, (slot, new))
return true
end
return false
end

function print_callstack(sv::InferenceState)
print("=================== Callstack: ==================\n")
idx = 0
Expand Down
22 changes: 0 additions & 22 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -724,28 +724,6 @@ function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional:
return nothing
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::StateUpdate)
changed = false
changeid = slot_id(changes.var)
for i = 1:length(state)
if i == changeid
newtype = changes.vtype
else
newtype = changes.state[i]
end
invalidated = invalidate_slotwrapper(newtype, changeid, changes.conditional)
if invalidated !== nothing
newtype = invalidated
end
oldtype = state[i]
if schanged(lattice, newtype, oldtype)
state[i] = smerge(lattice, oldtype, newtype)
changed = true
end
end
return changed
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
changed = false
for i = 1:length(state)
Expand Down
120 changes: 120 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4203,6 +4203,126 @@ end
end
end == [Union{Some{Float64}, Some{Int}, Some{UInt8}}]

@testset "constraint back-propagation from typeassert" begin
@test Base.infer_return_type((Any,)) do a
typeassert(a, Int)
return a
end == Int

@test Base.infer_return_type((Any,Bool)) do a, b
if b
typeassert(a, Int64)
else
typeassert(a, Int32)
end
return a
end == Union{Int32,Int64}
end

callsig_backprop_basic(::Int) = nothing
callsig_backprop_unionsplit(::Int32) = nothing
callsig_backprop_unionsplit(::Int64) = nothing
callsig_backprop_multi1(::Int) = nothing
callsig_backprop_multi2(::Nothing) = nothing
callsig_backprop_any(::Any) = nothing
callsig_backprop_lhs(::Int) = nothing
callsig_backprop_bailout(::Val{0}) = 0
callsig_backprop_bailout(::Val{1}) = undefvar # ::Any
callsig_backprop_bailout(::Val{2}) = 2
callsig_backprop_addinteger(a::Integer, b::Integer) = a + b # too many maching methods, and return type should be annotated as `Any` (and thus caught by `bail_out_call`)
@test Base.infer_return_type(callsig_backprop_addinteger) == Any
let effects = Base.infer_effects(callsig_backprop_addinteger)
@test !Core.Compiler.is_consistent(effects)
@test !Core.Compiler.is_effect_free(effects)
@test !Core.Compiler.is_nothrow(effects)
@test !Core.Compiler.is_terminates(effects)
end

@testset "constraint back-propagation from call signature" begin
# basic case
@test Base.infer_return_type(a->(callsig_backprop_basic(a); return a), (Any,)) == Int

# union-split case
@test Base.infer_return_type(a->(callsig_backprop_unionsplit(a); return a), (Any,)) == Union{Int32,Int64}

# multiple state updates
@test Base.infer_return_type((Any,Any)) do a, b
callsig_backprop_multi1(a)
callsig_backprop_multi2(b)
return a, b
end == Tuple{Int,Nothing}

# refinement should happen only when it's worthwhile
@test Base.infer_return_type(a->(callsig_backprop_any(a); return a), (Integer,)) == Integer

# state update on lhs slot (assignment effect should have the precedence)
@test Base.infer_return_type((Any,)) do a
a = callsig_backprop_lhs(a)
return a
end == Nothing

# make sure to throw away an intermediate refinement information when we bail out early
# (inference would bail out on `callsig_backprop_bailout(::Val{1})`)
@test Base.infer_return_type(a->(callsig_backprop_bailout(a); return a), (Any,)) == Any

# if we see all the matching methods, we don't need to throw away refinement information
# even if it's caught by `bail_out_call` check
@test Base.infer_return_type((Any,Any)) do a, b
callsig_backprop_addinteger(a, b)
return a, b
end == Tuple{Integer,Integer}

let
m = Module()
@eval m outer(a) = (_inner!(a); return a)

@test (@eval m begin
_inner!(::Int) = globalvar # ::Any
Base.return_types((Any,)) do a
return outer(a) # ::Int
end
end) == Any[Int]

# new definition of `_inner!` should invalidate `outer`
# (even if the previous return type is annotated as `Any`)
@test (@eval m begin
_inner!(::Nothing) = globalvar # ::Any
Base.return_types((Any,)) do a
# since inference will bail out at the first matched `_inner!` and so call signature constraint won't be available
return outer(a) # ::Union{Int,Nothing} ideally, but ::Any
end
end) Any[Int]
end
end

# make sure to add backedges when we use call signature constraint
function callsig_backprop_invalidation_outer(a)
callsig_backprop_invalidation_inner!(a)
return a
end
@eval callsig_backprop_invalidation_inner!(::Int) = $(gensym(:undefvar)) # ::Any
@test Base.infer_return_type((Any,)) do a
callsig_backprop_invalidation_outer(a)
end == Int
# new definition of `callsig_backprop_invalidation_inner!` should invalidate `callsig_backprop_invalidation_outer`
# (even if the previous return type is annotated as `Any`)
@eval callsig_backprop_invalidation_inner!(::Nothing) = $(gensym(:undefvar)) # ::Any
@test Base.infer_return_type((Any,)) do a
# since inference will bail out at the first matched `_inner!` and so call signature constraint won't be available
callsig_backprop_invalidation_outer(a)
end Int

# https://github.com/JuliaLang/julia/issues/37866
function issue37866(v::Vector{Union{Nothing,Float64}})
for x in v
if x > 5.0
return x # x > 5.0 is MethodError for Nothing so can assume ::Float64
end
end
return 0.0
end
@test Base.infer_return_type(issue37866, (Vector{Union{Nothing,Float64}},)) == Float64

# make sure inference on a recursive call graph with nested `Type`s terminates
# https://github.com/JuliaLang/julia/issues/40336
f40336(@nospecialize(t)) = f40336(Type{t})
Expand Down

0 comments on commit 6e8d7a4

Please sign in to comment.