Skip to content

Commit

Permalink
add precompile support for recording fields to change
Browse files Browse the repository at this point in the history
Somewhat generalizes our support for changing Ptr to C_NULL. Not
particularly fast, since it is just using the builtins implementation of
setfield, and delaying the actual stores, but it should suffice.
  • Loading branch information
vtjnash authored and KristofferC committed Oct 21, 2024
1 parent e9bfc9c commit 0438f2a
Show file tree
Hide file tree
Showing 7 changed files with 167 additions and 17 deletions.
38 changes: 31 additions & 7 deletions base/lock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@

const ThreadSynchronizer = GenericCondition{Threads.SpinLock}

"""
current_task()
Get the currently running [`Task`](@ref).
"""
current_task() = ccall(:jl_get_current_task, Ref{Task}, ())

# Advisory reentrant lock
"""
ReentrantLock()
Expand Down Expand Up @@ -606,16 +613,23 @@ mutable struct PerProcess{T, F}
const initializer::F
const lock::ReentrantLock

PerProcess{T}(initializer::F) where {T, F} = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
PerProcess{T,F}(initializer::F) where {T, F} = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
PerProcess(initializer) = new{Base.promote_op(initializer), typeof(initializer)}(nothing, 0x00, true, initializer, ReentrantLock())
function PerProcess{T,F}(initializer::F) where {T, F}
once = new{T,F}(nothing, 0x00, true, initializer, ReentrantLock())
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
once, :x, nothing)
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
once, :state, 0x00)
return once
end
end
PerProcess{T}(initializer::F) where {T, F} = PerProcess{T, F}(initializer)
PerProcess(initializer) = PerProcess{Base.promote_op(initializer), typeof(initializer)}(initializer)
@inline function (once::PerProcess{T})() where T
state = (@atomic :acquire once.state)
if state != 0x01
(@noinline function init_perprocesss(once, state)
state == 0x02 && error("PerProcess initializer failed previously")
Base.__precompile__(once.allow_compile_time)
once.allow_compile_time || __precompile__(false)
lock(once.lock)
try
state = @atomic :monotonic once.state
Expand Down Expand Up @@ -644,6 +658,8 @@ function copyto_monotonic!(dest::AtomicMemory, src)
for j in eachindex(src)
if isassigned(src, j)
@atomic :monotonic dest[i] = src[j]
#else
# _unsafeindex_atomic!(dest, i, src[j], :monotonic)
end
i += 1
end
Expand Down Expand Up @@ -701,10 +717,18 @@ mutable struct PerThread{T, F}
@atomic ss::AtomicMemory{UInt8} # states: 0=initial, 1=hasrun, 2=error, 3==concurrent
const initializer::F

PerThread{T}(initializer::F) where {T, F} = new{T,F}(AtomicMemory{T}(), AtomicMemory{UInt8}(), initializer)
PerThread{T,F}(initializer::F) where {T, F} = new{T,F}(AtomicMemory{T}(), AtomicMemory{UInt8}(), initializer)
PerThread(initializer) = (T = Base.promote_op(initializer); new{T, typeof(initializer)}(AtomicMemory{T}(), AtomicMemory{UInt8}(), initializer))
function PerThread{T,F}(initializer::F) where {T, F}
xs, ss = AtomicMemory{T}(), AtomicMemory{UInt8}()
once = new{T,F}(xs, ss, initializer)
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
once, :xs, xs)
ccall(:jl_set_precompile_field_replace, Cvoid, (Any, Any, Any),
once, :ss, ss)
return once
end
end
PerThread{T}(initializer::F) where {T, F} = PerThread{T,F}(initializer)
PerThread(initializer) = PerThread{Base.promote_op(initializer), typeof(initializer)}(initializer)
@inline function getindex(once::PerThread, tid::Integer)
tid = Int(tid)
ss = @atomic :acquire once.ss
Expand Down
7 changes: 0 additions & 7 deletions base/task.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,13 +143,6 @@ macro task(ex)
:(Task($thunk))
end

"""
current_task()
Get the currently running [`Task`](@ref).
"""
current_task() = ccall(:jl_get_current_task, Ref{Task}, ())

# task states

const task_state_runnable = UInt8(0)
Expand Down
2 changes: 1 addition & 1 deletion src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -1008,7 +1008,7 @@ static inline size_t get_checked_fieldindex(const char *name, jl_datatype_t *st,
else {
jl_value_t *ts[2] = {(jl_value_t*)jl_long_type, (jl_value_t*)jl_symbol_type};
jl_value_t *t = jl_type_union(ts, 2);
jl_type_error("getfield", t, arg);
jl_type_error(name, t, arg);
}
if (mutabl && jl_field_isconst(st, idx)) {
jl_errorf("%s: const field .%s of type %s cannot be changed", name,
Expand Down
2 changes: 2 additions & 0 deletions src/gc-stock.c
Original file line number Diff line number Diff line change
Expand Up @@ -2786,6 +2786,8 @@ static void gc_mark_roots(jl_gc_markqueue_t *mq)
gc_heap_snapshot_record_gc_roots((jl_value_t*)jl_global_roots_list, "global_roots_list");
gc_try_claim_and_push(mq, jl_global_roots_keyset, NULL);
gc_heap_snapshot_record_gc_roots((jl_value_t*)jl_global_roots_keyset, "global_roots_keyset");
gc_try_claim_and_push(mq, precompile_field_replace, NULL);
gc_heap_snapshot_record_gc_roots((jl_value_t*)precompile_field_replace, "precompile_field_replace");
}

// find unmarked objects that need to be finalized from the finalizer list "list".
Expand Down
2 changes: 2 additions & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -885,6 +885,8 @@ extern jl_genericmemory_t *jl_global_roots_keyset JL_GLOBALLY_ROOTED;
extern arraylist_t *jl_entrypoint_mis;
JL_DLLEXPORT int jl_is_globally_rooted(jl_value_t *val JL_MAYBE_UNROOTED) JL_NOTSAFEPOINT;
JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val, int insert) JL_GLOBALLY_ROOTED;
extern jl_svec_t *precompile_field_replace JL_GLOBALLY_ROOTED;
JL_DLLEXPORT void jl_set_precompile_field_replace(jl_value_t *val, jl_value_t *field, jl_value_t *newval) JL_GLOBALLY_ROOTED;

jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
jl_value_t *source, jl_value_t **env, size_t nenv, int do_compile);
Expand Down
111 changes: 110 additions & 1 deletion src/staticdata.c
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ void *native_functions; // opaque jl_native_code_desc_t blob used for fetching

// table of struct field addresses to rewrite during saving
static htable_t field_replace;
static htable_t bits_replace;
static htable_t relocatable_ext_cis;

// array of definitions for the predefined function pointers
Expand Down Expand Up @@ -1666,7 +1667,23 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
write_padding(f, offset - tot);
tot = offset;
size_t fsz = jl_field_size(t, i);
if (t->name->mutabl && jl_is_cpointer_type(jl_field_type_concrete(t, i)) && *(intptr_t*)slot != -1) {
jl_value_t *replace = (jl_value_t*)ptrhash_get(&bits_replace, (void*)slot);
if (replace != HT_NOTFOUND) {
assert(t->name->mutabl && !jl_field_isptr(t, i));
jl_value_t *rty = jl_typeof(replace);
size_t sz = jl_datatype_size(rty);
ios_write(f, (const char*)replace, sz);
jl_value_t *ft = jl_field_type_concrete(t, i);
int isunion = jl_is_uniontype(ft);
unsigned nth = 0;
if (!jl_find_union_component(ft, rty, &nth))
assert(0 && "invalid field assignment to isbits union");
assert(sz <= fsz - isunion);
write_padding(f, fsz - sz - isunion);
if (isunion)
write_uint8(f, nth);
}
else if (t->name->mutabl && jl_is_cpointer_type(jl_field_type_concrete(t, i)) && *(intptr_t*)slot != -1) {
// reset Ptr fields to C_NULL (but keep MAP_FAILED / INVALID_HANDLE)
assert(!jl_field_isptr(t, i));
write_pointer(f);
Expand Down Expand Up @@ -2660,6 +2677,65 @@ jl_mutex_t global_roots_lock;
extern jl_mutex_t world_counter_lock;
extern size_t jl_require_world;

jl_mutex_t precompile_field_replace_lock;
jl_svec_t *precompile_field_replace JL_GLOBALLY_ROOTED;

static inline jl_value_t *get_checked_fieldindex(const char *name, jl_datatype_t *st, jl_value_t *v, jl_value_t *arg, int mutabl)
{
if (mutabl) {
if (st == jl_module_type)
jl_error("cannot assign variables in other modules");
if (!st->name->mutabl)
jl_errorf("%s: immutable struct of type %s cannot be changed", name, jl_symbol_name(st->name->name));
}
size_t idx;
if (jl_is_long(arg)) {
idx = jl_unbox_long(arg) - 1;
if (idx >= jl_datatype_nfields(st))
jl_bounds_error(v, arg);
}
else if (jl_is_symbol(arg)) {
idx = jl_field_index(st, (jl_sym_t*)arg, 1);
arg = jl_box_long(idx);
}
else {
jl_value_t *ts[2] = {(jl_value_t*)jl_long_type, (jl_value_t*)jl_symbol_type};
jl_value_t *t = jl_type_union(ts, 2);
jl_type_error(name, t, arg);
}
if (mutabl && jl_field_isconst(st, idx)) {
jl_errorf("%s: const field .%s of type %s cannot be changed", name,
jl_symbol_name((jl_sym_t*)jl_svecref(jl_field_names(st), idx)), jl_symbol_name(st->name->name));
}
return arg;
}

JL_DLLEXPORT void jl_set_precompile_field_replace(jl_value_t *val, jl_value_t *field, jl_value_t *newval)
{
if (!jl_generating_output())
return;
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(val);
jl_value_t *idx = get_checked_fieldindex("setfield!", st, val, field, 1);
JL_GC_PUSH1(&idx);
size_t idxval = jl_unbox_long(idx);
jl_value_t *ft = jl_field_type_concrete(st, idxval);
if (!jl_isa(newval, ft))
jl_type_error("setfield!", ft, newval);
JL_LOCK(&precompile_field_replace_lock);
if (precompile_field_replace == NULL) {
precompile_field_replace = jl_alloc_svec(3);
jl_svecset(precompile_field_replace, 0, jl_alloc_vec_any(0));
jl_svecset(precompile_field_replace, 1, jl_alloc_vec_any(0));
jl_svecset(precompile_field_replace, 2, jl_alloc_vec_any(0));
}
jl_array_ptr_1d_push((jl_array_t*)jl_svecref(precompile_field_replace, 0), val);
jl_array_ptr_1d_push((jl_array_t*)jl_svecref(precompile_field_replace, 1), idx);
jl_array_ptr_1d_push((jl_array_t*)jl_svecref(precompile_field_replace, 2), newval);
JL_GC_POP();
JL_UNLOCK(&precompile_field_replace_lock);
}


JL_DLLEXPORT int jl_is_globally_rooted(jl_value_t *val JL_MAYBE_UNROOTED) JL_NOTSAFEPOINT
{
if (jl_is_datatype(val)) {
Expand Down Expand Up @@ -2779,6 +2855,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
jl_array_t *ext_targets, jl_array_t *edges) JL_GC_DISABLED
{
htable_new(&field_replace, 0);
htable_new(&bits_replace, 0);
// strip metadata and IR when requested
if (jl_options.strip_metadata || jl_options.strip_ir)
jl_strip_all_codeinfos();
Expand All @@ -2790,6 +2867,37 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
arraylist_new(&gvars, 0);
arraylist_t external_fns;
arraylist_new(&external_fns, 0);
// prepare hash table with any fields the user wanted us to rewrite during serialization
if (precompile_field_replace) {
jl_array_t *vals = (jl_array_t*)jl_svecref(precompile_field_replace, 0);
jl_array_t *fields = (jl_array_t*)jl_svecref(precompile_field_replace, 1);
jl_array_t *newvals = (jl_array_t*)jl_svecref(precompile_field_replace, 2);
size_t i, l = jl_array_nrows(vals);
assert(jl_array_nrows(fields) == l && jl_array_nrows(newvals) == l);
for (i = 0; i < l; i++) {
jl_value_t *val = jl_array_ptr_ref(vals, i);
size_t field = jl_unbox_long(jl_array_ptr_ref(fields, i));
jl_value_t *newval = jl_array_ptr_ref(newvals, i);
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(val);
size_t offs = jl_field_offset(st, field);
char *fldaddr = (char*)val + offs;
if (jl_field_isptr(st, field)) {
record_field_change((jl_value_t**)fldaddr, newval);
}
else {
// replace the bits
ptrhash_put(&bits_replace, (void*)fldaddr, newval);
// and any pointers inside
jl_datatype_t *rty = (jl_datatype_t*)jl_typeof(newval);
const jl_datatype_layout_t *layout = rty->layout;
size_t j, np = layout->npointers;
for (j = 0; j < np; j++) {
uint32_t ptr = jl_ptr_offset(rty, j);
record_field_change((jl_value_t**)fldaddr + ptr, *(((jl_value_t**)newval) + ptr));
}
}
}
}

int en = jl_gc_enable(0);
if (native_functions) {
Expand Down Expand Up @@ -3130,6 +3238,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
arraylist_free(&gvars);
arraylist_free(&external_fns);
htable_free(&field_replace);
htable_free(&bits_replace);
htable_free(&serialization_order);
htable_free(&nullptrs);
htable_free(&symbol_table);
Expand Down
22 changes: 21 additions & 1 deletion test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,19 @@ let once = PerProcess{Int}(() -> error("expected"))
@test_throws ErrorException("PerProcess initializer failed previously") once()
end

let once = PerThread(() -> return [nothing])
let e = Base.Event(true),
started = Channel{Int16}(Inf),
once = PerThread() do
push!(started, threadid())
wait(e)
return [nothing]
end
@test typeof(once) <: PerThread{Vector{Nothing}}
notify(e)
x = once()
@test x === once() === fetch(@async once())
@test take!(started) == threadid()
@test isempty(started)
tids = zeros(UInt, 50)
onces = Vector{Vector{Nothing}}(undef, length(tids))
for i = 1:length(tids)
Expand All @@ -420,7 +429,18 @@ let once = PerThread(() -> return [nothing])
err == 0 || Base.uv_error("uv_thread_join", err)
end
end
# let them finish in 5 batches of 10
for i = 1:length(tids) ÷ 10
for i = 1:10
@test take!(started) != threadid()
end
for i = 1:10
notify(e)
end
end
@test isempty(started)
waitallthreads(tids)
@test isempty(started)
@test length(IdSet{eltype(onces)}(onces)) == length(onces) # make sure every object is unique

end
Expand Down

0 comments on commit 0438f2a

Please sign in to comment.