Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Invalidate methods when binding is typed/const-defined #54733

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,7 @@ function abstract_eval_isdefined(interp::AbstractInterpreter, e::Expr, vtypes::U
elseif isdefinedconst_globalref(sym)
rt = Const(true)
else
add_binding_backedge!(sv, sym, :const)
effects = Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE)
end
elseif isexpr(sym, :static_parameter)
Expand Down Expand Up @@ -2822,18 +2823,21 @@ end
isdefined_globalref(g::GlobalRef) = !iszero(ccall(:jl_globalref_boundp, Cint, (Any,), g))
isdefinedconst_globalref(g::GlobalRef) = isconst(g) && isdefined_globalref(g)

function abstract_eval_globalref_type(g::GlobalRef)
function abstract_eval_globalref_type(g::GlobalRef, sv::Union{AbsIntState,Nothing}=nothing)
if isdefinedconst_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
ty = ccall(:jl_get_binding_type, Any, (Any, Any), g.mod, g.name)
ty === nothing && return Any
if ty === nothing
sv !== nothing && add_binding_backedge!(sv, g, :type)
return Any
end
return ty
end
abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref_type(GlobalRef(M, s))
abstract_eval_global(M::Module, s::Symbol, sv::Union{AbsIntState,Nothing}=nothing) = abstract_eval_globalref_type(GlobalRef(M, s), sv)

function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState)
rt = abstract_eval_globalref_type(g)
rt = abstract_eval_globalref_type(g, sv)
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
if isa(rt, Const)
Expand Down
8 changes: 8 additions & 0 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,14 @@ function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospeci
return push!(irsv.edges, mt, typ)
end

function add_binding_backedge!(caller::InferenceState, g::GlobalRef, kind::Symbol)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), g, kind)
end
function add_binding_backedge!(irsv::IRInterpretationState, g::GlobalRef)
return push!(irsv.edges, g, kind)
end

get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]

Expand Down
2 changes: 2 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ function store_backedges(caller::MethodInstance, edges::Vector{Any})
callee = itr.caller
if isa(callee, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
elseif isa(callee, GlobalRef)
ccall(:jl_globalref_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
else
typeassert(callee, MethodTable)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,17 @@ end
const empty_backedge_iter = BackedgeIterator(Any[])

struct BackedgePair
sig # ::Union{Nothing,Type}
caller::Union{MethodInstance,MethodTable}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable}) = new(sig, caller)
sig # ::Union{Nothing,Symbol,Type}
caller::Union{MethodInstance,MethodTable,GlobalRef}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable,GlobalRef}) = new(sig, caller)
end

function iterate(iter::BackedgeIterator, i::Int=1)
backedges = iter.backedges
i > length(backedges) && return nothing
item = backedges[i]
isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch
isa(item, GlobalRef) && return BackedgePair(backedges[i+1], item), i+2 # (untyped) binding
isa(item, MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch
return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls
end
Expand Down
15 changes: 13 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,10 @@ JL_CALLABLE(jl_f_get_binding_type)
if (b2 != b)
return (jl_value_t*)jl_any_type;
jl_value_t *old_ty = NULL;
jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type);
while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) {
if (old_ty && !jl_is_binding_edges(old_ty))
break;
}
return jl_atomic_load_relaxed(&b->ty);
}
return ty;
Expand All @@ -1395,8 +1398,15 @@ JL_CALLABLE(jl_f_set_binding_type)
JL_TYPECHK(set_binding_type!, type, ty);
jl_binding_t *b = jl_get_binding_wr(m, s);
jl_value_t *old_ty = NULL;
if (jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) {
while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) {
if (old_ty && !jl_is_binding_edges(old_ty))
break;
}
if (!old_ty)
jl_gc_wb(b, ty);
else if (jl_is_binding_edges(old_ty)) {
jl_gc_wb(b, ty);
jl_binding_invalidate(ty, /* is_const */ 0, (jl_binding_edges_t *)old_ty);
}
else if (nargs != 2 && !jl_types_equal(ty, old_ty)) {
jl_errorf("cannot set type for global %s.%s. It already has a value or is already set to a different type.",
Expand Down Expand Up @@ -2525,6 +2535,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("QuoteNode", (jl_value_t*)jl_quotenode_type);
add_builtin("NewvarNode", (jl_value_t*)jl_newvarnode_type);
add_builtin("Binding", (jl_value_t*)jl_binding_type);
add_builtin("BindingEdges", (jl_value_t*)jl_binding_edges_type);
add_builtin("GlobalRef", (jl_value_t*)jl_globalref_type);
add_builtin("NamedTuple", (jl_value_t*)jl_namedtuple_type);

Expand Down
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3201,7 +3201,7 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *
return mark_julia_const(ctx, v);
ty = jl_atomic_load_relaxed(&bnd->ty);
}
if (ty == nullptr)
if (ty == nullptr || jl_is_binding_edges(ty))
ty = (jl_value_t*)jl_any_type;
return update_julia_type(ctx, emit_checked_var(ctx, bp, name, (jl_value_t*)mod, false, ctx.tbaa().tbaa_binding), ty);
}
Expand All @@ -3217,7 +3217,7 @@ static jl_cgval_t emit_globalop(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *s
return jl_cgval_t();
if (bnd && !bnd->constp) {
jl_value_t *ty = jl_atomic_load_relaxed(&bnd->ty);
if (ty != nullptr) {
if (ty != nullptr && !jl_is_binding_edges(ty)) {
const std::string fname = issetglobal ? "setglobal!" : isreplaceglobal ? "replaceglobal!" : isswapglobal ? "swapglobal!" : ismodifyglobal ? "modifyglobal!" : "setglobalonce!";
if (!ismodifyglobal) {
// TODO: use typeassert in jl_check_binding_wr too
Expand Down
63 changes: 63 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,69 @@ static void invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_w
}
}

/**
* Invalidate the edges accumulated in `be` - this should be called when a binding has just
* acquired a type or a const value.
*
* ty is the new type of the binding (optional if const), and `is_const` is whether the new
* binding ended up being const. These will be used to filter the edge invalidations, so that
* e.g. an `isdefined` edge is not invalidated by a non-const binding
**/
JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be)
{
if (!is_const && ty == (jl_value_t *)jl_any_type)
return; // no improvement to inference information

jl_array_t *edges = be->edges;
jl_method_instance_t *mi = NULL;
JL_GC_PUSH2(&edges, mi);
JL_LOCK(&world_counter_lock);
// Narrow the world age on the methods to make them uncallable
size_t world = jl_atomic_load_relaxed(&jl_world_counter);
for (int i = 0; i < jl_array_len(edges) / 2; i++) {
mi = (jl_method_instance_t *)jl_array_ptr_ref(edges, 2 * i);
jl_sym_t *kind = (jl_sym_t *)jl_array_ptr_ref(edges, 2 * i + 1);
if (!is_const && kind == jl_symbol("const"))
continue; // this is an `isdefined` edge, which has not improved

invalidate_method_instance(mi, world, /* depth */ 0);
}
jl_atomic_store_release(&jl_world_counter, world + 1);
JL_UNLOCK(&world_counter_lock);
JL_GC_POP();
}

JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller)
{
jl_binding_t *b = jl_get_module_binding(callee->mod, callee->name, /* alloc */ 0);
assert(b != NULL);
jl_binding_edges_t *edges = (jl_binding_edges_t *)jl_atomic_load_acquire(&b->ty);
if (edges && !jl_is_binding_edges(edges))
return; // TODO: Handle case where the invalidation happens before the edge arrives

jl_array_t *array = NULL;
JL_GC_PUSH2(&array, &edges);
if (edges == NULL) {
array = jl_alloc_vec_any(0);
edges = (jl_binding_edges_t *)jl_gc_alloc(
jl_current_task->ptls, sizeof(jl_binding_edges_t),
jl_binding_edges_type
);
edges->edges = array;
jl_value_t *old_ty = NULL;
if (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t *)edges))
return; // TODO: Handle case where ty was swapped out from under us
jl_gc_wb(b, edges);
}
else {
array = edges->edges;
}
jl_array_ptr_1d_push(array, (jl_value_t *)caller);
jl_array_ptr_1d_push(array, (jl_value_t *)kind);
JL_GC_POP();
return;
}

// add a backedge from callee to caller
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller)
{
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
XX(jl_floatingpoint_type) \
XX(jl_function_type) \
XX(jl_binding_type) \
XX(jl_binding_edges_type) \
XX(jl_globalref_type) \
XX(jl_gotoifnot_type) \
XX(jl_enternode_type) \
Expand Down
2 changes: 2 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
XX(jl_backtrace_from_here) \
XX(jl_base_relative_to) \
XX(jl_binding_resolved_p) \
XX(jl_binding_invalidate) \
XX(jl_bitcast) \
XX(jl_boundp) \
XX(jl_bounds_error) \
Expand Down Expand Up @@ -237,6 +238,7 @@
XX(jl_get_world_counter) \
XX(jl_get_zero_subnormals) \
XX(jl_gf_invoke_lookup) \
XX(jl_globalref_add_backedge) \
XX(jl_method_lookup_by_tt) \
XX(jl_method_lookup) \
XX(jl_gf_invoke_lookup_worlds) \
Expand Down
5 changes: 5 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3108,6 +3108,11 @@ void jl_init_types(void) JL_GC_DISABLED
const static uint32_t binding_constfields[] = { 0x0002 }; // Set fields 2 as constant
jl_binding_type->name->constfields = binding_constfields;

jl_binding_edges_type =
jl_new_datatype(jl_symbol("BindingBackedges"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(1, "edges"), jl_svec(1, jl_any_type),
jl_emptysvec, 0, 0, 1);

jl_globalref_type =
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(3, "mod", "name", "binding"),
Expand Down
7 changes: 7 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,11 @@ typedef struct _jl_binding_t {
uint8_t padding:1;
} jl_binding_t;

typedef struct _jl_binding_edges_t {
JL_DATA_TYPE
jl_array_t *edges;
} jl_binding_edges_t;

typedef struct {
uint64_t hi;
uint64_t lo;
Expand Down Expand Up @@ -930,6 +935,7 @@ extern JL_DLLIMPORT jl_value_t *jl_memoryref_uint8_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_value_t *jl_memoryref_any_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_expr_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_binding_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_binding_edges_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_globalref_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_linenumbernode_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_gotonode_type JL_GLOBALLY_ROOTED;
Expand Down Expand Up @@ -1503,6 +1509,7 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT
#define jl_is_slotnumber(v) jl_typetagis(v,jl_slotnumber_type)
#define jl_is_expr(v) jl_typetagis(v,jl_expr_type)
#define jl_is_binding(v) jl_typetagis(v,jl_binding_type)
#define jl_is_binding_edges(v) jl_typetagis(v,jl_binding_edges_type)
#define jl_is_globalref(v) jl_typetagis(v,jl_globalref_type)
#define jl_is_gotonode(v) jl_typetagis(v,jl_gotonode_type)
#define jl_is_gotoifnot(v) jl_typetagis(v,jl_gotoifnot_type)
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ JL_DLLEXPORT jl_value_t *jl_nth_slot_type(jl_value_t *sig JL_PROPAGATES_ROOT, si
void jl_compute_field_offsets(jl_datatype_t *st);
void jl_module_run_initializer(jl_module_t *m);
JL_DLLEXPORT jl_binding_t *jl_get_module_binding(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var, int alloc);
JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be);
JL_DLLEXPORT void jl_binding_deprecation_warning(jl_module_t *m, jl_sym_t *sym, jl_binding_t *b);
extern jl_array_t *jl_module_init_order JL_GLOBALLY_ROOTED;
extern htable_t jl_current_modules JL_GLOBALLY_ROOTED;
Expand Down Expand Up @@ -1041,6 +1042,7 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt JL_PROPAGATES_RO
JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(
jl_method_t *m JL_PROPAGATES_ROOT, jl_value_t *type, jl_svec_t *sparams);
jl_method_instance_t *jl_specializations_get_or_insert(jl_method_instance_t *mi_ins);
JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller);
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller);
JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *typ, jl_value_t *caller);
JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
Expand Down
7 changes: 7 additions & 0 deletions src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -1113,6 +1113,13 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name,
if (gf != NULL) {
if (!jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(gf)) && !jl_is_type(gf))
jl_errorf("cannot define function %s; it already has a value", jl_symbol_name(name));
} else if (bnd) {
jl_value_t *old_ty = NULL;
while (!jl_atomic_cmpswap_relaxed(&bnd->ty, &old_ty, (jl_value_t*)jl_any_type)) {
assert(!old_ty || jl_is_binding_edges(old_ty));
}
if (old_ty)
jl_binding_invalidate((jl_value_t *)jl_any_type, /* is_const */ 1, (jl_binding_edges_t *)old_ty);
}
if (bnd)
bnd->constp = 1; // XXX: use jl_declare_constant and jl_checked_assignment
Expand Down
Loading