Skip to content

Commit

Permalink
make CallInfo propagate the edges list of CodeInstances
Browse files Browse the repository at this point in the history
Remaining TODOs:
- Finalize the format for `sv.edges`. There might be cases where no
  `edge::CodeInstance` exists as a result of `abstract_call_method`,
  and in such cases, we might still need to use `MethodInstance` in the
  `edges` list.
- Ensure that when the local caching mode is specified (i.e. for
  const-prop'ed calls and call-site-inlined calls), the const-propped
  edge should be propagated instead of the regular edge.
- Make use of the `CodeInstance` held by `CallInfo` during inlining
  for slightly better performance by avoiding the global cache lookup.
  • Loading branch information
aviatesk committed Oct 23, 2024
1 parent 4d1a1f1 commit 3b57317
Show file tree
Hide file tree
Showing 6 changed files with 168 additions and 145 deletions.
105 changes: 57 additions & 48 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
napplicable = length(applicable)
multiple_matches = napplicable > 1
while i <= napplicable
match = applicable[i]::MethodMatch
(; match, edges, edge_idx) = applicable[i]
method = match.method
sig = match.spec_types
if bail_out_toplevel_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
Expand All @@ -94,7 +94,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
#end
mresult = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, si, sv)::Future
function handle1(interp, sv)
local (; rt, exct, effects, volatile_inf_result) = mresult[]
local (; rt, exct, effects, edge, volatile_inf_result) = mresult[]
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
this_exct = exct
Expand All @@ -106,6 +106,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
mresult[], f.contents, this_arginfo, si, match, sv)
const_result = volatile_inf_result
if const_call_result !== nothing
# TODO override the edge with the const-prop' edge
this_const_conditional = ignorelimited(const_call_result.rt)
this_const_rt = widenwrappedconditional(const_call_result.rt)
if this_const_rt ₚ this_rt
Expand Down Expand Up @@ -163,14 +164,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = conditionals[2][i] ᵢ cnd.elsetype
end
end
edges[edge_idx] = edge
if i < napplicable && bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.")
seenall = false
i = napplicable # break in outer function
end
i += 1
return true
end
end # function handle1
if isready(mresult) && handle1(interp, sv)
continue
else
Expand Down Expand Up @@ -206,7 +208,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
if (isa(sv, InferenceState) && infer_compilation_signature(interp) &&
(seenall && 1 == napplicable) && rettype !== Any && rettype !== Bottom &&
!is_removable_if_unused(all_effects))
match = applicable[1]::MethodMatch
(; match) = applicable[1]
method = match.method
sig = match.spec_types
mi = specialize_method(match; preexisting=true)
Expand Down Expand Up @@ -243,7 +245,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),

gfresult[] = CallMeta(rettype, exctype, all_effects, info, slotrefinements)
return true
end # infercalls
end # function infercalls
# start making progress on the first call
infercalls(interp, sv) || push!(sv.tasks, infercalls)
return gfresult
Expand All @@ -253,8 +255,14 @@ struct FailedMethodMatch
reason::String
end

struct MethodMatchTarget
match::MethodMatch
edges::Vector{Union{Nothing,CodeInstance}}
edge_idx::Int
end

struct MethodMatches
applicable::Vector{Any}
applicable::Vector{MethodMatchTarget}
info::MethodMatchInfo
valid_worlds::WorldRange
end
Expand All @@ -265,7 +273,7 @@ fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)

struct UnionSplitMethodMatches
applicable::Vector{Any}
applicable::Vector{MethodMatchTarget}
applicable_argtypes::Vector{Vector{Any}}
info::UnionSplitInfo
valid_worlds::WorldRange
Expand Down Expand Up @@ -301,7 +309,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
@nospecialize(atype), max_methods::Int)
split_argtypes = switchtupleunion(typeinf_lattice(interp), argtypes)
infos = MethodMatchInfo[]
applicable = Any[]
applicable = MethodMatchTarget[]
applicable_argtypes = Vector{Any}[] # arrays like `argtypes`, including constants, for each match
valid_worlds = WorldRange()
for i in 1:length(split_argtypes)
Expand All @@ -314,14 +322,14 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
if thismatches === nothing
return FailedMethodMatch("For one of the union split cases, too many methods matched")
end
for m in thismatches
push!(applicable, m)
push!(applicable_argtypes, arg_n)
end
valid_worlds = intersect(valid_worlds, thismatches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, thismatches)
thisinfo = MethodMatchInfo(thismatches, mt, sig_n, thisfullmatch)
push!(infos, thisinfo)
for idx = 1:length(thismatches)
push!(applicable, MethodMatchTarget(thismatches[idx], thisinfo.edges, idx))
push!(applicable_argtypes, arg_n)
end
end
info = UnionSplitInfo(infos)
return UnionSplitMethodMatches(
Expand All @@ -342,7 +350,8 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
end
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
info = MethodMatchInfo(matches, mt, atype, fullmatch)
return MethodMatches(matches.matches, info, matches.valid_worlds)
applicable = MethodMatchTarget[MethodMatchTarget(matches[idx], info.edges, idx) for idx = 1:length(matches)]
return MethodMatches(applicable, info, matches.valid_worlds)
end

"""
Expand Down Expand Up @@ -513,7 +522,7 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
end
end

function collect_slot_refinements(𝕃ᵢ::AbstractLattice, applicable::Vector{Any},
function collect_slot_refinements(𝕃ᵢ::AbstractLattice, applicable::Vector{MethodMatchTarget},
argtypes::Vector{Any}, fargs::Vector{Any}, sv::InferenceState)
, = strictpartialorder(𝕃ᵢ), join(𝕃ᵢ)
slotrefinements = nothing
Expand All @@ -527,7 +536,7 @@ function collect_slot_refinements(𝕃ᵢ::AbstractLattice, applicable::Vector{A
end
sigt = Bottom
for j = 1:length(applicable)
match = applicable[j]::MethodMatch
(;match) = applicable[j]
valid_as_lattice(match.spec_types, true) || continue
sigt = sigt fieldtype(match.spec_types, i)
end
Expand All @@ -551,9 +560,9 @@ function abstract_call_method(interp::AbstractInterpreter,
hardlimit::Bool, si::StmtInfo, sv::AbsIntState)
sigtuple = unwrap_unionall(sig)
sigtuple isa DataType ||
return Future(MethodCallResult(Any, Any, false, false, nothing, Effects()))
return Future(MethodCallResult(Any, Any, Effects(), nothing, false, false))
all(@nospecialize(x) -> valid_as_lattice(unwrapva(x), true), sigtuple.parameters) ||
return Future(MethodCallResult(Union{}, Any, false, false, nothing, EFFECTS_THROWS)) # catch bad type intersections early
return Future(MethodCallResult(Union{}, Any, EFFECTS_THROWS, nothing, false, false)) # catch bad type intersections early

if is_nospecializeinfer(method)
sig = get_nospecializeinfer_sig(method, sig, sparams)
Expand All @@ -578,7 +587,7 @@ function abstract_call_method(interp::AbstractInterpreter,
# we have a self-cycle in the call-graph, but not in the inference graph (typically):
# break this edge now (before we record it) by returning early
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return Future(MethodCallResult(Any, Any, true, true, nothing, Effects()))
return Future(MethodCallResult(Any, Any, Effects(), nothing, true, true))
end
topmost = nothing
edgecycle = true
Expand Down Expand Up @@ -633,7 +642,7 @@ function abstract_call_method(interp::AbstractInterpreter,
# since it's very unlikely that we'll try to inline this,
# or want make an invoke edge to its calling convention return type.
# (non-typically, this means that we lose the ability to detect a guaranteed StackOverflow in some cases)
return Future(MethodCallResult(Any, Any, true, true, nothing, Effects()))
return Future(MethodCallResult(Any, Any, Effects(), nothing, true, true))
end
add_remark!(interp, sv, washardlimit ? RECURSION_MSG_HARDLIMIT : RECURSION_MSG)
# TODO (#48913) implement a proper recursion handling for irinterp:
Expand Down Expand Up @@ -759,9 +768,9 @@ function matches_sv(parent::AbsIntState, sv::AbsIntState)
method_for_inference_limit_heuristics(sv) === method_for_inference_limit_heuristics(parent))
end

function is_edge_recursed(edge::MethodInstance, caller::AbsIntState)
function is_edge_recursed(edge::CodeInstance, caller::AbsIntState)
return any(AbsIntStackUnwind(caller)) do sv::AbsIntState
return edge === frame_instance(sv)
return edge.def === frame_instance(sv)
end
end

Expand All @@ -788,18 +797,15 @@ end
struct MethodCallResult
rt
exct
effects::Effects
edge::Union{Nothing,CodeInstance}
edgecycle::Bool
edgelimited::Bool
edge::Union{Nothing,MethodInstance}
effects::Effects
volatile_inf_result::Union{Nothing,VolatileInferenceResult}
function MethodCallResult(@nospecialize(rt), @nospecialize(exct),
edgecycle::Bool,
edgelimited::Bool,
edge::Union{Nothing,MethodInstance},
effects::Effects,
function MethodCallResult(@nospecialize(rt), @nospecialize(exct), effects::Effects,
edge::Union{Nothing,CodeInstance}, edgecycle::Bool, edgelimited::Bool,
volatile_inf_result::Union{Nothing,VolatileInferenceResult}=nothing)
return new(rt, exct, edgecycle, edgelimited, edge, effects, volatile_inf_result)
return new(rt, exct, effects, edge, edgecycle, edgelimited, volatile_inf_result)
end
end

Expand All @@ -809,12 +815,12 @@ struct InvokeCall
InvokeCall(@nospecialize(types), @nospecialize(lookupsig)) = new(types, lookupsig)
end

struct ConstCallResults
struct ConstCallResult
rt::Any
exct::Any
const_result::ConstResult
effects::Effects
function ConstCallResults(
function ConstCallResult(
@nospecialize(rt), @nospecialize(exct),
const_result::ConstResult,
effects::Effects)
Expand Down Expand Up @@ -901,7 +907,8 @@ function concrete_eval_eligible(interp::AbstractInterpreter,
return :none
end
end
mi = result.edge
codeinst = result.edge
mi = codeinst === nothing ? nothing : codeinst.def
if mi !== nothing && is_foldable(effects, #=check_rtcall=#true)
if f !== nothing && is_all_const_arg(arginfo, #=start=#2)
if (is_nonoverlayed(interp) || is_nonoverlayed(effects) ||
Expand Down Expand Up @@ -964,15 +971,15 @@ function concrete_eval_call(interp::AbstractInterpreter,
f = invoke
end
world = get_inference_world(interp)
edge = result.edge::MethodInstance
edge = (result.edge::CodeInstance).def
value = try
Core._call_in_world_total(world, f, args...)
catch e
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime.
# Howevever, at present, :consistency does not mandate the type of the exception
return ConstCallResults(Bottom, Any, ConcreteResult(edge, result.effects), result.effects)
return ConstCallResult(Bottom, Any, ConcreteResult(edge, result.effects), result.effects)
end
return ConstCallResults(Const(value), Union{}, ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
return ConstCallResult(Const(value), Union{}, ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
end

# check if there is a cycle and duplicated inference of `mi`
Expand Down Expand Up @@ -1216,9 +1223,9 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::AbsIntState)
world = frame_world(sv)
mi_cache = WorldView(code_cache(interp), world)
code = get(mi_cache, mi, nothing)
if code !== nothing
irsv = IRInterpretationState(interp, code, mi, arginfo.argtypes, world)
codeinst = get(mi_cache, mi, nothing)
if codeinst !== nothing
irsv = IRInterpretationState(interp, codeinst, mi, arginfo.argtypes, world)
if irsv !== nothing
assign_parentchild!(irsv, sv)
rt, (nothrow, noub) = ir_abstract_constant_propagation(interp, irsv)
Expand All @@ -1237,16 +1244,16 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
effects = Effects(effects; noub=ALWAYS_TRUE)
end
exct = refine_exception_type(result.exct, effects)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects, spec_info(irsv)), effects)
return ConstCallResult(rt, exct, SemiConcreteResult(codeinst, ir, effects, spec_info(irsv)), effects)
end
end
end
return nothing
end

const_prop_result(inf_result::InferenceResult) =
ConstCallResults(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result),
inf_result.ipo_effects)
ConstCallResult(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result),
inf_result.ipo_effects)

# return cached result of constant analysis
return_localcache_result(::AbstractInterpreter, inf_result::InferenceResult, ::AbsIntState) =
Expand All @@ -1259,7 +1266,7 @@ end

function const_prop_call(interp::AbstractInterpreter,
mi::MethodInstance, result::MethodCallResult, arginfo::ArgInfo, sv::AbsIntState,
concrete_eval_result::Union{Nothing, ConstCallResults}=nothing)
concrete_eval_result::Union{Nothing, ConstCallResult}=nothing)
inf_cache = get_inference_cache(interp)
𝕃ᵢ = typeinf_lattice(interp)
forwarded_argtypes = compute_forwarded_argtypes(interp, arginfo, sv)
Expand Down Expand Up @@ -1645,7 +1652,7 @@ function abstract_iteration(interp::AbstractInterpreter, @nospecialize(itft), @n
end
iterateresult[] = AbstractIterationResult(ret, AbstractIterationInfo(calls, false))
return true
end # inferiterate_2arg
end # function inferiterate_2arg
# continue making progress as much as possible, on iterate(arg, state)
inferiterate_2arg(interp, sv) || push!(sv.tasks, inferiterate_2arg)
return true
Expand Down Expand Up @@ -1815,7 +1822,7 @@ function abstract_apply(interp::AbstractInterpreter, argtypes::Vector{Any}, si::
# For now, only propagate info if we don't also union-split the iteration
applyresult[] = CallMeta(res, exctype, all_effects, retinfo)
return true
end
end # function infercalls
# start making progress on the first call
infercalls(interp, sv) || push!(sv.tasks, infercalls)
return applyresult
Expand Down Expand Up @@ -2184,7 +2191,7 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
mresult = abstract_call_method(interp, method, ti, env, false, si, sv)::Future
match = MethodMatch(ti, env, method, argtype <: method.sig)
return Future{CallMeta}(mresult, interp, sv) do result, interp, sv
(; rt, exct, effects, volatile_inf_result) = result
(; rt, exct, effects, edge, volatile_inf_result) = result
res = nothing
sig = match.spec_types
argtypes′ = invoke_rewrite(argtypes)
Expand All @@ -2204,6 +2211,7 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
result, f, arginfo, si, match, sv, invokecall)
const_result = volatile_inf_result
if const_call_result !== nothing
# TODO override the edge with the const-prop' edge
if const_call_result.rt rt
(; rt, effects, const_result) = const_call_result
end
Expand All @@ -2212,7 +2220,7 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
info = InvokeCallInfo(match, const_result, lookupsig)
info = InvokeCallInfo(edge, match, const_result, lookupsig)
if !match.fully_covers
effects = Effects(effects; nothrow=false)
exct = exct TypeError
Expand Down Expand Up @@ -2408,14 +2416,15 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
mresult = abstract_call_method(interp, ocmethod, sig, Core.svec(), false, si, sv)
ocsig_box = Core.Box(ocsig)
return Future{CallMeta}(mresult, interp, sv) do result, interp, sv
(; rt, exct, effects, volatile_inf_result, edgecycle) = result
(; rt, exct, effects, volatile_inf_result, edge, edgecycle) = result
𝕃ₚ = ipo_lattice(interp)
, , = partialorder(𝕃ₚ), strictneqpartialorder(𝕃ₚ), join(𝕃ₚ)
const_result = volatile_inf_result
if !edgecycle
const_call_result = abstract_call_method_with_const_args(interp, result,
#=f=#nothing, arginfo, si, match, sv)
if const_call_result !== nothing
# TODO override the edge with the const-prop' edge
if const_call_result.rt rt
(; rt, effects, const_result) = const_call_result
end
Expand All @@ -2434,7 +2443,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo, match.spec_types)
info = OpaqueClosureCallInfo(match, const_result)
info = OpaqueClosureCallInfo(edge, match, const_result)
return CallMeta(rt, exct, effects, info)
end
end
Expand Down
Loading

0 comments on commit 3b57317

Please sign in to comment.