Skip to content

Commit

Permalink
More scope adjustment
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno committed Sep 28, 2024
1 parent 1714a6b commit d2ddb17
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 35 deletions.
18 changes: 9 additions & 9 deletions ext/DAECompilerModelingToolkitExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,8 @@ function declare_parameters(model, struct_name)
backing::B
end
)


constructor_expr =:(
@generated function _check_parameter_names(::Type{$struct_name}, param_kwargs::NamedTuple)
unexpected_parameters = setdiff(fieldnames(param_kwargs), $param_names_tuple_expr)
Expand Down Expand Up @@ -108,7 +108,7 @@ function declare_parameters(model, struct_name)
if name === $param_name
return if hasfield(B, $param_name)
getfield(getfield(this, :backing), $param_name)
else
else
$param_value
end
end
Expand All @@ -118,7 +118,7 @@ function declare_parameters(model, struct_name)
return getfield(getfield(this, :backing), name)
))
getproperty_expr.args[end].args[end] = Expr(:block, getproperty_body...)

return Expr(:block, struct_expr, constructor_expr, propertynames_expr, getproperty_expr)
end

Expand Down Expand Up @@ -206,7 +206,7 @@ end

macro DAECompiler.declare_MTKConnector(mtk_component, ports...)
# We do need to do run time eval, because we can't decide what to construct with just lexical information.
# we need the values of the
# we need the values of the
:(Base.eval(@__MODULE__, $MTKConnector_AST($(esc(mtk_component)), $(esc.(ports)...))))
end

Expand All @@ -219,7 +219,7 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
end

while !isnothing(MTK.get_parent(model))
# Undo any call to structural_simplify
# Undo any call to structural_simplify
# (Should we give a warning here? They did waste CPU cycles simplfying it in first place)
model = MTK.get_parent(model)
end
Expand All @@ -239,11 +239,11 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)


struct_name = gensym(nameof(model))

return quote
$(declare_parameters(model, struct_name))

function (this::$struct_name)($(port_names...); dscope=$(_c(Scope))())
function (this::$struct_name)($(map(port->:($(port)::Float64), port_names)...); dscope=$(_c(Scope))())
$(declare_vars(model, :dscope))
$(declare_derivatives(state))
$(declare_equations(state, model, :dscope, ports))
Expand All @@ -258,4 +258,4 @@ function MTKConnector_AST(model::MTK.ODESystem, ports...)
end


end # module
end # module
2 changes: 1 addition & 1 deletion src/analysis/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ has_any_genscope(sc::PartialStruct) = false # TODO

function _make_argument_lattice_elem(which::Argument, @nospecialize(argt), add_variable!, add_equation!, add_scope!)
if isa(argt, Const)
@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
#@assert !isa(argt.val, Scope) # Shouldn't have been forwarded
return argt
elseif isa(argt, Type) && argt <: Intrinsics.AbstractScope
return PartialScope(add_scope!(which))
Expand Down
3 changes: 3 additions & 0 deletions src/analysis/lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,9 @@ function CC._getfield_tfunc(🥬::DAELattice, @nospecialize(s00), @nospecialize(
return Union{}
end
rt = CC._getfield_tfunc(CC.widenlattice(🥬), s00.typ, name, setfield)
if rt == Union{}
return Union{}
end
if isempty(s00)
return Incidence(rt)
end
Expand Down
36 changes: 15 additions & 21 deletions src/state_mapping.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,9 @@ using SciMLBase, SymbolicIndexingInterface
struct ScopeRef{T, ST}
sys::T
scope::Scope{ST}

# (Optional) opaque data structure to facilitate faster `getproperty`.
cursor
end
Base.Broadcast.broadcastable(ref::ScopeRef) = Ref(ref) # broadcast as scalar

Expand Down Expand Up @@ -135,33 +138,24 @@ function SciMLBase.sym_to_index(sr::ScopeRef, A::SciMLBase.DEIntegrator)
end

function Base.getproperty(sys::IRODESystem, name::Symbol)
haskey(StructuralAnalysisResult(sys).names, name) || throw(Base.UndefRefError())
return ScopeRef(sys, Scope(Scope(), name))
names = StructuralAnalysisResult(sys).names
cursor = get(names, name, nothing)
cursor === nothing && throw(Base.UndefRefError())
return ScopeRef(sys, Scope(Scope(), name), cursor)
end

function Base.propertynames(sr::ScopeRef)
scope = getfield(sr, :scope)
stack = sym_stack(scope)
strct = NameLevel(StructuralAnalysisResult(IRODESystem(sr)).names)
for s in reverse(stack)
strct = strct.children[s]
strct.children === nothing && return keys(Dict{Symbol, Any}())
end
return keys(strct.children)
cursor = getfield(sr, :cursor)
cursor.children === nothing && return keys(Dict{Symbol, Any}())
return keys(cursor.children)
end

function Base.getproperty(sr::ScopeRef{IRODESystem}, name::Symbol)
scope = getfield(sr, :scope)
stack = sym_stack(scope)
strct = NameLevel(StructuralAnalysisResult(IRODESystem(sr)).names)
for s in reverse(stack)
strct = strct.children[s]
strct.children === nothing && throw(Base.UndefRefError())
end
if !haskey(strct.children, name)
throw(Base.UndefRefError())
end
ScopeRef(IRODESystem(sr), Scope(getfield(sr, :scope), name))
cursor = getfield(sr, :cursor)
cursor.children === nothing && return throw(Base.UndefRefError())
new_cursor = get(cursor.children, name, nothing)
new_cursor === nothing && return throw(Base.UndefRefError())
return ScopeRef(IRODESystem(sr), Scope(getfield(sr, :scope), name), new_cursor)
end

function Base.show(io::IO, scope::Scope)
Expand Down
8 changes: 4 additions & 4 deletions test/MSL/modeling_toolkit_helper.jl
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,10 @@ function Base.getproperty(sys::IRODESystem, name::Symbol)
namespaces = split_namespaces_var(name)
if haskey(names, namespaces[1])
# Normal DAECompiler way
return return get_scope_ref(sys, namespaces)
return return get_scope_ref(sys, namespaces, names[namespaces[1]])
elseif length(namespaces) > 1 && haskey(names, namespaces[2])
# Ignore first namespace it's cos we are not fully consistent with if we include the system name or not
return return get_scope_ref(sys, namespaces; start_idx=2)
return return get_scope_ref(sys, namespaces, names[namespaces[2]]; start_idx=2)
else # It could be from the mtksys
mtksys = sys_map[sys_map_key(sys)]
if hasproperty(mtksys, name) # if it is actually from the MTK system (which allows unflattened names)
Expand All @@ -273,8 +273,8 @@ function Base.getproperty(sys::IRODESystem, name::Symbol)
end
throw(Base.KeyError(name)) # should be a UndefRef but key error useful for findout what broke it.
end
function get_scope_ref(sys, names; start_idx=1)
ref = DAECompiler.ScopeRef(sys, DAECompiler.Scope(DAECompiler.Scope(), names[start_idx]))
function get_scope_ref(sys, names, cursor; start_idx=1)
ref = DAECompiler.ScopeRef(sys, DAECompiler.Scope(DAECompiler.Scope(), names[start_idx]), cursor)
for name in @view names[(start_idx+1):end]
ref = getproperty(ref, name)
end
Expand Down

0 comments on commit d2ddb17

Please sign in to comment.