Skip to content

Commit

Permalink
Add better thread safety for State callback functions.
Browse files Browse the repository at this point in the history
  • Loading branch information
dpad committed May 4, 2021
1 parent 628b33a commit 15bd62f
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions src/dynamics/orbital_trajectories.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,16 +104,19 @@ DiffEqBase.solve(state::State, args...; reltol=1e-10, abstol=1e-10, kwargs...) =
DiffEqBase.__solve(state::State; kwargs...) = DiffEqBase.__solve(state, DEFAULT_ALG; kwargs...)

# The __solve() method does the actual heavy lifting, including converting to a Trajectory.
function DiffEqBase.__solve(state::State, alg::DiffEqBase.DEAlgorithm; userdata=Dict(), kwargs...)
function DiffEqBase.__solve(state::State, alg::DiffEqBase.DEAlgorithm; userdata=Dict(), callback=nothing, kwargs...)
default_frame = default_reference_frame(state.model)
real_state = convert_to_frame(state, default_frame)

# Pass the default state into the underlying solver
# TODO: Remove the need for this in DiffCorrectAxisymmetric
userdata_new = deepcopy(userdata)
userdata_new[:real_state] = real_state
userdata = deepcopy(userdata)
userdata[:real_state] = real_state

# Copy the callbacks (for thread-safety)
callback = deepcopy(callback)

# Call the underlying solver
raw_sol = solve(real_state.prob, alg; userdata=userdata_new, kwargs...)
raw_sol = solve(real_state.prob, alg; userdata, callback, kwargs...)
return Trajectory(state.model, default_frame, raw_sol)
end

0 comments on commit 15bd62f

Please sign in to comment.