From 40451573131bdf8293707bcef1a4f6aa94db94da Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 17 Oct 2024 04:22:08 -0400 Subject: [PATCH] Adjust to stackless compiler changes (#9) * Adjust to stackless compiler changes Depends on: - https://github.com/JuliaDiff/Diffractor.jl/pull/295 - https://github.com/JuliaLang/julia/pull/55972 * More compiler adjust --- Manifest.toml | 10 ++-- src/analysis/compiler.jl | 7 +-- src/analysis/interpreter.jl | 97 ++++++++++++++++++------------------- src/transform/common.jl | 2 +- 4 files changed, 59 insertions(+), 57 deletions(-) diff --git a/Manifest.toml b/Manifest.toml index be6aff4..3d16ae7 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -444,9 +444,11 @@ version = "1.15.1" [[deps.Diffractor]] deps = ["AbstractDifferentiation", "ChainRules", "ChainRulesCore", "Combinatorics", "Cthulhu", "InteractiveUtils", "OffsetArrays", "PrecompileTools", "StaticArrays", "StructArrays"] -git-tree-sha1 = "e9472ffeff4ec8958e96cf3ddcae5e700cbeacbd" +git-tree-sha1 = "2a9b827fce47e27ef32471df18a96dc4ff1123bd" +repo-rev = "kf/compileradjust" +repo-url = "https://github.com/JuliaDiff/Diffractor.jl.git" uuid = "9f5e2b26-1114-432f-b630-d3fe2085c51c" -version = "0.2.10" +version = "0.2.8" [[deps.Distances]] deps = ["LinearAlgebra", "Statistics", "StatsAPI"] @@ -1235,7 +1237,7 @@ weakdeps = ["Adapt"] [[deps.OpenBLAS_jll]] deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" -version = "0.3.27+1" +version = "0.3.28+2" [[deps.OpenLibm_jll]] deps = ["Artifacts", "Libdl"] @@ -1830,7 +1832,7 @@ uuid = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9" [[deps.SuiteSparse_jll]] deps = ["Artifacts", "Libdl", "libblastrampoline_jll"] uuid = "bea87d4a-7f5b-5778-9afe-8cc45184846c" -version = "7.7.0+0" +version = "7.8.0+0" [[deps.Sundials]] deps = ["CEnum", "DataStructures", "DiffEqBase", "Libdl", "LinearAlgebra", "Logging", "PrecompileTools", "Reexport", "SciMLBase", "SparseArrays", "Sundials_jll"] diff --git a/src/analysis/compiler.jl b/src/analysis/compiler.jl index 971fd74..8b91323 100644 --- a/src/analysis/compiler.jl +++ b/src/analysis/compiler.jl @@ -717,7 +717,8 @@ end record_ir!(debug_config, "pre_incidence_propagation", ir) # TODO better work here? - method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing) + (nargs, isva) = isa(mi.def, Method) ? (mi.def.nargs, mi.def.isva) : (0, false) + method_info = CC.SpecInfo(nargs, isva, #=propagate_inbounds=#true, nothing) min_world = world = get_inference_world(interp) max_world = get_world_counter() if caller !== nothing @@ -726,7 +727,7 @@ end analysis_interp = DAEInterpreter(interp; var_to_diff, var_kind, eq_kind, in_analysis=interp.ipo_analysis_mode) irsv = CC.IRInterpretationState(analysis_interp, method_info, ir, mi, argtypes, world, min_world, max_world) - ultimate_rt, _ = CC._ir_abstract_constant_propagation(analysis_interp, irsv; externally_refined) + ultimate_rt, _ = CC.ir_abstract_constant_propagation(analysis_interp, irsv; externally_refined) record_ir!(debug_config, "incidence_propagation", ir) # Encountering a `ddt` during abstract interpretation can add variables, @@ -745,7 +746,7 @@ end # recalculate domtree (inference could have changed the cfg) domtree = CC.construct_domtree(ir.cfg.blocks) - # We use the _ir_abstract_constant_propagation pass for three things: + # We use the ir_abstract_constant_propagation pass for three things: # 1. To establish incidence # 2. To constant propagate scope information that may not have been # available at inference time diff --git a/src/analysis/interpreter.jl b/src/analysis/interpreter.jl index 261b622..66f5dc1 100644 --- a/src/analysis/interpreter.jl +++ b/src/analysis/interpreter.jl @@ -7,7 +7,7 @@ using .CC: AbstractInterpreter, NativeInterpreter, InferenceParams, Optimization StmtInfo, MethodCallResult, ConstCallResults, ConstPropResult, MethodTableView, CachedMethodTable, InternalMethodTable, OverlayMethodTable, CallMeta, CallInfo, IRCode, LazyDomtree, IRInterpretationState, set_inlineable!, block_for_inst, - BitSetBoundedMinPrioritySet, AbsIntState + BitSetBoundedMinPrioritySet, AbsIntState, Future using Base: IdSet using StateSelection: DiffGraph @@ -282,13 +282,13 @@ widenincidence(@nospecialize(x)) = x if length(argtypes) == 2 xarg = argtypes[2] if isa(xarg, Union{Incidence, Const}) - return structural_inc_ddt(interp.var_to_diff, interp.var_kind, xarg) + return Future{CallMeta}(structural_inc_ddt(interp.var_to_diff, interp.var_kind, xarg)) end end end if interp.in_analysis && !isa(f, Core.Builtin) && !isa(f, Core.IntrinsicFunction) # We don't want to do new inference here - return CallMeta(Any, Any, CC.Effects(), CC.NoCallInfo()) + return Future{CallMeta}(CallMeta(Any, Any, CC.Effects(), CC.NoCallInfo())) end ret = @invoke CC.abstract_call_known(interp::AbstractInterpreter, f::Any, arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int) @@ -305,27 +305,27 @@ widenincidence(@nospecialize(x)) = x end end arginfo = ArgInfo(arginfo.fargs, map(widenincidence, arginfo.argtypes)) - r = Diffractor.fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret) - r !== nothing && return r - return ret + return Diffractor.fwd_abstract_call_gf_by_type(interp, f, arginfo, si, sv, ret) end @override function CC.abstract_call_method(interp::DAEInterpreter, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState) - ret = @invoke CC.abstract_call_method(interp::AbstractInterpreter, + mret = @invoke CC.abstract_call_method(interp::AbstractInterpreter, method::Method, sig::Any, sparams::SimpleVector, hardlimit::Bool, si::StmtInfo, sv::InferenceState) - edge = ret.edge - if edge !== nothing - cache = CC.get(CC.code_cache(interp), edge, nothing) - if cache !== nothing - src = @atomic :monotonic cache.inferred - if isa(src, DAECache) - info = src.info - merge_daeinfo!(interp, sv.result, info) + return Future{MethodCallResult}(mret, interp, sv) do ret, interp, sv + edge = ret.edge + if edge !== nothing + cache = CC.get(CC.code_cache(interp), edge, nothing) + if cache !== nothing + src = @atomic :monotonic cache.inferred + if isa(src, DAECache) + info = src.info + merge_daeinfo!(interp, sv.result, info) + end end end + return ret end - return ret end @override function CC.const_prop_call(interp::DAEInterpreter, @@ -443,9 +443,9 @@ end # TODO propagate debug configurations here @override function CC.optimize(interp::DAEInterpreter, opt::OptimizationState, caller::InferenceResult) - ir = CC.run_passes_ipo_safe(opt.src, opt, caller) + ir = CC.run_passes_ipo_safe(opt.src, opt) ir = run_dae_passes(interp, ir) - CC.ipo_dataflow_analysis!(interp, ir, caller) + CC.ipo_dataflow_analysis!(interp, opt, ir, caller) if interp.ipo_analysis_mode result = ipo_dae_analysis!(interp, ir, caller.linfo, caller) if result !== nothing @@ -524,14 +524,10 @@ end src === nothing && return nothing (; inferred, ir) = src::DAECache (isa(inferred, CodeInfo) && isa(ir, IRCode)) || return nothing - method_info = CC.MethodInfo(inferred) + method_info = CC.SpecInfo(inferred) ir = copy(ir) (; min_world, max_world) = inferred - if Base.__has_internal_change(v"1.12-alpha", :codeinfonargs) - argtypes = CC.va_process_argtypes(CC.optimizer_lattice(interp), argtypes, inferred.nargs, inferred.isva) - elseif VERSION >= v"1.12.0-DEV.341" - argtypes = CC.va_process_argtypes(CC.optimizer_lattice(interp), argtypes, mi) - end + argtypes = CC.va_process_argtypes(CC.optimizer_lattice(interp), argtypes, inferred.nargs, inferred.isva) return IRInterpretationState(interp, method_info, ir, mi, argtypes, world, min_world, max_world) end @@ -974,34 +970,36 @@ function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instr end @override function CC.abstract_eval_statement_expr(interp::DAEInterpreter, inst::Expr, vtypes::Nothing, irsv::IRInterpretationState) - (; rt, exct, effects) = @invoke CC.abstract_eval_statement_expr(interp::AbstractInterpreter, inst::Expr, vtypes::Nothing, irsv::IRInterpretationState) - - if (!interp.ipo_analysis_mode || interp.in_analysis) && !isa(rt, Const) && !isa(rt, Incidence) && !CC.isType(rt) && !is_all_inc_or_const(Any[rt]) - argtypes = CC.collect_argtypes(interp, inst.args, nothing, irsv) - if argtypes === nothing - return CC.RTEffects(rt, exct, effects) - end - if is_all_inc_or_const(argtypes) - if inst.head in (:call, :invoke) && CC.hasintersect(widenconst(argtypes[inst.head === :call ? 1 : 2]), Union{typeof(variable), typeof(sim_time), typeof(state_ddt)}) - # The `variable` and `state_ddt` intrinsics can source Incidence. For all other - # calls, if there's no incidence in the arguments, there cannot be any incidence - # in the result. + ret = @invoke CC.abstract_eval_statement_expr(interp::AbstractInterpreter, inst::Expr, vtypes::Nothing, irsv::IRInterpretationState) + return Future{CC.RTEffects}(ret, interp, irsv) do ret, interp, irsv + (; rt, exct, effects) = ret + if (!interp.ipo_analysis_mode || interp.in_analysis) && !isa(rt, Const) && !isa(rt, Incidence) && !CC.isType(rt) && !is_all_inc_or_const(Any[rt]) + argtypes = CC.collect_argtypes(interp, inst.args, nothing, irsv) + if argtypes === nothing return CC.RTEffects(rt, exct, effects) end - fb_inci = _fallback_incidence(argtypes) - if fb_inci !== nothing - update_type(t::Type) = Incidence(t, fb_inci.row, fb_inci.eps) - update_type(t::Incidence) = t - update_type(t::Const) = t - update_type(t::CC.PartialTypeVar) = t - update_type(t::PartialStruct) = PartialStruct(t.typ, Any[Base.isvarargtype(f) ? f : update_type(f) for f in t.fields]) - update_type(t::CC.Conditional) = CC.Conditional(t.slot, update_type(t.thentype), update_type(t.elsetype)) - newrt = update_type(rt) - return CC.RTEffects(newrt, exct, effects) + if is_all_inc_or_const(argtypes) + if inst.head in (:call, :invoke) && CC.hasintersect(widenconst(argtypes[inst.head === :call ? 1 : 2]), Union{typeof(variable), typeof(sim_time), typeof(state_ddt)}) + # The `variable` and `state_ddt` intrinsics can source Incidence. For all other + # calls, if there's no incidence in the arguments, there cannot be any incidence + # in the result. + return CC.RTEffects(rt, exct, effects) + end + fb_inci = _fallback_incidence(argtypes) + if fb_inci !== nothing + update_type(t::Type) = Incidence(t, fb_inci.row, fb_inci.eps) + update_type(t::Incidence) = t + update_type(t::Const) = t + update_type(t::CC.PartialTypeVar) = t + update_type(t::PartialStruct) = PartialStruct(t.typ, Any[Base.isvarargtype(f) ? f : update_type(f) for f in t.fields]) + update_type(t::CC.Conditional) = CC.Conditional(t.slot, update_type(t.thentype), update_type(t.elsetype)) + newrt = update_type(rt) + return CC.RTEffects(newrt, exct, effects) + end end end + return CC.RTEffects(rt, exct, effects) end - return CC.RTEffects(rt, exct, effects) end @override function CC.compute_forwarded_argtypes(interp::DAEInterpreter, arginfo::ArgInfo, sv::AbsIntState) @@ -1218,11 +1216,12 @@ function infer_ir!(ir, interp::AbstractInterpreter, mi::MethodInstance) end end - method_info = CC.MethodInfo(#=propagate_inbounds=#true, nothing) + (nargs, isva) = isa(mi.def, Method) ? (mi.def.nargs, mi.def.isva) : (0, false) + method_info = CC.SpecInfo(nargs, isva, #=propagate_inbounds=#true, nothing) min_world = world = get_inference_world(interp) max_world = get_world_counter() irsv = IRInterpretationState(interp, method_info, ir, mi, ir.argtypes, world, min_world, max_world) - (rt, nothrow) = CC._ir_abstract_constant_propagation(interp, irsv) + (rt, nothrow) = CC.ir_abstract_constant_propagation(interp, irsv) return rt end diff --git a/src/transform/common.jl b/src/transform/common.jl index d25a05c..97d9129 100644 --- a/src/transform/common.jl +++ b/src/transform/common.jl @@ -57,7 +57,7 @@ function remap_info(remap_ir!, info) if isa(result, CC.SemiConcreteResult) let ir = copy(result.ir) remap_ir!(ir) - CC.SemiConcreteResult(result.mi, ir, result.effects) + CC.SemiConcreteResult(result.mi, ir, result.effects, result.spec_info) end elseif isa(result, CC.ConstPropResult) if isa(result.result.src, DAECache)