Skip to content

Commit

Permalink
inference: backward constraint propagation from call signatures
Browse files Browse the repository at this point in the history
This PR implements another (limited) backward analysis pass in abstract
interpretation; it exploits signatures of matching methods and refines
types of slots.

Here are couple of examples where these changes will improve the accuracy:

> generic function example
```julia
addint(a::Int, b::Int) = a + b
@test Base.infer_return_type((Any,Any,)) do a, b
    c = addint(a, b)
    return a, b, c # now the compiler understands `a::Int`, `b::Int`
end == Tuple{Int,Int,Int}
```

> `typeassert` example
```julia
@test Base.infer_return_type((Any,)) do a
    a::Int
    return a # now the compiler understands `a::Int`
end == Int
```

Unlike `Conditional` constrained type propagation, this type refinement
information isn't encoded within any lattice element, but rather they
are propagated within the newly added field `frame.curr_stmt_change` of
`frame::InferenceState`.
For now this commit exploits refinement information available from call
signatures of generic functions and `typeassert`.
  • Loading branch information
aviatesk committed Jul 23, 2024
1 parent 364bd1a commit 6e94124
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 6 deletions.
60 changes: 55 additions & 5 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,16 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
multiple_matches = napplicable > 1
fargs = arginfo.fargs
all_effects = EFFECTS_TOTAL
if fargs !== nothing
# keeps refinement information on slot types obtained from call signature
refine_targets = Union{Nothing,StmtChange}[]
for i = 1:length(fargs)
x = fargs[i]
push!(refine_targets, isa(x, SlotNumber) ? StmtChange(x, Bottom) : nothing)
end
else
refine_targets = nothing
end

for i in 1:napplicable
match = applicable[i]::MethodMatch
Expand Down Expand Up @@ -162,6 +172,14 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = tmerge(𝕃ᵢ, conditionals[2][i], cnd.elsetype)
end
end
if refine_targets !== nothing
for i in 1:length(refine_targets)
target = refine_targets[i]
if target !== nothing
refine_targets[i] = StmtChange(target.slot, tmerge(𝕃ᵢ, fieldtype(sig, i), target.typ))
end
end
end
if bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.")
break
Expand All @@ -177,6 +195,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# there is unanalyzed candidate, widen type and effects to the top
rettype = excttype = Any
all_effects = Effects()
refine_targets = nothing
elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
Expand All @@ -203,6 +222,18 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

# if refinement information on slot types is available, apply it now
anyrefined = false
if rettype !== Bottom && refine_targets !== nothing
for target in refine_targets
if target !== nothing
if target.typ !== Bottom
push!(sv.curr_stmt_changes, target)
anyrefined = true # TODO limit this when t ⋤ old
end
end
end
end
if call_result_unused(si) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
Expand All @@ -213,7 +244,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
add_call_backedges!(interp, rettype, all_effects, edges, matches, atype, sv)
add_call_backedges!(interp, rettype, all_effects, anyrefined, edges, matches, atype, sv)
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
Expand Down Expand Up @@ -492,8 +523,9 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
end
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype), all_effects::Effects,
edges::Vector{MethodInstance}, matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype),
all_effects::Effects, anyrefined::Bool, edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::AbsIntState)
# don't bother to add backedges when both type and effects information are already
# maximized to the top since a new method couldn't refine or widen them anyway
Expand All @@ -503,7 +535,9 @@ function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype)
if !isoverlayed(method_table(interp))
all_effects = Effects(all_effects; nonoverlayed=ALWAYS_FALSE)
end
all_effects === Effects() && return nothing
if all_effects === Effects() && !anyrefined
return nothing
end
end
for edge in edges
add_backedge!(sv, edge)
Expand Down Expand Up @@ -1821,7 +1855,12 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
ft = popfirst!(argtypes)
rt = builtin_tfunction(interp, f, argtypes, sv)
pushfirst!(argtypes, ft)
if has_mustalias(𝕃ᵢ) && f === getfield && isa(fargs, Vector{Any}) && la 3
if f === typeassert
# perform very limited back-propagation of invariants after this type assertion
if rt !== Bottom && isa(fargs, Vector{Any}) && (x2 = fargs[2]; isa(x2, SlotNumber))
push!(sv.curr_stmt_changes, StmtChange(x2, rt))
end
elseif has_mustalias(𝕃ᵢ) && f === getfield && isa(fargs, Vector{Any}) && la 3
a3 = argtypes[3]
if isa(a3, Const)
if rt !== Bottom && !isalreadyconst(rt)
Expand Down Expand Up @@ -3384,6 +3423,17 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if changes !== nothing
stoverwrite1!(currstate, changes)
end
while !isempty(frame.curr_stmt_changes)
stmtchange = pop!(frame.curr_stmt_changes)
if changes !== nothing && stmtchange.slot == changes.var
continue # type propagation from statement (like assignment) should have the precedence
end
vtype = currstate[slot_id(stmtchange.slot)]
if (𝕃ᵢ, stmtchange.typ, vtype.typ)
stmtupdate = StateUpdate(stmtchange.slot, VarState(stmtchange.typ, vtype.undef), false)
stoverwrite1!(currstate, stmtupdate)
end
end
if rt === nothing
ssavaluetypes[currpc] = Any
continue
Expand Down
13 changes: 12 additions & 1 deletion base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,12 @@ struct HandlerInfo
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

struct StmtChange
slot::SlotNumber
typ
StmtChange(slot::SlotNumber, @nospecialize(typ)) = new(slot, typ)
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -249,6 +255,9 @@ mutable struct InferenceState
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
stmt_info::Vector{CallInfo}
# additional state updates at current statement made by means other than the assignment
# e.g. type information refinement from `typeassert` call itself
curr_stmt_changes::Vector{StmtChange}

#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand Down Expand Up @@ -301,6 +310,8 @@ mutable struct InferenceState
stmt_edges = Vector{Vector{Any}}(undef, nstmts)
stmt_info = CallInfo[ NoCallInfo() for i = 1:nstmts ]

curr_stmt_changes = StmtChange[]

nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
Expand Down Expand Up @@ -350,7 +361,7 @@ mutable struct InferenceState

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, curr_stmt_changes,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand Down
98 changes: 98 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4200,6 +4200,104 @@ end
end
end == [Union{Some{Float64}, Some{Int}, Some{UInt8}}]

@testset "constraint back-propagation from typeassert" begin
@test Base.infer_return_type((Any,)) do a
typeassert(a, Int)
return a
end == Int

@test Base.infer_return_type((Any,Bool)) do a, b
if b
typeassert(a, Int64)
else
typeassert(a, Int32)
end
return a
end == Union{Int32,Int64}
end

callsig_backprop_basic(::Int) = nothing
callsig_backprop_unionsplit(::Int32) = nothing
callsig_backprop_unionsplit(::Int64) = nothing
callsig_backprop_multi1(::Int) = nothing
callsig_backprop_multi2(::Nothing) = nothing
callsig_backprop_any(::Any) = nothing
callsig_backprop_lhs(::Int) = nothing
callsig_backprop_bailout(::Val{0}) = 0
callsig_backprop_bailout(::Val{1}) = undefvar # undefvar::Any triggers `bail_out_call`
callsig_backprop_bailout(::Val{2}) = 2
callsig_backprop_addinteger(a::Integer, b::Integer) = a + b # results in too many matching methods and triggers `bail_out_call`)
@test Base.infer_return_type(callsig_backprop_addinteger) == Any
let effects = Base.infer_effects(callsig_backprop_addinteger)
@test !Core.Compiler.is_consistent(effects)
@test !Core.Compiler.is_effect_free(effects)
@test !Core.Compiler.is_nothrow(effects)
@test !Core.Compiler.is_terminates(effects)
end

@testset "constraint back-propagation from call signature" begin
# basic case
@test Base.infer_return_type(a->(callsig_backprop_basic(a); return a), (Any,)) == Int

# union-split case
@test Base.infer_return_type(a->(callsig_backprop_unionsplit(a); return a), (Any,)) == Union{Int32,Int64}

# multiple state updates
@test Base.infer_return_type((Any,Any)) do a, b
callsig_backprop_multi1(a)
callsig_backprop_multi2(b)
return a, b
end == Tuple{Int,Nothing}

# refinement should happen only when it's worthwhile
@test Base.infer_return_type(a->(callsig_backprop_any(a); return a), (Integer,)) == Integer

# state update on lhs slot (assignment effect should have the precedence)
@test Base.infer_return_type((Any,)) do a
a = callsig_backprop_lhs(a)
return a
end == Nothing

# make sure to throw away an intermediate refinement information when we bail out early
# (inference would bail out on `callsig_backprop_bailout(::Val{1})`)
@test Base.infer_return_type(a->(callsig_backprop_bailout(a); return a), (Any,)) == Any

# if we see all the matching methods, we don't need to throw away refinement information
# even if it's caught by `bail_out_call` check
@test Base.infer_return_type((Any,Any)) do a, b
callsig_backprop_addinteger(a, b)
return a, b
end == Tuple{Integer,Integer}
end

# make sure to add backedges when we use call signature constraint
function callsig_backprop_invalidation_outer(a)
callsig_backprop_invalidation_inner!(a)
return a
end
@eval callsig_backprop_invalidation_inner!(::Int) = $(gensym(:undefvar)) # ::Any
@test Base.infer_return_type((Any,)) do a
callsig_backprop_invalidation_outer(a)
end == Int
# new definition of `callsig_backprop_invalidation_inner!` should invalidate `callsig_backprop_invalidation_outer`
# (even if the previous return type is annotated as `Any`)
@eval callsig_backprop_invalidation_inner!(::Nothing) = $(gensym(:undefvar)) # ::Any
@test Base.infer_return_type((Any,)) do a
# since inference will bail out at the first matched `_inner!` and so call signature constraint won't be available
callsig_backprop_invalidation_outer(a)
end Int

# https://github.com/JuliaLang/julia/issues/37866
function issue37866(v::Vector{Union{Nothing,Float64}})
for x in v
if x > 5.0
return x # x > 5.0 is MethodError for Nothing so can assume ::Float64
end
end
return 0.0
end
@test Base.infer_return_type(issue37866, (Vector{Union{Nothing,Float64}},)) == Float64

# make sure inference on a recursive call graph with nested `Type`s terminates
# https://github.com/JuliaLang/julia/issues/40336
f40336(@nospecialize(t)) = f40336(Type{t})
Expand Down

0 comments on commit 6e94124

Please sign in to comment.