Skip to content

Commit

Permalink
Specify Callback Argument Type in init Functions in ODE solvers (2N, …
Browse files Browse the repository at this point in the history
…3Sstar, PairedExplicitRK, SSP) (#2026)

* Update function signatures to include type annotations for callback parameter

* Update error messages for unsupported continuous callbacks in time integration methods

* Update error messages for unsupported continuous callbacks in time integration methods

* Update src/time_integration/methods_2N.jl

Co-authored-by: Joshua Lampert <[email protected]>

* Update src/time_integration/methods_3Sstar.jl

Co-authored-by: Joshua Lampert <[email protected]>

* change error message in ssp integrator

* fix error message in ssp

---------

Co-authored-by: Joshua Lampert <[email protected]>
  • Loading branch information
warisa-r and JoshuaLampert authored Aug 13, 2024
1 parent e1aabb3 commit 5f5a232
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 16 deletions.
6 changes: 2 additions & 4 deletions src/time_integration/methods_2N.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ function Base.getproperty(integrator::SimpleIntegrator2N, field::Symbol)
end

function init(ode::ODEProblem, alg::SimpleAlgorithm2N;
dt, callback = nothing, kwargs...)
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
u = copy(ode.u0)
du = similar(u)
u_tmp = similar(u)
Expand All @@ -119,13 +119,11 @@ function init(ode::ODEProblem, alg::SimpleAlgorithm2N;
# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
error("unsupported")
throw(ArgumentError("Continuous callbacks are unsupported with the 2N storage time integration methods."))
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
elseif !isnothing(callback)
error("unsupported")
end

return integrator
Expand Down
6 changes: 2 additions & 4 deletions src/time_integration/methods_3Sstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ function Base.getproperty(integrator::SimpleIntegrator3Sstar, field::Symbol)
end

function init(ode::ODEProblem, alg::SimpleAlgorithm3Sstar;
dt, callback = nothing, kwargs...)
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
u = copy(ode.u0)
du = similar(u)
u_tmp1 = similar(u)
Expand All @@ -189,13 +189,11 @@ function init(ode::ODEProblem, alg::SimpleAlgorithm3Sstar;
# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
error("unsupported")
throw(ArgumentError("Continuous callbacks are unsupported with the 3 star time integration methods."))
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
elseif !isnothing(callback)
error("unsupported")
end

return integrator
Expand Down
6 changes: 2 additions & 4 deletions src/time_integration/methods_SSP.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ of type `SimpleAlgorithmSSP`.
This is an experimental feature and may change in future releases.
"""
function solve(ode::ODEProblem, alg = SimpleSSPRK33()::SimpleAlgorithmSSP;
dt, callback = nothing, kwargs...)
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
u = copy(ode.u0)
du = similar(u)
r0 = similar(u)
Expand All @@ -157,13 +157,11 @@ function solve(ode::ODEProblem, alg = SimpleSSPRK33()::SimpleAlgorithmSSP;
# initialize callbacks
if callback isa CallbackSet
foreach(callback.continuous_callbacks) do cb
error("unsupported")
throw(ArgumentError("Continuous callbacks are unsupported with the SSP time integration methods."))
end
foreach(callback.discrete_callbacks) do cb
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
elseif !isnothing(callback)
error("unsupported")
end

for stage_callback in alg.stage_callbacks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ function Base.getproperty(integrator::PairedExplicitRK, field::Symbol)
end

function init(ode::ODEProblem, alg::PairedExplicitRK2;
dt, callback = nothing, kwargs...)
dt, callback::Union{CallbackSet, Nothing} = nothing, kwargs...)
u0 = copy(ode.u0)
du = zero(u0)
u_tmp = zero(u0)
Expand All @@ -286,13 +286,11 @@ function init(ode::ODEProblem, alg::PairedExplicitRK2;
# initialize callbacks
if callback isa CallbackSet
for cb in callback.continuous_callbacks
error("unsupported")
throw(ArgumentError("Continuous callbacks are unsupported with paired explicit Runge-Kutta methods."))
end
for cb in callback.discrete_callbacks
cb.initialize(cb, integrator.u, integrator.t, integrator)
end
elseif !isnothing(callback)
error("unsupported")
end

return integrator
Expand Down

0 comments on commit 5f5a232

Please sign in to comment.