Skip to content

Commit

Permalink
Invalidate methods when binding is typed/const-defined
Browse files Browse the repository at this point in the history
This allows for patterns like:
```
julia> function foo(N)
    for i = 1:N
        x = bar(i)
    end
end

julia> foo(1_000_000_000)
ERROR: UndefVarError: `bar` not defined
```

not to suffer a tremendous performance regression because of the fact
that `foo` was inferred with `bar` still undefined.

Strictly speaking the original code remains valid, but for performance
reasons once the global is defined we'd like to invalidate the code
anyway to get an improved inference result.

```
julia> bar(x) = 3x
bar (generic function with 1 method)

julia> foo(1_000_000_000) # w/o PR: takes > 30 seconds
```
  • Loading branch information
topolarity committed Jun 7, 2024
1 parent 9477472 commit 38d10b6
Show file tree
Hide file tree
Showing 17 changed files with 208 additions and 28 deletions.
12 changes: 8 additions & 4 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2602,6 +2602,7 @@ function abstract_eval_isdefined(interp::AbstractInterpreter, e::Expr, vtypes::U
elseif isdefinedconst_globalref(sym)
rt = Const(true)
else
add_binding_backedge!(sv, sym, :const)
effects = Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE)
end
elseif isexpr(sym, :static_parameter)
Expand Down Expand Up @@ -2822,18 +2823,21 @@ end
isdefined_globalref(g::GlobalRef) = !iszero(ccall(:jl_globalref_boundp, Cint, (Any,), g))
isdefinedconst_globalref(g::GlobalRef) = isconst(g) && isdefined_globalref(g)

function abstract_eval_globalref_type(g::GlobalRef)
function abstract_eval_globalref_type(g::GlobalRef, sv::Union{AbsIntState,Nothing}=nothing)
if isdefinedconst_globalref(g)
return Const(ccall(:jl_get_globalref_value, Any, (Any,), g))
end
ty = ccall(:jl_get_binding_type, Any, (Any, Any), g.mod, g.name)
ty === nothing && return Any
if ty === nothing
sv !== nothing && add_binding_backedge!(sv, g, :type)
return Any
end
return ty
end
abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref_type(GlobalRef(M, s))
abstract_eval_global(M::Module, s::Symbol, sv::Union{AbsIntState,Nothing}=nothing) = abstract_eval_globalref_type(GlobalRef(M, s), sv)

function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState)
rt = abstract_eval_globalref_type(g)
rt = abstract_eval_globalref_type(g, sv)
consistent = inaccessiblememonly = ALWAYS_FALSE
nothrow = false
if isa(rt, Const)
Expand Down
8 changes: 8 additions & 0 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1038,6 +1038,14 @@ function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospeci
return push!(irsv.edges, mt, typ)
end

function add_binding_backedge!(caller::InferenceState, g::GlobalRef, kind::Symbol)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), g, kind)
end
function add_binding_backedge!(irsv::IRInterpretationState, g::GlobalRef)
return push!(irsv.edges, g, kind)
end

get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]

Expand Down
2 changes: 2 additions & 0 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ function store_backedges(caller::MethodInstance, edges::Vector{Any})
callee = itr.caller
if isa(callee, MethodInstance)
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
elseif isa(callee, GlobalRef)
ccall(:jl_globalref_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
else
typeassert(callee, MethodTable)
ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller)
Expand Down
7 changes: 4 additions & 3 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -336,16 +336,17 @@ end
const empty_backedge_iter = BackedgeIterator(Any[])

struct BackedgePair
sig # ::Union{Nothing,Type}
caller::Union{MethodInstance,MethodTable}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable}) = new(sig, caller)
sig # ::Union{Nothing,Symbol,Type}
caller::Union{MethodInstance,MethodTable,GlobalRef}
BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable,GlobalRef}) = new(sig, caller)
end

function iterate(iter::BackedgeIterator, i::Int=1)
backedges = iter.backedges
i > length(backedges) && return nothing
item = backedges[i]
isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch
isa(item, GlobalRef) && return BackedgePair(backedges[i+1], item), i+2 # (untyped) binding
isa(item, MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch
return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls
end
Expand Down
15 changes: 13 additions & 2 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,10 @@ JL_CALLABLE(jl_f_get_binding_type)
if (b2 != b)
return (jl_value_t*)jl_any_type;
jl_value_t *old_ty = NULL;
jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type);
while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) {
if (old_ty && !jl_is_binding_edges(old_ty))
break;
}
return jl_atomic_load_relaxed(&b->ty);
}
return ty;
Expand All @@ -1395,8 +1398,15 @@ JL_CALLABLE(jl_f_set_binding_type)
JL_TYPECHK(set_binding_type!, type, ty);
jl_binding_t *b = jl_get_binding_wr(m, s);
jl_value_t *old_ty = NULL;
if (jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) {
while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) {
if (old_ty && !jl_is_binding_edges(old_ty))
break;
}
if (!old_ty)
jl_gc_wb(b, ty);
else if (jl_is_binding_edges(old_ty)) {
jl_gc_wb(b, ty);
jl_binding_invalidate(ty, /* is_const */ 0, (jl_binding_edges_t *)old_ty);
}
else if (nargs != 2 && !jl_types_equal(ty, old_ty)) {
jl_errorf("cannot set type for global %s.%s. It already has a value or is already set to a different type.",
Expand Down Expand Up @@ -2525,6 +2535,7 @@ void jl_init_primitives(void) JL_GC_DISABLED
add_builtin("QuoteNode", (jl_value_t*)jl_quotenode_type);
add_builtin("NewvarNode", (jl_value_t*)jl_newvarnode_type);
add_builtin("Binding", (jl_value_t*)jl_binding_type);
add_builtin("BindingEdges", (jl_value_t*)jl_binding_edges_type);
add_builtin("GlobalRef", (jl_value_t*)jl_globalref_type);
add_builtin("NamedTuple", (jl_value_t*)jl_namedtuple_type);

Expand Down
4 changes: 2 additions & 2 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3201,7 +3201,7 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *
return mark_julia_const(ctx, v);
ty = jl_atomic_load_relaxed(&bnd->ty);
}
if (ty == nullptr)
if (ty == nullptr || jl_is_binding_edges(ty))
ty = (jl_value_t*)jl_any_type;
return update_julia_type(ctx, emit_checked_var(ctx, bp, name, (jl_value_t*)mod, false, ctx.tbaa().tbaa_binding), ty);
}
Expand All @@ -3217,7 +3217,7 @@ static jl_cgval_t emit_globalop(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *s
return jl_cgval_t();
if (bnd && !bnd->constp) {
jl_value_t *ty = jl_atomic_load_relaxed(&bnd->ty);
if (ty != nullptr) {
if (ty != nullptr && !jl_is_binding_edges(ty)) {
const std::string fname = issetglobal ? "setglobal!" : isreplaceglobal ? "replaceglobal!" : isswapglobal ? "swapglobal!" : ismodifyglobal ? "modifyglobal!" : "setglobalonce!";
if (!ismodifyglobal) {
// TODO: use typeassert in jl_check_binding_wr too
Expand Down
63 changes: 63 additions & 0 deletions src/gf.c
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,69 @@ static void invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_w
}
}

/**
* Invalidate the edges accumulated in `be` - this should be called when a binding has just
* acquired a type or a const value.
*
* ty is the new type of the binding (optional if const), and `is_const` is whether the new
* binding ended up being const. These will be used to filter the edge invalidations, so that
* e.g. an `isdefined` edge is not invalidated by a non-const binding
**/
JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be)
{
if (!is_const && ty == (jl_value_t *)jl_any_type)
return; // no improvement to inference information

jl_array_t *edges = be->edges;
jl_method_instance_t *mi = NULL;
JL_GC_PUSH2(&edges, mi);
JL_LOCK(&world_counter_lock);
// Narrow the world age on the methods to make them uncallable
size_t world = jl_atomic_load_relaxed(&jl_world_counter);
for (int i = 0; i < jl_array_len(edges) / 2; i++) {
mi = (jl_method_instance_t *)jl_array_ptr_ref(edges, 2 * i);
jl_sym_t *kind = (jl_sym_t *)jl_array_ptr_ref(edges, 2 * i + 1);
if (!is_const && kind == jl_symbol("const"))
continue; // this is an `isdefined` edge, which has not improved

invalidate_method_instance(mi, world, /* depth */ 0);
}
jl_atomic_store_release(&jl_world_counter, world + 1);
JL_UNLOCK(&world_counter_lock);
JL_GC_POP();
}

JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller)
{
jl_binding_t *b = jl_get_module_binding(callee->mod, callee->name, /* alloc */ 0);
assert(b != NULL);
jl_binding_edges_t *edges = (jl_binding_edges_t *)jl_atomic_load_acquire(&b->ty);
if (edges && !jl_is_binding_edges(edges))
return; // TODO: Handle case where the invalidation happens before the edge arrives

jl_array_t *array = NULL;
JL_GC_PUSH2(&array, &edges);
if (edges == NULL) {
array = jl_alloc_vec_any(0);
edges = (jl_binding_edges_t *)jl_gc_alloc(
jl_current_task->ptls, sizeof(jl_binding_edges_t),
jl_binding_edges_type
);
edges->edges = array;
jl_value_t *old_ty = NULL;
if (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t *)edges))
return; // TODO: Handle case where ty was swapped out from under us
jl_gc_wb(b, edges);
}
else {
array = edges->edges;
}
jl_array_ptr_1d_push(array, (jl_value_t *)caller);
jl_array_ptr_1d_push(array, (jl_value_t *)kind);
JL_GC_POP();
return;
}

// add a backedge from callee to caller
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller)
{
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_data.inc
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
XX(jl_floatingpoint_type) \
XX(jl_function_type) \
XX(jl_binding_type) \
XX(jl_binding_edges_type) \
XX(jl_globalref_type) \
XX(jl_gotoifnot_type) \
XX(jl_enternode_type) \
Expand Down
2 changes: 2 additions & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
XX(jl_backtrace_from_here) \
XX(jl_base_relative_to) \
XX(jl_binding_resolved_p) \
XX(jl_binding_invalidate) \
XX(jl_bitcast) \
XX(jl_boundp) \
XX(jl_bounds_error) \
Expand Down Expand Up @@ -237,6 +238,7 @@
XX(jl_get_world_counter) \
XX(jl_get_zero_subnormals) \
XX(jl_gf_invoke_lookup) \
XX(jl_globalref_add_backedge) \
XX(jl_method_lookup_by_tt) \
XX(jl_method_lookup) \
XX(jl_gf_invoke_lookup_worlds) \
Expand Down
5 changes: 5 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -3108,6 +3108,11 @@ void jl_init_types(void) JL_GC_DISABLED
const static uint32_t binding_constfields[] = { 0x0002 }; // Set fields 2 as constant
jl_binding_type->name->constfields = binding_constfields;

jl_binding_edges_type =
jl_new_datatype(jl_symbol("BindingBackedges"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(1, "edges"), jl_svec(1, jl_any_type),
jl_emptysvec, 0, 0, 1);

jl_globalref_type =
jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec,
jl_perm_symsvec(3, "mod", "name", "binding"),
Expand Down
7 changes: 7 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,11 @@ typedef struct _jl_binding_t {
uint8_t padding:1;
} jl_binding_t;

typedef struct _jl_binding_edges_t {
JL_DATA_TYPE
jl_array_t *edges;
} jl_binding_edges_t;

typedef struct {
uint64_t hi;
uint64_t lo;
Expand Down Expand Up @@ -930,6 +935,7 @@ extern JL_DLLIMPORT jl_value_t *jl_memoryref_uint8_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_value_t *jl_memoryref_any_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_expr_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_binding_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_binding_edges_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_globalref_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_linenumbernode_type JL_GLOBALLY_ROOTED;
extern JL_DLLIMPORT jl_datatype_t *jl_gotonode_type JL_GLOBALLY_ROOTED;
Expand Down Expand Up @@ -1503,6 +1509,7 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT
#define jl_is_slotnumber(v) jl_typetagis(v,jl_slotnumber_type)
#define jl_is_expr(v) jl_typetagis(v,jl_expr_type)
#define jl_is_binding(v) jl_typetagis(v,jl_binding_type)
#define jl_is_binding_edges(v) jl_typetagis(v,jl_binding_edges_type)
#define jl_is_globalref(v) jl_typetagis(v,jl_globalref_type)
#define jl_is_gotonode(v) jl_typetagis(v,jl_gotonode_type)
#define jl_is_gotoifnot(v) jl_typetagis(v,jl_gotoifnot_type)
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -835,6 +835,7 @@ JL_DLLEXPORT jl_value_t *jl_nth_slot_type(jl_value_t *sig JL_PROPAGATES_ROOT, si
void jl_compute_field_offsets(jl_datatype_t *st);
void jl_module_run_initializer(jl_module_t *m);
JL_DLLEXPORT jl_binding_t *jl_get_module_binding(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var, int alloc);
JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be);
JL_DLLEXPORT void jl_binding_deprecation_warning(jl_module_t *m, jl_sym_t *sym, jl_binding_t *b);
extern jl_array_t *jl_module_init_order JL_GLOBALLY_ROOTED;
extern htable_t jl_current_modules JL_GLOBALLY_ROOTED;
Expand Down Expand Up @@ -1041,6 +1042,7 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt JL_PROPAGATES_RO
JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo(
jl_method_t *m JL_PROPAGATES_ROOT, jl_value_t *type, jl_svec_t *sparams);
jl_method_instance_t *jl_specializations_get_or_insert(jl_method_instance_t *mi_ins);
JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller);
JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller);
JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *typ, jl_value_t *caller);
JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
Expand Down
11 changes: 10 additions & 1 deletion src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -1104,6 +1104,7 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name,
_Atomic(jl_value_t*) *bp,
jl_binding_t *bnd)
{
// TODO: Revisit all of this...
jl_value_t *gf = NULL;

assert(name && bp);
Expand All @@ -1113,9 +1114,17 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name,
if (gf != NULL) {
if (!jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(gf)) && !jl_is_type(gf))
jl_errorf("cannot define function %s; it already has a value", jl_symbol_name(name));
} else if (bnd) {
jl_value_t *old_ty = NULL;
while (!jl_atomic_cmpswap_relaxed(&bnd->ty, &old_ty, (jl_value_t*)jl_any_type)) {
assert(!old_ty || jl_is_binding_edges(old_ty));
}
if (old_ty)
jl_binding_invalidate((jl_value_t *)jl_any_type, /* is_const */ 1, (jl_binding_edges_t *)old_ty);
}
if (bnd)
if (bnd) {
bnd->constp = 1; // XXX: use jl_declare_constant and jl_checked_assignment
}
if (gf == NULL) {
gf = (jl_value_t*)jl_new_generic_function(name, module);
jl_atomic_store(bp, gf); // TODO: fix constp assignment data race
Expand Down
Loading

0 comments on commit 38d10b6

Please sign in to comment.