diff --git a/examples/hybrid/driver.jl b/examples/hybrid/driver.jl index 140b5fad43..c17149938e 100644 --- a/examples/hybrid/driver.jl +++ b/examples/hybrid/driver.jl @@ -99,27 +99,11 @@ else ) end p = get_cache(ᶜlocal_geometry, ᶠlocal_geometry, Y, dt, upwinding_mode) -if ode_algorithm <: Union{ - OrdinaryDiffEq.OrdinaryDiffEqImplicitAlgorithm, - OrdinaryDiffEq.OrdinaryDiffEqAdaptiveImplicitAlgorithm, -} - use_transform = !(ode_algorithm in (Rosenbrock23, Rosenbrock32)) - W = SchurComplementW(Y, use_transform, jacobian_flags, test_implicit_solver) - jac_kwargs = - use_transform ? (; jac_prototype = W, Wfact_t = Wfact!) : - (; jac_prototype = W, Wfact = Wfact!) - - alg_kwargs = (; linsolve = linsolve!) - if ode_algorithm <: Union{ - OrdinaryDiffEq.OrdinaryDiffEqNewtonAlgorithm, - OrdinaryDiffEq.OrdinaryDiffEqNewtonAdaptiveAlgorithm, - } - alg_kwargs = - (; alg_kwargs..., nlsolve = NLNewton(; max_iter = max_newton_iters)) - end -else - jac_kwargs = alg_kwargs = (;) -end + +include("ode_config.jl") + +ode_algo = + ode_configuration(FT; ode_name = string(ode_algorithm), max_newton_iters) if haskey(ENV, "OUTPUT_DIR") output_dir = ENV["OUTPUT_DIR"] @@ -164,7 +148,7 @@ callback = problem = SplitODEProblem( ODEFunction( implicit_tendency!; - jac_kwargs..., + jac_kwargs(ode_algo, Y, jacobian_flags)..., tgrad = (∂Y∂t, Y, p, t) -> (∂Y∂t .= FT(0)), ), remaining_tendency!, @@ -174,7 +158,7 @@ problem = SplitODEProblem( ) integrator = OrdinaryDiffEq.init( problem, - ode_algorithm(; alg_kwargs...); + ode_algo; saveat = dt_save_to_sol == 0 ? [] : dt_save_to_sol, callback = callback, dt = dt, diff --git a/examples/hybrid/ode_config.jl b/examples/hybrid/ode_config.jl new file mode 100644 index 0000000000..20b6ac9e4b --- /dev/null +++ b/examples/hybrid/ode_config.jl @@ -0,0 +1,88 @@ +import DiffEqBase +import OrdinaryDiffEq as ODE +import ClimaTimeSteppers as CTS + +is_explicit_CTS_algo_type(alg_or_tableau) = + alg_or_tableau <: CTS.ERKAlgorithmName + +is_imex_CTS_algo_type(alg_or_tableau) = + alg_or_tableau <: CTS.IMEXARKAlgorithmName + +is_implicit_type(::typeof(ODE.IMEXEuler)) = true +is_implicit_type(alg_or_tableau) = + alg_or_tableau <: Union{ + ODE.OrdinaryDiffEqImplicitAlgorithm, + ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm, + } || is_imex_CTS_algo_type(alg_or_tableau) + +is_ordinary_diffeq_newton(::typeof(ODE.IMEXEuler)) = true +is_ordinary_diffeq_newton(alg_or_tableau) = + alg_or_tableau <: Union{ + ODE.OrdinaryDiffEqNewtonAlgorithm, + ODE.OrdinaryDiffEqNewtonAdaptiveAlgorithm, + } + +is_imex_CTS_algo(::CTS.IMEXAlgorithm) = true +is_imex_CTS_algo(::DiffEqBase.AbstractODEAlgorithm) = false + +is_implicit(::ODE.OrdinaryDiffEqImplicitAlgorithm) = true +is_implicit(::ODE.OrdinaryDiffEqAdaptiveImplicitAlgorithm) = true +is_implicit(ode_algo) = is_imex_CTS_algo(ode_algo) + +is_rosenbrock(::ODE.Rosenbrock23) = true +is_rosenbrock(::ODE.Rosenbrock32) = true +is_rosenbrock(::DiffEqBase.AbstractODEAlgorithm) = false +use_transform(ode_algo) = + !(is_imex_CTS_algo(ode_algo) || is_rosenbrock(ode_algo)) + +function jac_kwargs(ode_algo, Y, jacobi_flags) + if is_implicit(ode_algo) + W = SchurComplementW(Y, use_transform(ode_algo), jacobi_flags) + if use_transform(ode_algo) + return (; jac_prototype = W, Wfact_t = Wfact!) + else + return (; jac_prototype = W, Wfact = Wfact!) + end + else + return NamedTuple() + end +end + +function ode_configuration( + ::Type{FT}; + ode_name::Union{String, Nothing} = nothing, + max_newton_iters = nothing, +) where {FT} + if occursin(".", ode_name) + ode_name = split(ode_name, ".")[end] + end + ode_sym = Symbol(ode_name) + alg_or_tableau = if hasproperty(ODE, ode_sym) + @warn "apply_limiter flag is ignored for OrdinaryDiffEq algorithms" + getproperty(ODE, ode_sym) + else + getproperty(CTS, ode_sym) + end + @info "Using ODE config: `$alg_or_tableau`" + + if is_explicit_CTS_algo_type(alg_or_tableau) + return CTS.ExplicitAlgorithm(alg_or_tableau()) + elseif !is_implicit_type(alg_or_tableau) + return alg_or_tableau() + elseif is_ordinary_diffeq_newton(alg_or_tableau) + if max_newton_iters == 1 + error("OridinaryDiffEq requires at least 2 Newton iterations") + end + # κ like a relative tolerance; its default value in ODE is 0.01 + nlsolve = ODE.NLNewton(; + κ = max_newton_iters == 2 ? Inf : 0.01, + max_iter = max_newton_iters, + ) + return alg_or_tableau(; linsolve = linsolve!, nlsolve) + elseif is_imex_CTS_algo_type(alg_or_tableau) + newtons_method = CTS.NewtonsMethod(; max_iters = max_newton_iters) + return CTS.IMEXAlgorithm(alg_or_tableau(), newtons_method) + else + return alg_or_tableau(; linsolve = linsolve!) + end +end