Skip to content

Commit

Permalink
fix use-after-free in test (detected in win32 CI)
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash authored and KristofferC committed Oct 21, 2024
1 parent 53ca3b2 commit 3bff707
Showing 1 changed file with 41 additions and 34 deletions.
75 changes: 41 additions & 34 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,8 @@ let e = Base.Event(true),
newthreads = zeros(Int16, length(tids))
onces = Vector{Vector{Nothing}}(undef, length(tids))
allonces = Vector{Vector{Vector{Nothing}}}(undef, length(tids))
for i = 1:length(tids)
function cl()
# allocate closure memory to last until all threads are started
cls = [function cl()
GC.gc(false) # stress test the GC-safepoint mechanics of jl_adopt_thread
try
newthreads[i] = threadid()
Expand All @@ -434,43 +434,50 @@ let e = Base.Event(true),
GC.gc(false) # stress test the GC-safepoint mechanics of jl_delete_thread
nothing
end
function threadcallclosure(cl::F) where {F} # create sparam so we can reference the type of cl in the ccall type
threadwork = @cfunction cl -> cl() Cvoid (Ref{F},) # create a cfunction that specializes on cl as an argument and calls it
err = @ccall uv_thread_create(Ref(tids, i)::Ptr{UInt}, threadwork::Ptr{Cvoid}, cl::Ref{F})::Cint # call that on a thread
err == 0 || Base.uv_error("uv_thread_create", err)
end
threadcallclosure(cl)
end
@noinline function waitallthreads(tids)
for i = 1:length(tids)]
GC.@preserve cls begin # this memory must survive until each corresponding thread exits (waitallthreads / uv_thread_join)
Base.preserve_handle(cls)
for i = 1:length(tids)
tid = Ref(tids, i)
tidp = Base.unsafe_convert(Ptr{UInt}, tid)::Ptr{UInt}
gc_state = @ccall jl_gc_safe_enter()::Int8
GC.@preserve tid err = @ccall uv_thread_join(tidp::Ptr{UInt})::Cint
@ccall jl_gc_safe_leave(gc_state::Int8)::Cvoid
err == 0 || Base.uv_error("uv_thread_join", err)
end
end
try
# let them finish in batches of 10
for i = 1:length(tids) ÷ 10
for i = 1:10
newid = take!(started)
@test newid != threadid()
function threadcallclosure(tid::Ref{UInt}, cl::Ref{F}) where {F} # create sparam so we can reference the type of cl in the ccall type
threadwork = @cfunction cl -> cl() Cvoid (Ref{F},) # create a cfunction that specializes on cl as an argument and calls it
err = @ccall uv_thread_create(tid::Ptr{UInt}, threadwork::Ptr{Cvoid}, cl::Ref{F})::Cint # call that on a thread
err == 0 || Base.uv_error("uv_thread_create", err)
nothing
end
for i = 1:10
push!(finish, nothing)
threadcallclosure(Ref(tids, i), Ref(cls, i))
end
@noinline function waitallthreads(tids, cls)
for i = 1:length(tids)
tid = Ref(tids, i)
tidp = Base.unsafe_convert(Ptr{UInt}, tid)::Ptr{UInt}
gc_state = @ccall jl_gc_safe_enter()::Int8
GC.@preserve tid err = @ccall uv_thread_join(tidp::Ptr{UInt})::Cint
@ccall jl_gc_safe_leave(gc_state::Int8)::Cvoid
err == 0 || Base.uv_error("uv_thread_join", err)
end
Base.unpreserve_handle(cls)
end
@test isempty(started)
# now run the second part of the test where they all try to access the other threads elements
notify(starttest2)
finally
for _ = 1:length(tids)
# run IO loop until all threads are close to exiting
take!(exiting)
try
# let them finish in batches of 10
for i = 1:length(tids) ÷ 10
for i = 1:10
newid = take!(started)
@test newid != threadid()
end
for i = 1:10
push!(finish, nothing)
end
end
@test isempty(started)
# now run the second part of the test where they all try to access the other threads elements
notify(starttest2)
finally
for _ = 1:length(tids)
# run IO loop until all threads are close to exiting
take!(exiting)
end
waitallthreads(tids, cls)
end
waitallthreads(tids)
end
@test isempty(started)
@test isempty(finish)
Expand Down

0 comments on commit 3bff707

Please sign in to comment.