Skip to content

Commit

Permalink
Generalize symbol type for debug scopes
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno committed Sep 20, 2024
1 parent 20d1ed4 commit 1714a6b
Show file tree
Hide file tree
Showing 14 changed files with 663 additions and 234 deletions.
682 changes: 542 additions & 140 deletions Manifest.toml

Large diffs are not rendered by default.

27 changes: 4 additions & 23 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ uuid = "32805668-c3d0-42c2-aafd-0d0a9857a104"
version = "1.21.0"
authors = ["JuliaHub, Inc. and other contributors"]

[workspace]
projects = ["test"]

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
CentralizedCaches = "d1073d05-2d26-4019-b855-dfa0385fef5e"
Expand Down Expand Up @@ -30,12 +33,11 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
StateSelection = "64909d44-ed92-46a8-bbd9-f047dfbdc84b"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Tracy = "e689c965-62c8-4b79-b2c5-8359227902fd"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"

[sources]
ModelingToolkitStandardLibrary = {rev = "ox/dae_compatible5", url = "https://github.com/CedarEDA/ModelingToolkitStandardLibrary.jl"}
SciMLBase = {rev = "os/dae-get-du2", url = "https://github.com/CedarEDA/SciMLBase.jl"}
SciMLSensitivity = {rev = "kf/mindep2", url = "https://github.com/CedarEDA/SciMLSensitivity.jl"}

Expand All @@ -54,12 +56,10 @@ Cthulhu = "2.10.1"
DiffEqBase = "6.149.2"
Diffractor = "0.2.7"
ForwardDiff = "0.10.36"
ModelingToolkitStandardLibrary = "2.6.0"
NonlinearSolve = "3.5.0"
OrderedCollections = "1.6.3"
PrecompileTools = "1"
Preferences = "1.4"
Roots = "2.0.22"
SciMLBase = "2.24.0"
SciMLSensitivity = "7.47"
StateSelection = "0.2.0"
Expand All @@ -68,24 +68,5 @@ Sundials = "4.19"
SymbolicIndexingInterface = "0.3"
julia = "1.11"

[extras]
ControlSystemsBase = "aaaaaaaa-a6ca-5380-bf3e-84a91bcd477e"
DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0"
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
ModelingToolkitStandardLibrary = "16a59e39-deab-5bd0-87e4-056b12336739"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Roots = "f2b01f46-fcfa-551c-844a-d8ac1e96c665"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Sundials = "c3572dad-4567-51f8-b174-8c6c989267f4"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[preferences.LinearSolve]
LoadMKL_JLL = false

[targets]
test = ["ControlSystemsBase", "DataInterpolations", "FiniteDiff", "FiniteDifferences", "IfElse", "InteractiveUtils", "ModelingToolkit", "ModelingToolkitStandardLibrary", "OrdinaryDiffEq", "SafeTestsets", "Sundials", "Test", "Roots", "StaticArrays"]
9 changes: 5 additions & 4 deletions ext/DAECompilerSciMLSensitivityExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ with one column per time step in `ts` and one one row per `variable`/`observed!`
"""
function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolution, syms::Vector{<:DAECompiler.ScopeRef}, ts=sol.t)
us, du_dparams = extract_local_sensitivities(sol, ts)
var_inds, obs_inds = DAECompiler.split_and_sort_syms(syms)

transformed_sys = DAECompiler.get_transformed_sys(sol)
sys = DAECompiler.get_sys(transformed_sys)
var_inds, obs_inds = DAECompiler.split_and_sort_syms(sys, syms)

dreconstruct! = get!(sol.prob.f.observed.derivative_cache, (var_inds, obs_inds, false)) do
DAECompiler.compile_batched_reconstruct_derivatives(transformed_sys, var_inds, obs_inds, false, false;)
end

num_params = length(du_dparams)
dout_vars_per_param = [similar(us, (length(var_inds), length(ts))) for _ in 1:num_params]
dout_obs_per_param = [similar(us, (length(obs_inds), length(ts))) for _ in 1:num_params]
Expand All @@ -67,7 +68,7 @@ function DAECompiler.reconstruct_sensitivities(sol::SciMLBase.AbstractODESolutio
end

return map(dout_vars_per_param, dout_obs_per_param) do dout_vars, dout_obs
DAECompiler.join_syms(syms, dout_vars, dout_obs, (var_inds, obs_inds))
DAECompiler.join_syms(sys, syms, dout_vars, dout_obs, (var_inds, obs_inds))
end
end

Expand Down
30 changes: 17 additions & 13 deletions src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ function make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_va
end

function resolve_genscopes(names)
new_names = OrderedDict{LevelKey, NameLevel}()
new_names = OrderedDict{Any, NameLevel}()
for (key, val) in collect(names)
if val.children !== nothing
@reset val.children = resolve_genscopes(val.children)
Expand Down Expand Up @@ -423,7 +423,7 @@ Perform the structural analysis on optimized code of `mi` and return `structure:
end
end

function refresh_identities(names::OrderedDict{LevelKey, NameLevel})
function refresh_identities(names::OrderedDict{LevelKey, NameLevel}) where {LevelKey, NameLevel}
new_names = OrderedDict{LevelKey, NameLevel}()
for (key, val) in names
if isa(key, Gen)
Expand Down Expand Up @@ -502,7 +502,7 @@ end
eq_kind = VarEqKind[]
warnings = UnsupportedIRException[]

names = OrderedDict{LevelKey, NameLevel}()
names = OrderedDict{Any, NameLevel}()

nsysmscopes = 0
ncallees = 0
Expand Down Expand Up @@ -1191,7 +1191,7 @@ function process_ipo_return!(ultimate_rt::PartialStruct, args...)
return PartialStruct(ultimate_rt.typ, fields), nimplicitoutpairs
end

function get_variable_name(names::OrderedDict{LevelKey, NameLevel}, var_to_diff, var_idx)
function get_variable_name(names::OrderedDict, var_to_diff, var_idx)
var_names = build_var_names(names, var_to_diff)
return var_names[var_idx]
end
Expand Down Expand Up @@ -1221,7 +1221,7 @@ function get_inline_backtrace(ir::IRCode, v::SSAValue)
return frames
end

function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector{<:LevelKey})
function walk_dict(names::OrderedDict{LevelKey, NameLevel}, stack::Vector) where {LevelKey, NameLevel}
for i = length(stack):-1:2
s = stack[i]
if !haskey(names, s)
Expand All @@ -1235,11 +1235,11 @@ end
is_valid_partial_scope(_) = false
is_valid_partial_scope(ps::PartialScope) = true
function is_valid_partial_scope(ps::PartialStruct)
if ps.typ === Scope
if ps.typ <: Scope
isa(ps.fields[2], Const) || return false
isa(ps.fields[2].val, Symbol) || return false
return is_valid_partial_scope(ps.fields[1])
elseif ps.typ === GenScope
elseif ps.typ <: GenScope
isa(ps.fields[1], Const) || return false
return is_valid_partial_scope(ps.fields[2])
else
Expand All @@ -1248,11 +1248,11 @@ function is_valid_partial_scope(ps::PartialStruct)
end

function sym_stack(ps::PartialStruct)
if ps.typ === Scope
if ps.typ <: Scope
sym = (ps.fields[2]::Const).val::Symbol
return pushfirst!(sym_stack(ps.fields[1]), sym)
else
@assert ps.typ === GenScope
@assert ps.typ <: GenScope
stack = sym_stack(ps.fields[2])
scope_identity = ((ps.fields[1]::Const).val)::Intrinsics.ScopeIdentity
stack[1] = Gen(scope_identity, stack[1])
Expand All @@ -1261,7 +1261,7 @@ function sym_stack(ps::PartialStruct)
end

sym_stack(ps::PartialScope) = LevelKey[ps]
function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
function record_scope!(ir::IRCode, names::OrderedDict, scope::Union{Scope, GenScope, PartialStruct, PartialScope},
varssa::Vector, idx::Int, lens)

stack = sym_stack(scope)
Expand All @@ -1282,11 +1282,15 @@ function record_scope!(ir::IRCode, names::OrderedDict{LevelKey, NameLevel}, scop
end

function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, val::NameLevel,
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}

haskey(names, key) || (names[key] = NameLevel())
existing = names[key]
for (offset, lens) in ((x->(only(findnz(mapping.var_coeffs[x].row)[1])), @o _.var),
function remap_var(x)
r = only(findnz(mapping.var_coeffs[x].row)[1]) - 1
return r
end
for (offset, lens) in ((remap_var, @o _.var),
(x->(x+obsoffset), @o _.obs),
(x->mapping.eqs[x], @o _.eq), (x->(x+epsoffset), @o _.eps))
if lens(val) !== nothing
Expand All @@ -1312,7 +1316,7 @@ function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::LevelKey, v
end

function merge_scopes!(names::OrderedDict{LevelKey, NameLevel}, key::Union{Scope, PartialStruct}, val::NameLevel,
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int)
mapping::CalleeMapping, obsoffset::Int, epsoffset::Int) where {LevelKey, NameLevel}

stack = sym_stack(key)
if isempty(stack)
Expand Down
4 changes: 2 additions & 2 deletions src/analysis/debugging.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
using StateSelection
using StateSelection.BipartiteGraphs

function build_var_names(names::OrderedDict{LevelKey, NameLevel}, var_to_diff)
function build_var_names(names::OrderedDict, var_to_diff)
var_names = OrderedDict{Int,String}()
build_var_names!(var_names, names, var_to_diff)
return var_names
end
function build_var_names!(var_names, names::OrderedDict{LevelKey, NameLevel}, var_to_diff, prefix=String[])
function build_var_names!(var_names, names::OrderedDict, var_to_diff, prefix=String[])
for name in keys(names)
name_path = join(vcat(prefix..., name), ".")
level = names[name]
Expand Down
15 changes: 6 additions & 9 deletions src/analysis/interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,7 @@ function process_template!(𝕃, coeffs, eq_mapping, applied_scopes, argtypes, t
eq_mapping[idnum(template)] = idnum(arg)
elseif CC.is_const_argtype(template)
#@Core.Compiler.show (arg, template)
@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
#@assert CC.is_lattice_equal(DAE_LATTICE, arg, template)
elseif isa(template, PartialScope)
id = idnum(template)
(id > length(applied_scopes)) && resize!(applied_scopes, id)
Expand Down Expand Up @@ -919,7 +919,7 @@ function _abstract_eval_invoke_inst(interp::DAEInterpreter, inst::Union{CC.Instr
argtypes = CC.collect_argtypes(interp, stmt.args, nothing, irsv)[2:end]
callee_result = dae_result_for_inst(interp, inst)
callee_result === nothing && return RT(nothing, (false, false))
if isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
if isa(callee_result, UncompilableIPOResult) || isa(callee_result.extended_rt, Const) || isa(callee_result.extended_rt, Type)
return RT(nothing, (false, false))
end
mapping = CalleeMapping(CC.optimizer_lattice(interp), argtypes, callee_result)
Expand Down Expand Up @@ -1030,14 +1030,11 @@ end
# -----

function typeinf_dae(@nospecialize(tt), world::UInt=get_world_counter();
method_table::Union{Nothing,MethodTable} = nothing,
ipo_analysis_mode::Bool = false)
interp = DAEInterpreter(world; method_table, ipo_analysis_mode)
match = Base._which(tt;
method_table=CC.method_table(interp),
world=get_inference_world(interp),
raise=false)
match === nothing && single_match_error(tt)
interp = DAEInterpreter(world; ipo_analysis_mode)
match = Base._methods_by_ftype(tt, 1, world)
isempty(match) && single_match_error(tt)
match = only(match)
mi = CC.specialize_method(match)
ci = CC.typeinf_ext(interp, mi, Core.Compiler.SOURCE_MODE_ABI)
return interp, ci
Expand Down
7 changes: 4 additions & 3 deletions src/cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,12 @@ struct NameLevel
obs::Union{Nothing, Int}
eq::Union{Nothing, Int}
eps::Union{Nothing, Int}
children::Union{Nothing, OrderedDict{LevelKey, NameLevel}}
# TODO: This should be an OrderedIdDict
children::Union{Nothing, OrderedDict{Any, NameLevel}}
end
NameLevel() =
NameLevel(nothing, nothing, nothing, nothing, nothing)
NameLevel(children::OrderedDict{LevelKey, NameLevel}) =
NameLevel(children::OrderedDict{Any, NameLevel}) =
NameLevel(nothing, nothing, nothing, nothing, children)

struct UnsupportedIRException <: Exception
Expand Down Expand Up @@ -77,7 +78,7 @@ struct DAEIPOResult
total_incidence::Vector{Any}
eq_kind::Vector{VarEqKind}
eq_callee_mapping::Vector{Union{Nothing, Vector{Pair{SSAValue, Int}}}}
names::OrderedDict{LevelKey, NameLevel}
names::OrderedDict{Any, NameLevel} # TODO: OrderedIdDict
nobserved::Int
neps::Int
ic_nzc::Int
Expand Down
4 changes: 2 additions & 2 deletions src/irodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ struct IRODESystem
fallback_interp::AbstractInterpreter = Core.Compiler.NativeInterpreter(),
debug_config = (;),
ipo_analysis_mode = false,
world::UInt=get_world_counter())
world::UInt=Base.tls_world_age())
debug_config = DebugConfig(debug_config, tt)
@may_timeit debug_config "typeinf_dae" interp, ci = typeinf_dae(tt, world; ipo_analysis_mode)
mi = ci.def
Expand All @@ -183,7 +183,7 @@ mutable struct IRTransformationState <: TransformationState{IRODESystem}
ir::IRCode
callback_func::Function
structure::SystemStructure
const names::OrderedDict{LevelKey, NameLevel}
const names::OrderedDict{Any, NameLevel}
const nobserved::Int
const neps::Int
const ic_nzc::Int
Expand Down
16 changes: 8 additions & 8 deletions src/runtime.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,22 @@ module Intrinsics

abstract type AbstractScope; end

struct Scope <: AbstractScope
struct Scope{T} <: AbstractScope
parent::AbstractScope
name::Symbol
Scope() = new()
Scope(s::AbstractScope, sym::Symbol) = new(s, sym)
name::T
Scope() = new{Union{}}()
Scope(s::AbstractScope, name::T) where {T} = new{T}(s, name)
end
(scope::Scope)(s::Symbol) = Scope(scope, s)
# Scope(), but will less function calls, so marginally easier on the compiler
# Scope(), but with less function calls, so marginally easier on the compiler
const root_scope = Scope()

mutable struct ScopeIdentity; end

struct GenScope <: AbstractScope
struct GenScope{T} <: AbstractScope
identity::ScopeIdentity
sc::Scope
GenScope(sc::Scope) = new(ScopeIdentity(), sc)
sc::Scope{T}
GenScope(sc::Scope{T}) where {T} = new{T}(ScopeIdentity(), sc)
end
GenScope(parent::AbstractScope, name::Symbol) =
GenScope(Scope(parent, name))
Expand Down
Loading

0 comments on commit 1714a6b

Please sign in to comment.