From 38d10b6d8c139e1b703de4801597f40beb6eaa6e Mon Sep 17 00:00:00 2001 From: Cody Tapscott Date: Tue, 4 Jun 2024 03:46:37 -0400 Subject: [PATCH] Invalidate methods when binding is typed/const-defined This allows for patterns like: ``` julia> function foo(N) for i = 1:N x = bar(i) end end julia> foo(1_000_000_000) ERROR: UndefVarError: `bar` not defined ``` not to suffer a tremendous performance regression because of the fact that `foo` was inferred with `bar` still undefined. Strictly speaking the original code remains valid, but for performance reasons once the global is defined we'd like to invalidate the code anyway to get an improved inference result. ``` julia> bar(x) = 3x bar (generic function with 1 method) julia> foo(1_000_000_000) # w/o PR: takes > 30 seconds ``` --- base/compiler/abstractinterpretation.jl | 12 ++-- base/compiler/inferencestate.jl | 8 +++ base/compiler/typeinfer.jl | 2 + base/compiler/utilities.jl | 7 ++- src/builtins.c | 15 ++++- src/codegen.cpp | 4 +- src/gf.c | 63 +++++++++++++++++++++ src/jl_exported_data.inc | 1 + src/jl_exported_funcs.inc | 2 + src/jltypes.c | 5 ++ src/julia.h | 7 +++ src/julia_internal.h | 2 + src/method.c | 11 +++- src/module.c | 73 ++++++++++++++++++++----- src/staticdata.c | 10 +++- src/toplevel.c | 5 +- test/compiler/inference.jl | 9 +++ 17 files changed, 208 insertions(+), 28 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 46e15d0c3ad79e..0e430e87343b94 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -2602,6 +2602,7 @@ function abstract_eval_isdefined(interp::AbstractInterpreter, e::Expr, vtypes::U elseif isdefinedconst_globalref(sym) rt = Const(true) else + add_binding_backedge!(sv, sym, :const) effects = Effects(EFFECTS_TOTAL; consistent=ALWAYS_FALSE) end elseif isexpr(sym, :static_parameter) @@ -2822,18 +2823,21 @@ end isdefined_globalref(g::GlobalRef) = !iszero(ccall(:jl_globalref_boundp, Cint, (Any,), g)) isdefinedconst_globalref(g::GlobalRef) = isconst(g) && isdefined_globalref(g) -function abstract_eval_globalref_type(g::GlobalRef) +function abstract_eval_globalref_type(g::GlobalRef, sv::Union{AbsIntState,Nothing}=nothing) if isdefinedconst_globalref(g) return Const(ccall(:jl_get_globalref_value, Any, (Any,), g)) end ty = ccall(:jl_get_binding_type, Any, (Any, Any), g.mod, g.name) - ty === nothing && return Any + if ty === nothing + sv !== nothing && add_binding_backedge!(sv, g, :type) + return Any + end return ty end -abstract_eval_global(M::Module, s::Symbol) = abstract_eval_globalref_type(GlobalRef(M, s)) +abstract_eval_global(M::Module, s::Symbol, sv::Union{AbsIntState,Nothing}=nothing) = abstract_eval_globalref_type(GlobalRef(M, s), sv) function abstract_eval_globalref(interp::AbstractInterpreter, g::GlobalRef, sv::AbsIntState) - rt = abstract_eval_globalref_type(g) + rt = abstract_eval_globalref_type(g, sv) consistent = inaccessiblememonly = ALWAYS_FALSE nothrow = false if isa(rt, Const) diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index e26845b9da1f0b..1a87c845fe43b3 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -1038,6 +1038,14 @@ function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospeci return push!(irsv.edges, mt, typ) end +function add_binding_backedge!(caller::InferenceState, g::GlobalRef, kind::Symbol) + isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance + return push!(get_stmt_edges!(caller), g, kind) +end +function add_binding_backedge!(irsv::IRInterpretationState, g::GlobalRef) + return push!(irsv.edges, g, kind) +end + get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc] get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag] diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index ba8be9ca561597..dae2451b64a8d0 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -641,6 +641,8 @@ function store_backedges(caller::MethodInstance, edges::Vector{Any}) callee = itr.caller if isa(callee, MethodInstance) ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller) + elseif isa(callee, GlobalRef) + ccall(:jl_globalref_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller) else typeassert(callee, MethodTable) ccall(:jl_method_table_add_backedge, Cvoid, (Any, Any, Any), callee, itr.sig, caller) diff --git a/base/compiler/utilities.jl b/base/compiler/utilities.jl index 65563dab795fb5..76fad95238fff8 100644 --- a/base/compiler/utilities.jl +++ b/base/compiler/utilities.jl @@ -336,9 +336,9 @@ end const empty_backedge_iter = BackedgeIterator(Any[]) struct BackedgePair - sig # ::Union{Nothing,Type} - caller::Union{MethodInstance,MethodTable} - BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable}) = new(sig, caller) + sig # ::Union{Nothing,Symbol,Type} + caller::Union{MethodInstance,MethodTable,GlobalRef} + BackedgePair(@nospecialize(sig), caller::Union{MethodInstance,MethodTable,GlobalRef}) = new(sig, caller) end function iterate(iter::BackedgeIterator, i::Int=1) @@ -346,6 +346,7 @@ function iterate(iter::BackedgeIterator, i::Int=1) i > length(backedges) && return nothing item = backedges[i] isa(item, MethodInstance) && return BackedgePair(nothing, item), i+1 # regular dispatch + isa(item, GlobalRef) && return BackedgePair(backedges[i+1], item), i+2 # (untyped) binding isa(item, MethodTable) && return BackedgePair(backedges[i+1], item), i+2 # abstract dispatch return BackedgePair(item, backedges[i+1]::MethodInstance), i+2 # `invoke` calls end diff --git a/src/builtins.c b/src/builtins.c index eb48f726191b94..b07b258e7cd2cb 100644 --- a/src/builtins.c +++ b/src/builtins.c @@ -1378,7 +1378,10 @@ JL_CALLABLE(jl_f_get_binding_type) if (b2 != b) return (jl_value_t*)jl_any_type; jl_value_t *old_ty = NULL; - jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type); + while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) { + if (old_ty && !jl_is_binding_edges(old_ty)) + break; + } return jl_atomic_load_relaxed(&b->ty); } return ty; @@ -1395,8 +1398,15 @@ JL_CALLABLE(jl_f_set_binding_type) JL_TYPECHK(set_binding_type!, type, ty); jl_binding_t *b = jl_get_binding_wr(m, s); jl_value_t *old_ty = NULL; - if (jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) { + while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, ty)) { + if (old_ty && !jl_is_binding_edges(old_ty)) + break; + } + if (!old_ty) + jl_gc_wb(b, ty); + else if (jl_is_binding_edges(old_ty)) { jl_gc_wb(b, ty); + jl_binding_invalidate(ty, /* is_const */ 0, (jl_binding_edges_t *)old_ty); } else if (nargs != 2 && !jl_types_equal(ty, old_ty)) { jl_errorf("cannot set type for global %s.%s. It already has a value or is already set to a different type.", @@ -2525,6 +2535,7 @@ void jl_init_primitives(void) JL_GC_DISABLED add_builtin("QuoteNode", (jl_value_t*)jl_quotenode_type); add_builtin("NewvarNode", (jl_value_t*)jl_newvarnode_type); add_builtin("Binding", (jl_value_t*)jl_binding_type); + add_builtin("BindingEdges", (jl_value_t*)jl_binding_edges_type); add_builtin("GlobalRef", (jl_value_t*)jl_globalref_type); add_builtin("NamedTuple", (jl_value_t*)jl_namedtuple_type); diff --git a/src/codegen.cpp b/src/codegen.cpp index 31f40470962eec..1125218597d773 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3201,7 +3201,7 @@ static jl_cgval_t emit_globalref(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t * return mark_julia_const(ctx, v); ty = jl_atomic_load_relaxed(&bnd->ty); } - if (ty == nullptr) + if (ty == nullptr || jl_is_binding_edges(ty)) ty = (jl_value_t*)jl_any_type; return update_julia_type(ctx, emit_checked_var(ctx, bp, name, (jl_value_t*)mod, false, ctx.tbaa().tbaa_binding), ty); } @@ -3217,7 +3217,7 @@ static jl_cgval_t emit_globalop(jl_codectx_t &ctx, jl_module_t *mod, jl_sym_t *s return jl_cgval_t(); if (bnd && !bnd->constp) { jl_value_t *ty = jl_atomic_load_relaxed(&bnd->ty); - if (ty != nullptr) { + if (ty != nullptr && !jl_is_binding_edges(ty)) { const std::string fname = issetglobal ? "setglobal!" : isreplaceglobal ? "replaceglobal!" : isswapglobal ? "swapglobal!" : ismodifyglobal ? "modifyglobal!" : "setglobalonce!"; if (!ismodifyglobal) { // TODO: use typeassert in jl_check_binding_wr too diff --git a/src/gf.c b/src/gf.c index e5a33ecf68c5d5..a875a2e44abfff 100644 --- a/src/gf.c +++ b/src/gf.c @@ -1747,6 +1747,69 @@ static void invalidate_backedges(jl_method_instance_t *replaced_mi, size_t max_w } } +/** + * Invalidate the edges accumulated in `be` - this should be called when a binding has just + * acquired a type or a const value. + * + * ty is the new type of the binding (optional if const), and `is_const` is whether the new + * binding ended up being const. These will be used to filter the edge invalidations, so that + * e.g. an `isdefined` edge is not invalidated by a non-const binding + **/ +JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be) +{ + if (!is_const && ty == (jl_value_t *)jl_any_type) + return; // no improvement to inference information + + jl_array_t *edges = be->edges; + jl_method_instance_t *mi = NULL; + JL_GC_PUSH2(&edges, mi); + JL_LOCK(&world_counter_lock); + // Narrow the world age on the methods to make them uncallable + size_t world = jl_atomic_load_relaxed(&jl_world_counter); + for (int i = 0; i < jl_array_len(edges) / 2; i++) { + mi = (jl_method_instance_t *)jl_array_ptr_ref(edges, 2 * i); + jl_sym_t *kind = (jl_sym_t *)jl_array_ptr_ref(edges, 2 * i + 1); + if (!is_const && kind == jl_symbol("const")) + continue; // this is an `isdefined` edge, which has not improved + + invalidate_method_instance(mi, world, /* depth */ 0); + } + jl_atomic_store_release(&jl_world_counter, world + 1); + JL_UNLOCK(&world_counter_lock); + JL_GC_POP(); +} + +JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller) +{ + jl_binding_t *b = jl_get_module_binding(callee->mod, callee->name, /* alloc */ 0); + assert(b != NULL); + jl_binding_edges_t *edges = (jl_binding_edges_t *)jl_atomic_load_acquire(&b->ty); + if (edges && !jl_is_binding_edges(edges)) + return; // TODO: Handle case where the invalidation happens before the edge arrives + + jl_array_t *array = NULL; + JL_GC_PUSH2(&array, &edges); + if (edges == NULL) { + array = jl_alloc_vec_any(0); + edges = (jl_binding_edges_t *)jl_gc_alloc( + jl_current_task->ptls, sizeof(jl_binding_edges_t), + jl_binding_edges_type + ); + edges->edges = array; + jl_value_t *old_ty = NULL; + if (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t *)edges)) + return; // TODO: Handle case where ty was swapped out from under us + jl_gc_wb(b, edges); + } + else { + array = edges->edges; + } + jl_array_ptr_1d_push(array, (jl_value_t *)caller); + jl_array_ptr_1d_push(array, (jl_value_t *)kind); + JL_GC_POP(); + return; +} + // add a backedge from callee to caller JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller) { diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index 79ff4378418790..454ab5e2ef233b 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -51,6 +51,7 @@ XX(jl_floatingpoint_type) \ XX(jl_function_type) \ XX(jl_binding_type) \ + XX(jl_binding_edges_type) \ XX(jl_globalref_type) \ XX(jl_gotoifnot_type) \ XX(jl_enternode_type) \ diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 5e70566ab310e3..d94a9c00ee825c 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -43,6 +43,7 @@ XX(jl_backtrace_from_here) \ XX(jl_base_relative_to) \ XX(jl_binding_resolved_p) \ + XX(jl_binding_invalidate) \ XX(jl_bitcast) \ XX(jl_boundp) \ XX(jl_bounds_error) \ @@ -237,6 +238,7 @@ XX(jl_get_world_counter) \ XX(jl_get_zero_subnormals) \ XX(jl_gf_invoke_lookup) \ + XX(jl_globalref_add_backedge) \ XX(jl_method_lookup_by_tt) \ XX(jl_method_lookup) \ XX(jl_gf_invoke_lookup_worlds) \ diff --git a/src/jltypes.c b/src/jltypes.c index 59807226fb4a93..e2d84cb3c958a9 100644 --- a/src/jltypes.c +++ b/src/jltypes.c @@ -3108,6 +3108,11 @@ void jl_init_types(void) JL_GC_DISABLED const static uint32_t binding_constfields[] = { 0x0002 }; // Set fields 2 as constant jl_binding_type->name->constfields = binding_constfields; + jl_binding_edges_type = + jl_new_datatype(jl_symbol("BindingBackedges"), core, jl_any_type, jl_emptysvec, + jl_perm_symsvec(1, "edges"), jl_svec(1, jl_any_type), + jl_emptysvec, 0, 0, 1); + jl_globalref_type = jl_new_datatype(jl_symbol("GlobalRef"), core, jl_any_type, jl_emptysvec, jl_perm_symsvec(3, "mod", "name", "binding"), diff --git a/src/julia.h b/src/julia.h index 4534d00caa8884..f6d4979dcc836f 100644 --- a/src/julia.h +++ b/src/julia.h @@ -642,6 +642,11 @@ typedef struct _jl_binding_t { uint8_t padding:1; } jl_binding_t; +typedef struct _jl_binding_edges_t { + JL_DATA_TYPE + jl_array_t *edges; +} jl_binding_edges_t; + typedef struct { uint64_t hi; uint64_t lo; @@ -930,6 +935,7 @@ extern JL_DLLIMPORT jl_value_t *jl_memoryref_uint8_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_value_t *jl_memoryref_any_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_expr_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_binding_type JL_GLOBALLY_ROOTED; +extern JL_DLLIMPORT jl_datatype_t *jl_binding_edges_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_globalref_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_linenumbernode_type JL_GLOBALLY_ROOTED; extern JL_DLLIMPORT jl_datatype_t *jl_gotonode_type JL_GLOBALLY_ROOTED; @@ -1503,6 +1509,7 @@ static inline int jl_field_isconst(jl_datatype_t *st, int i) JL_NOTSAFEPOINT #define jl_is_slotnumber(v) jl_typetagis(v,jl_slotnumber_type) #define jl_is_expr(v) jl_typetagis(v,jl_expr_type) #define jl_is_binding(v) jl_typetagis(v,jl_binding_type) +#define jl_is_binding_edges(v) jl_typetagis(v,jl_binding_edges_type) #define jl_is_globalref(v) jl_typetagis(v,jl_globalref_type) #define jl_is_gotonode(v) jl_typetagis(v,jl_gotonode_type) #define jl_is_gotoifnot(v) jl_typetagis(v,jl_gotoifnot_type) diff --git a/src/julia_internal.h b/src/julia_internal.h index 6f71b6018606f9..35abf0fdb542ab 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -835,6 +835,7 @@ JL_DLLEXPORT jl_value_t *jl_nth_slot_type(jl_value_t *sig JL_PROPAGATES_ROOT, si void jl_compute_field_offsets(jl_datatype_t *st); void jl_module_run_initializer(jl_module_t *m); JL_DLLEXPORT jl_binding_t *jl_get_module_binding(jl_module_t *m JL_PROPAGATES_ROOT, jl_sym_t *var, int alloc); +JL_DLLEXPORT void jl_binding_invalidate(jl_value_t *ty, int is_const, jl_binding_edges_t *be); JL_DLLEXPORT void jl_binding_deprecation_warning(jl_module_t *m, jl_sym_t *sym, jl_binding_t *b); extern jl_array_t *jl_module_init_order JL_GLOBALLY_ROOTED; extern htable_t jl_current_modules JL_GLOBALLY_ROOTED; @@ -1041,6 +1042,7 @@ JL_DLLEXPORT jl_value_t *jl_methtable_lookup(jl_methtable_t *mt JL_PROPAGATES_RO JL_DLLEXPORT jl_method_instance_t *jl_specializations_get_linfo( jl_method_t *m JL_PROPAGATES_ROOT, jl_value_t *type, jl_svec_t *sparams); jl_method_instance_t *jl_specializations_get_or_insert(jl_method_instance_t *mi_ins); +JL_DLLEXPORT void jl_globalref_add_backedge(jl_globalref_t *callee, jl_sym_t *kind, jl_method_instance_t *caller); JL_DLLEXPORT void jl_method_instance_add_backedge(jl_method_instance_t *callee, jl_value_t *invokesig, jl_method_instance_t *caller); JL_DLLEXPORT void jl_method_table_add_backedge(jl_methtable_t *mt, jl_value_t *typ, jl_value_t *caller); JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT, diff --git a/src/method.c b/src/method.c index cb426514a6d544..89cba145f06f94 100644 --- a/src/method.c +++ b/src/method.c @@ -1104,6 +1104,7 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name, _Atomic(jl_value_t*) *bp, jl_binding_t *bnd) { + // TODO: Revisit all of this... jl_value_t *gf = NULL; assert(name && bp); @@ -1113,9 +1114,17 @@ JL_DLLEXPORT jl_value_t *jl_generic_function_def(jl_sym_t *name, if (gf != NULL) { if (!jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(gf)) && !jl_is_type(gf)) jl_errorf("cannot define function %s; it already has a value", jl_symbol_name(name)); + } else if (bnd) { + jl_value_t *old_ty = NULL; + while (!jl_atomic_cmpswap_relaxed(&bnd->ty, &old_ty, (jl_value_t*)jl_any_type)) { + assert(!old_ty || jl_is_binding_edges(old_ty)); + } + if (old_ty) + jl_binding_invalidate((jl_value_t *)jl_any_type, /* is_const */ 1, (jl_binding_edges_t *)old_ty); } - if (bnd) + if (bnd) { bnd->constp = 1; // XXX: use jl_declare_constant and jl_checked_assignment + } if (gf == NULL) { gf = (jl_value_t*)jl_new_generic_function(name, module); jl_atomic_store(bp, gf); // TODO: fix constp assignment data race diff --git a/src/module.c b/src/module.c index 9242a659502017..58db9b91f08487 100644 --- a/src/module.c +++ b/src/module.c @@ -407,6 +407,37 @@ static jl_binding_t *jl_resolve_owner(jl_binding_t *b/*optional*/, jl_module_t * // concurrent import return owner; } + jl_value_t *old_ty = jl_atomic_exchange_relaxed(&b->ty, NULL); + if (old_ty) { + assert(jl_is_binding_edges(old_ty)); + + // Load the owner type, attempting to insert a new jl_binding_edges_t if it's NULL + jl_value_t *owner_ty = jl_atomic_load_relaxed(&b2->ty); + if (owner_ty == NULL) { + jl_array_t *array = jl_alloc_vec_any(0); + JL_GC_PUSH1(&array); + jl_binding_edges_t *edges = (jl_binding_edges_t *)jl_gc_alloc( + jl_current_task->ptls, sizeof(jl_binding_edges_t), + jl_binding_edges_type + ); + edges->edges = array; + jl_atomic_cmpswap_relaxed(&b2->ty, &owner_ty, (jl_value_t *)edges); + if (owner_ty == NULL) { + jl_gc_wb(b2, (jl_value_t *)edges); + owner_ty = (jl_value_t *)edges; + } + JL_GC_POP(); + } + if (jl_is_binding_edges(owner_ty)) { + // TODO: Add a lock to make sure we don't collide with this on invalidation + jl_array_ptr_1d_append( + ((jl_binding_edges_t *)owner_ty)->edges, + ((jl_binding_edges_t *)old_ty)->edges + ); + } else if (owner_ty != (jl_value_t *)jl_any_type || b2->constp) { + jl_binding_invalidate(owner_ty, b2->constp, (jl_binding_edges_t *)old_ty); + } + } if (b2->deprecated) { b->deprecated = 1; // we will warn about this below, but we might want to warn at the use sites too if (m != jl_main_module && m != jl_base_module && @@ -448,7 +479,7 @@ JL_DLLEXPORT jl_value_t *jl_get_binding_type(jl_module_t *m, jl_sym_t *var) if (b == NULL) return jl_nothing; jl_value_t *ty = jl_atomic_load_relaxed(&b->ty); - return ty ? ty : jl_nothing; + return (ty && !jl_is_binding_edges(ty)) ? ty : jl_nothing; } JL_DLLEXPORT jl_binding_t *jl_get_binding(jl_module_t *m, jl_sym_t *var) @@ -805,14 +836,19 @@ JL_DLLEXPORT void jl_set_const(jl_module_t *m JL_ROOTING_ARGUMENT, jl_sym_t *var if (!jl_atomic_cmpswap(&bp->owner, &b2, bp) && b2 != bp) jl_errorf("invalid redefinition of constant %s", jl_symbol_name(var)); if (jl_atomic_load_relaxed(&bp->value) == NULL) { - jl_value_t *old_ty = NULL; - jl_atomic_cmpswap_relaxed(&bp->ty, &old_ty, (jl_value_t*)jl_any_type); uint8_t constp = 0; // if (jl_atomic_cmpswap(&bp->constp, &constp, 1)) { if (constp = bp->constp, bp->constp = 1, constp == 0) { jl_value_t *old = NULL; if (jl_atomic_cmpswap(&bp->value, &old, val)) { jl_gc_wb(bp, val); + jl_value_t *old_ty = NULL; + while (!jl_atomic_cmpswap_relaxed(&bp->ty, &old_ty, (jl_value_t*)jl_any_type)) { + if (old_ty && !jl_is_binding_edges(old_ty)) + break; + } + if (old_ty && jl_is_binding_edges(old_ty)) + jl_binding_invalidate((jl_value_t *)jl_any_type, /* is_const */ 1, (jl_binding_edges_t *)old_ty); return; } } @@ -889,18 +925,22 @@ void jl_binding_deprecation_warning(jl_module_t *m, jl_sym_t *s, jl_binding_t *b jl_value_t *jl_check_binding_wr(jl_binding_t *b, jl_module_t *mod, jl_sym_t *var, jl_value_t *rhs JL_MAYBE_UNROOTED, int reassign) { jl_value_t *old_ty = NULL; - if (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) { - if (old_ty != (jl_value_t*)jl_any_type && jl_typeof(rhs) != old_ty) { - JL_GC_PUSH1(&rhs); // callee-rooted - if (!jl_isa(rhs, old_ty)) - jl_errorf("cannot assign an incompatible value to the global %s.%s.", - jl_symbol_name(mod->name), jl_symbol_name(var)); - JL_GC_POP(); - } + while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) { + if (old_ty && !jl_is_binding_edges(old_ty)) + break; } - else { + if (!old_ty || jl_is_binding_edges(old_ty)) { + // edges are intentionally dropped on the floor here, since the new `Any` + // type is not a refinement of any inference information old_ty = (jl_value_t*)jl_any_type; } + else if (old_ty != (jl_value_t*)jl_any_type && jl_typeof(rhs) != old_ty) { + JL_GC_PUSH1(&rhs); // callee-rooted + if (!jl_isa(rhs, old_ty)) + jl_errorf("cannot assign an incompatible value to the global %s.%s.", + jl_symbol_name(mod->name), jl_symbol_name(var)); + JL_GC_POP(); + } if (b->constp) { if (reassign) { jl_value_t *old = NULL; @@ -950,8 +990,15 @@ JL_DLLEXPORT jl_value_t *jl_checked_replace(jl_binding_t *b, jl_module_t *mod, j JL_DLLEXPORT jl_value_t *jl_checked_modify(jl_binding_t *b, jl_module_t *mod, jl_sym_t *var, jl_value_t *op, jl_value_t *rhs) { jl_value_t *ty = NULL; - if (jl_atomic_cmpswap_relaxed(&b->ty, &ty, (jl_value_t*)jl_any_type)) + while (!jl_atomic_cmpswap_relaxed(&b->ty, &ty, (jl_value_t*)jl_any_type)) { + if (ty && !jl_is_binding_edges(ty)) + break; + } + if (!ty || jl_is_binding_edges(ty)) { + // edges are intentionally dropped on the floor here, since the new `Any` + // type is not a refinement of any inference information ty = (jl_value_t*)jl_any_type; + } if (b->constp) jl_errorf("invalid redefinition of constant %s.%s", jl_symbol_name(mod->name), jl_symbol_name(var)); diff --git a/src/staticdata.c b/src/staticdata.c index 28051d52eb1055..3cde57855549b6 100644 --- a/src/staticdata.c +++ b/src/staticdata.c @@ -100,7 +100,7 @@ extern "C" { // TODO: put WeakRefs on the weak_refs list during deserialization // TODO: handle finalizers -#define NUM_TAGS 190 +#define NUM_TAGS 191 // An array of references that need to be restored from the sysimg // This is a manually constructed dual of the gvars array, which would be produced by codegen for Julia code, for C. @@ -122,6 +122,7 @@ jl_value_t **const*const get_tags(void) { INSERT_TAG(jl_array_type); INSERT_TAG(jl_expr_type); INSERT_TAG(jl_binding_type); + INSERT_TAG(jl_binding_edges_type); INSERT_TAG(jl_globalref_type); INSERT_TAG(jl_string_type); INSERT_TAG(jl_module_type); @@ -1351,7 +1352,12 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED // We don't want these accidentally managing to diverge later in different compilation units. if (jl_atomic_load_relaxed(&b->owner) == b) { jl_value_t *old_ty = NULL; - jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type); + while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) { + if (old_ty && !jl_is_binding_edges(old_ty)) + break; + } + // TODO: This should be re-written to use the same open-closed semantics as value + // then we don't have to drop the edges (and they can still be defined in the non-pkgimage) } } } diff --git a/src/toplevel.c b/src/toplevel.c index 1899c9e18db304..d551406b646809 100644 --- a/src/toplevel.c +++ b/src/toplevel.c @@ -329,7 +329,10 @@ void jl_eval_global_expr(jl_module_t *m, jl_expr_t *ex, int set_type) { if (set_type) { jl_value_t *old_ty = NULL; // maybe set the type too, perhaps - jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type); + while (!jl_atomic_cmpswap_relaxed(&b->ty, &old_ty, (jl_value_t*)jl_any_type)) { + if (old_ty && jl_is_binding_edges(old_ty)) + break; + } } } } diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index eca937dddc5aba..ce88d5984f6774 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -5757,3 +5757,12 @@ end bar54341(args...) = foo54341(4, args...) @test Core.Compiler.return_type(bar54341, Tuple{Vararg{Int}}) === Int + +should_be_invalidated_by_binding_edge() = unknown_foo() +# Trigger an inference result before all definitions are available +@test Any === Core.Compiler.return_type(should_be_invalidated_by_binding_edge, Tuple{}) + +# Binding backedges should guarantee that when `unknown_foo` is const-defined, this is invalidated +unknown_foo() = rand(Int) +# Inference in the new world should give a good result +@test Int === Core.Compiler.return_type(should_be_invalidated_by_binding_edge, Tuple{})