Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reorder global/local caching managements #571

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions examples/dispatch_analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ function CC.finish!(analyzer::DispatchAnalyzer, frame::CC.InferenceState)

## get the source before running `finish!` to keep the reference to `OptimizationState`
src = caller.src

## run `finish!(::AbstractAnalyzer, ::CC.InferenceState)` first to convert the optimized `IRCode` into optimized `CodeInfo`
ret = Base.@invoke CC.finish!(analyzer::AbstractAnalyzer, frame::CC.InferenceState)
if src isa CC.OptimizationState{typeof(analyzer)}
## allow the following analysis passes to see the optimized `CodeInfo`
caller.src = CC.ir_to_codeinf!(src)
end

if analyzer.frame_filter(frame.linfo)
if isa(src, Core.Const) # the optimization was very successful, nothing to report
Expand All @@ -93,7 +94,7 @@ function CC.finish!(analyzer::DispatchAnalyzer, frame::CC.InferenceState)
end
end

return ret
return @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::CC.InferenceState)
end

@jetreport struct OptimizationFailureReport <: InferenceErrorReport end
Expand Down
78 changes: 45 additions & 33 deletions src/abstractinterpret/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ end
function CC.transform_result_for_cache(analyzer::AbstractAnalyzer,
linfo::MethodInstance, valid_worlds::WorldRange, result::InferenceResult)
cache = InferenceErrorReport[]
for report in get_reports(analyzer, result)
for report in get_any_reports(analyzer, result)
@static if JET_DEV_MODE
actual, expected = first(report.vst).linfo, linfo
@assert actual === expected "invalid global caching detected, expected $expected but got $actual"
Expand Down Expand Up @@ -464,6 +464,44 @@ function filter_lineages!(analyzer::AbstractAnalyzer, caller::InferenceResult, c
filter!(!islineage(caller.linfo, current), get_reports(analyzer, caller))
end

function finish_frame!(analyzer::AbstractAnalyzer, frame::InferenceState)
caller = frame.result

reports = get_reports(analyzer, caller)

# XXX this is a dirty fix for performance problem, we need more "proper" fix
# https://github.com/aviatesk/JET.jl/issues/75
unique!(aggregation_policy(analyzer), reports)

if frame.parent !== nothing
# inter-procedural handling: get back to the caller what we got from these results
stash_report!(analyzer, reports)

# local cache management
# TODO there are duplicated work here and `transform_result_for_cache`
cache_reports_locally!(analyzer, caller, reports)
end
end

function cache_reports_locally!(analyzer::AbstractAnalyzer, caller::InferenceResult,
reports::Vector{InferenceErrorReport})
cache = InferenceErrorReport[]
for report in reports
cache_report!(cache, report)
end
set_cached_result!(analyzer, caller, cache)
end

@static if VERSION ≥ v"1.11.0-DEV.737"

function CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)
ret = @invoke CC.finish!(analyzer::AbstractInterpreter, frame::InferenceState)
finish_frame!(analyzer, frame)
return ret
end

else

# in this overload we can work on `frame.src::CodeInfo` (and also `frame::InferenceState`)
# where type inference (and also optimization if applied) already ran on
function CC._typeinf(analyzer::AbstractAnalyzer, frame::InferenceState)
Expand All @@ -486,14 +524,6 @@ function CC._typeinf(analyzer::AbstractAnalyzer, frame::InferenceState)
caller.inferred = true
end
end
# NOTE we don't discard `InferenceState`s here so that some analyzers can use them in `finish!`
# # collect results for the new expanded frame
# results = Tuple{InferenceResult, Vector{Any}, Bool}[
# ( frames[i].result,
# frames[i].stmt_edges[1]::Vector{Any},
# frames[i].cached )
# for i in 1:length(frames) ]
# empty!(frames)
for frame in frames
caller = frame.result
opt = caller.src
Expand All @@ -505,42 +535,20 @@ function CC._typeinf(analyzer::AbstractAnalyzer, frame::InferenceState)
end
end
end

for frame in frames
caller = frame.result
edges = frame.stmt_edges[1]::Vector{Any}
cached = frame.cached
valid_worlds = caller.valid_worlds
if CC.last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
CC.store_backedges(caller, edges)
end
CC.finish!(analyzer, frame)

reports = get_reports(analyzer, caller)

# XXX this is a dirty fix for performance problem, we need more "proper" fix
# https://github.com/aviatesk/JET.jl/issues/75
unique!(aggregation_policy(analyzer), reports)

# global cache management
if cached && !istoplevel(frame)
if frame.cached && !istoplevel(frame)
CC.cache_result!(analyzer, caller)
end

if frame.parent !== nothing
# inter-procedural handling: get back to the caller what we got from these results
stash_report!(analyzer, reports)

# local cache management
# TODO there are duplicated work here and `transform_result_for_cache`
cache = InferenceErrorReport[]
for report in reports
cache_report!(cache, report)
end
set_cached_result!(analyzer, caller, cache)
end
end

return true
Expand All @@ -550,7 +558,11 @@ end
# but the only reason we have this overload is that some analyzers (like `JETAnalyzer`)
# can further overload this to generate `InferenceErrorReport` with an access to `frame`
function CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)
return CC.finish!(analyzer, frame.result)
ret = CC.finish!(analyzer, frame.result)
finish_frame!(analyzer, frame)
return ret
end

end

# top-level bridge
Expand Down
10 changes: 7 additions & 3 deletions src/analyzers/jetanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,17 +251,21 @@ function CC.InferenceState(result::InferenceResult, cache::Symbol, analyzer::JET
end

function CC.finish!(analyzer::JETAnalyzer, caller::InferenceState)
ret = @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceState)
src = caller.result.src
if src isa OptimizationState
# allow the following analysis passes to see the optimized `CodeInfo`
src = caller.result.src = CC.ir_to_codeinf!(src)
end

if isnothing(src)
# caught in cycle, similar error should have been reported where the source is available
return ret
else
code = (src::CodeInfo).code
# report pass for uncaught `throw` calls
ReportPass(analyzer)(UncaughtExceptionReport, analyzer, caller, code)
return ret
end

return @invoke CC.finish!(analyzer::AbstractAnalyzer, caller::InferenceState)
end

function CC.abstract_call_gf_by_type(analyzer::JETAnalyzer,
Expand Down
33 changes: 17 additions & 16 deletions src/analyzers/optanalyzer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -291,37 +291,38 @@ function CC.finish!(analyzer::OptAnalyzer, frame::InferenceState)
caller = frame.result

# get the source before running `finish!` to keep the reference to `OptimizationState`
oldsrc = caller.src

ret = @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)

newsrc = caller.src

analyze = popfirst!(analyzer.__analyze_frame)
if !analyze && isa(newsrc, CodeInfo)
# if this inferred source is not "compileable" but still is going to be inlined,
# we should add report runtime dispatches within it
analyze = CC.is_inlineable(newsrc)
src = caller.src
if src isa OptimizationState{typeof(analyzer)}
# allow the following analysis passes to see the optimized `CodeInfo`
caller.src = CC.ir_to_codeinf!(src)

if !analyze
# if this inferred source is not "compileable" but still is going to be inlined,
# we should add report runtime dispatches within it
analyze = CC.is_inlineable(src.src)
end
end

if analyze
ReportPass(analyzer)(OptimizationFailureReport, analyzer, caller)

if oldsrc isa OptimizationState{typeof(analyzer)}
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, caller, oldsrc)
if src isa OptimizationState{typeof(analyzer)}
ReportPass(analyzer)(RuntimeDispatchReport, analyzer, caller, src)
elseif (@static JET_DEV_MODE ? true : false)
if (@static VERSION < v"1.10.0-DEV.551" && true) && isa(oldsrc, CC.ConstAPI)
if (@static VERSION < v"1.10.0-DEV.551" && true) && isa(src, CC.ConstAPI)
# the optimization was very successful (i.e. fully constant folded),
# nothing to report
elseif oldsrc === nothing # the optimization didn't happen
elseif src === nothing # the optimization didn't happen
else # and this pass should never happen
# NOTE `src` never be `CodeInfo` since `CC.may_discard_trees(::OptAnalyzer) === false`
Core.eval(@__MODULE__, :(oldsrc = $oldsrc))
Core.eval(@__MODULE__, :(src = $src))
throw("unexpected state happened, inspect `$(@__MODULE__).src`")
end
end
end

return ret
return @invoke CC.finish!(analyzer::AbstractAnalyzer, frame::InferenceState)
end

# report optimization failure due to recursive calls, etc.
Expand Down
Loading