diff --git a/src/solver/impls/arkode/arkode_mri.cxx b/src/solver/impls/arkode/arkode_mri.cxx
new file mode 100644
index 0000000000..1fb9a79a24
--- /dev/null
+++ b/src/solver/impls/arkode/arkode_mri.cxx
@@ -0,0 +1,1090 @@
+/**************************************************************************
+ * Experimental interface to SUNDIALS ARKode MRI solver
+ *
+ * NOTE: ARKode is still in beta testing so use with cautious optimism
+ *
+ **************************************************************************
+ * Copyright 2010-2024 BOUT++ contributors
+ *
+ * This file is part of BOUT++.
+ *
+ * BOUT++ is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * BOUT++ is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with BOUT++. If not, see .
+ *
+ **************************************************************************/
+
+#include "bout/build_config.hxx"
+
+#include "arkode_mri.hxx"
+
+#if BOUT_HAS_ARKODE
+
+#include "bout/bout_enum_class.hxx"
+#include "bout/boutcomm.hxx"
+#include "bout/boutexception.hxx"
+#include "bout/field3d.hxx"
+#include "bout/mesh.hxx"
+#include "bout/msg_stack.hxx"
+#include "bout/options.hxx"
+#include "bout/output.hxx"
+#include "bout/unused.hxx"
+#include "bout/utils.hxx"
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+class Field2D;
+
+// NOLINTBEGIN(readability-identifier-length)
+namespace {
+int arkode_rhs_s_explicit(BoutReal t, N_Vector u, N_Vector du, void* user_data);
+int arkode_rhs_s_implicit(BoutReal t, N_Vector u, N_Vector du, void* user_data);
+int arkode_rhs_f_explicit(BoutReal t, N_Vector u, N_Vector du, void* user_data);
+int arkode_rhs_f_implicit(BoutReal t, N_Vector u, N_Vector du, void* user_data);
+int arkode_s_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data);
+int arkode_f_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data);
+
+int arkode_s_bbd_rhs(sunindextype Nlocal, BoutReal t, N_Vector u, N_Vector du,
+ void* user_data);
+int arkode_f_bbd_rhs(sunindextype Nlocal, BoutReal t, N_Vector u, N_Vector du,
+ void* user_data);
+int arkode_s_pre(BoutReal t, N_Vector yy, N_Vector yp, N_Vector rvec, N_Vector zvec,
+ BoutReal gamma, BoutReal delta, int lr, void* user_data);
+int arkode_f_pre(BoutReal t, N_Vector yy, N_Vector yp, N_Vector rvec, N_Vector zvec,
+ BoutReal gamma, BoutReal delta, int lr, void* user_data);
+
+} // namespace
+// NOLINTEND(readability-identifier-length)
+
+ArkodeMRISolver::ArkodeMRISolver(Options* opts)
+ : Solver(opts), diagnose((*options)["diagnose"]
+ .doc("Print some additional diagnostics")
+ .withDefault(false)),
+ mxsteps((*options)["mxstep"]
+ .doc("Maximum number of steps to take between outputs")
+ .withDefault(500)),
+ treatment((*options)["treatment"]
+ .doc("Use default capability (imex) or provide a specific treatment: "
+ "implicit or explicit")
+ .withDefault(MRI_Treatment::ImEx)),
+ inner_treatment((*options)["inner_treatment"]
+ .doc("Use default capability (imex) or provide a specific inner_treatment: "
+ "implicit or explicit")
+ .withDefault(MRI_Treatment::ImEx)),
+ set_linear(
+ (*options)["set_linear"]
+ .doc("Use linear implicit solver (only evaluates jacobian inversion once)")
+ .withDefault(false)),
+ inner_set_linear(
+ (*options)["inner_set_linear"]
+ .doc("Use linear implicit solver (only evaluates jacobian inversion once)")
+ .withDefault(false)),
+ fixed_step((*options)["fixed_step"]
+ .doc("Solve explicit portion in fixed timestep mode. NOTE: This is "
+ "not recommended except for code comparison")
+ .withDefault(false)),
+ order((*options)["order"].doc("Order of internal step").withDefault(4)),
+ adap_method(
+ (*options)["adap_method"]
+ .doc("Set timestep adaptivity function: pid, pi, i, explicit_gustafsson, "
+ "implicit_gustafsson, imex_gustafsson.")
+ .withDefault(MRI_AdapMethod::PID)),
+ abstol((*options)["atol"].doc("Absolute tolerance").withDefault(1.0e-12)),
+ reltol((*options)["rtol"].doc("Relative tolerance").withDefault(1.0e-5)),
+ use_vector_abstol((*options)["use_vector_abstol"]
+ .doc("Use separate absolute tolerance for each field")
+ .withDefault(false)),
+ use_precon((*options)["use_precon"]
+ .doc("Use user-supplied preconditioner function")
+ .withDefault(false)),
+ inner_use_precon((*options)["inner_use_precon"]
+ .doc("Use user-supplied preconditioner function")
+ .withDefault(false)),
+ maxl(
+ (*options)["maxl"].doc("Number of Krylov basis vectors to use").withDefault(0)),
+ inner_maxl(
+ (*options)["inner_maxl"].doc("Number of Krylov basis vectors to use").withDefault(0)),
+ rightprec((*options)["rightprec"]
+ .doc("Use right preconditioning instead of left preconditioning")
+ .withDefault(false)),
+ suncontext(createSUNContext(BoutComm::get())) {
+ has_constraints = false; // This solver doesn't have constraints
+
+ // Add diagnostics to output
+ add_int_diagnostic(nsteps, "arkode_nsteps", "Cumulative number of internal steps");
+ add_int_diagnostic(inner_nsteps, "arkode_inner_nsteps", "Cumulative number of inner internal steps");
+ add_int_diagnostic(nfe_evals, "arkode_nfe_evals",
+ "No. of calls to fe (explicit portion of the right-hand-side "
+ "function) function");
+ add_int_diagnostic(inner_nfe_evals, "arkode_inner_nfe_evals",
+ "No. of calls to fe (explicit portion of the inner right-hand-side "
+ "function) function");
+ add_int_diagnostic(nfi_evals, "arkode_nfi_evals",
+ "No. of calls to fi (implicit portion of the right-hand-side "
+ "function) function");
+ add_int_diagnostic(inner_nfi_evals, "arkode_inner_nfi_evals",
+ "No. of calls to fi (implicit portion of inner the right-hand-side "
+ "function) function");
+ add_int_diagnostic(nniters, "arkode_nniters", "No. of nonlinear solver iterations");
+ add_int_diagnostic(inner_nniters, "arkode_inner_nniters", "No. of inner nonlinear solver iterations");
+ add_int_diagnostic(npevals, "arkode_npevals", "No. of preconditioner evaluations");
+ add_int_diagnostic(inner_npevals, "arkode_inner_npevals", "No. of inner preconditioner evaluations");
+ add_int_diagnostic(nliters, "arkode_nliters", "No. of linear iterations");
+ add_int_diagnostic(inner_nliters, "arkode_inner_nliters", "No. of inner linear iterations");
+}
+
+ArkodeMRISolver::~ArkodeMRISolver() {
+ N_VDestroy(uvec);
+ ARKodeFree(&arkode_mem);
+ ARKodeFree(&inner_arkode_mem);
+ SUNLinSolFree(sun_solver);
+ SUNLinSolFree(inner_sun_solver);
+ SUNNonlinSolFree(nonlinear_solver);
+ SUNNonlinSolFree(inner_nonlinear_solver);
+ MRIStepInnerStepper_Free(&inner_stepper);
+
+ SUNAdaptController_Destroy(controller);
+ SUNAdaptController_Destroy(inner_controller);
+}
+
+/**************************************************************************
+ * Initialise
+ **************************************************************************/
+
+int ArkodeMRISolver::init() {
+ TRACE("Initialising ARKODE MRI solver");
+
+ Solver::init();
+
+ output.write("Initialising SUNDIALS' ARKODE MRI solver\n");
+
+ // Calculate number of variables (in generic_solver)
+ const int local_N = getLocalN();
+
+ // Get total problem size
+ int neq;
+ if (bout::globals::mpi->MPI_Allreduce(&local_N, &neq, 1, MPI_INT, MPI_SUM,
+ BoutComm::get())) {
+ throw BoutException("Allreduce localN -> GlobalN failed!\n");
+ }
+
+ output.write("\t3d fields = {:d}, 2d fields = {:d} neq={:d}, local_N={:d}\n", n3Dvars(),
+ n2Dvars(), neq, local_N);
+
+ // Allocate memory
+ uvec = callWithSUNContext(N_VNew_Parallel, suncontext, BoutComm::get(), local_N, neq);
+ if (uvec == nullptr) {
+ throw BoutException("SUNDIALS memory allocation failed\n");
+ }
+
+ // Put the variables into uvec
+ save_vars(N_VGetArrayPointer(uvec));
+
+ switch (inner_treatment) {
+ case MRI_Treatment::ImEx:
+ inner_arkode_mem = callWithSUNContext(ARKStepCreate, suncontext, arkode_rhs_f_explicit,
+ arkode_rhs_f_implicit, simtime, uvec);
+ output_info.write("\tUsing ARKode ImEx inner solver \n");
+ break;
+ case MRI_Treatment::Explicit:
+ inner_arkode_mem =
+ callWithSUNContext(ARKStepCreate, suncontext, arkode_f_rhs, nullptr, simtime, uvec);
+ output_info.write("\tUsing ARKode Explicit inner solver \n");
+ break;
+ case MRI_Treatment::Implicit:
+ inner_arkode_mem =
+ callWithSUNContext(ARKStepCreate, suncontext, nullptr, arkode_f_rhs, simtime, uvec);
+ output_info.write("\tUsing ARKode Implicit inner solver \n");
+ break;
+ default:
+ throw BoutException("Invalid inner_treatment: {}\n", toString(inner_treatment));
+ }
+ if (inner_arkode_mem == nullptr) {
+ throw BoutException("ARKStepCreate failed\n");
+ }
+
+ // For callbacks, need pointer to solver object
+ if (ARKodeSetUserData(inner_arkode_mem, this) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetUserData failed\n");
+ }
+
+ if(inner_treatment != MRI_Treatment::Explicit)
+ if (ARKodeSetLinear(inner_arkode_mem, inner_set_linear) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetLinear failed\n");
+ }
+
+ if (fixed_step) {
+ // If not given, default to adaptive timestepping
+ const BoutReal inner_fixed_timestep = (*options)["inner_timestep"];
+ if (ARKodeSetFixedStep(inner_arkode_mem, inner_fixed_timestep) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetFixedStep failed\n");
+ }
+ }
+
+ if (ARKodeSetOrder(inner_arkode_mem, order) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetOrder failed\n");
+ }
+
+ if (ARKStepCreateMRIStepInnerStepper(inner_arkode_mem, &inner_stepper) != ARK_SUCCESS) {
+ throw BoutException("ARKStepCreateMRIStepInnerStepper failed\n");
+ }
+
+ // Initialize the slow integrator. Specify the explicit slow right-hand side
+ // function in y'=fe(t,y)+fi(t,y)+ff(t,y), the inital time T0, the
+ // initial dependent variable vector y, and the fast integrator.
+
+ switch (treatment) {
+ case MRI_Treatment::ImEx:
+ arkode_mem = callWithSUNContext(MRIStepCreate, suncontext, arkode_rhs_s_explicit, arkode_rhs_s_implicit,
+ simtime, uvec, inner_stepper);
+ output_info.write("\tUsing ARKode ImEx solver \n");
+ break;
+ case MRI_Treatment::Explicit:
+ arkode_mem = callWithSUNContext(MRIStepCreate, suncontext, arkode_s_rhs, nullptr,
+ simtime, uvec, inner_stepper);
+ output_info.write("\tUsing ARKode Explicit solver \n");
+ break;
+ case MRI_Treatment::Implicit:
+ arkode_mem = callWithSUNContext(MRIStepCreate, suncontext, nullptr, arkode_s_rhs,
+ simtime, uvec, inner_stepper);
+ output_info.write("\tUsing ARKode Implicit solver \n");
+ break;
+ default:
+ throw BoutException("Invalid treatment: {}\n", toString(treatment));
+ }
+ if (arkode_mem == nullptr) {
+ throw BoutException("MRIStepCreate failed\n");
+ }
+
+ // For callbacks, need pointer to solver object
+ if (ARKodeSetUserData(arkode_mem, this) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetUserData failed\n");
+ }
+
+ if(treatment != MRI_Treatment::Explicit)
+ if (ARKodeSetLinear(arkode_mem, set_linear) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetLinear failed\n");
+ }
+
+ if (fixed_step) {
+ // If not given, default to adaptive timestepping
+ const BoutReal fixed_timestep = (*options)["timestep"];
+ if (ARKodeSetFixedStep(arkode_mem, fixed_timestep) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetFixedStep failed\n");
+ }
+ }
+
+ if (ARKodeSetOrder(arkode_mem, order) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetOrder failed\n");
+ }
+
+ switch (adap_method) {
+ case MRI_AdapMethod::PID:
+ controller = SUNAdaptController_PID(suncontext);
+ inner_controller = SUNAdaptController_PID(suncontext);
+ break;
+ case MRI_AdapMethod::PI:
+ controller = SUNAdaptController_PI(suncontext);
+ inner_controller = SUNAdaptController_PI(suncontext);
+ break;
+ case MRI_AdapMethod::I:
+ controller = SUNAdaptController_I(suncontext);
+ inner_controller = SUNAdaptController_I(suncontext);
+ break;
+ case MRI_AdapMethod::Explicit_Gustafsson:
+ controller = SUNAdaptController_ExpGus(suncontext);
+ inner_controller = SUNAdaptController_ExpGus(suncontext);
+ break;
+ case MRI_AdapMethod::Implicit_Gustafsson:
+ controller = SUNAdaptController_ImpGus(suncontext);
+ inner_controller = SUNAdaptController_ImpGus(suncontext);
+ break;
+ case MRI_AdapMethod::ImEx_Gustafsson:
+ controller = SUNAdaptController_ImExGus(suncontext);
+ inner_controller = SUNAdaptController_ImExGus(suncontext);
+ break;
+ default:
+ throw BoutException("Invalid adap_method\n");
+ }
+
+ // if (ARKodeSetAdaptController(arkode_mem, controller) != ARK_SUCCESS) {
+ // throw BoutException("ARKodeSetAdaptController failed\n");
+ // }
+
+ // if (ARKodeSetAdaptivityAdjustment(arkode_mem, 0) != ARK_SUCCESS) {
+ // throw BoutException("ARKodeSetAdaptivityAdjustment failed\n");
+ // }
+
+ if (ARKodeSetFixedStep(arkode_mem, 0.001) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetAdaptController failed\n");
+ }
+
+ if (ARKodeSetAdaptController(inner_arkode_mem, controller) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetAdaptController failed\n");
+ }
+
+ if (ARKodeSetAdaptivityAdjustment(inner_arkode_mem, 0) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetAdaptivityAdjustment failed\n");
+ }
+
+ if (use_vector_abstol) {
+ std::vector f2dtols;
+ f2dtols.reserve(f2d.size());
+ std::transform(begin(f2d), end(f2d), std::back_inserter(f2dtols),
+ [abstol = abstol](const VarStr& f2) {
+ auto& f2_options = Options::root()[f2.name];
+ const auto wrong_name = f2_options.isSet("abstol");
+ if (wrong_name) {
+ output_warn << "WARNING: Option 'abstol' for field " << f2.name
+ << " is deprecated. Please use 'atol' instead\n";
+ }
+ const std::string atol_name = wrong_name ? "abstol" : "atol";
+ return f2_options[atol_name].withDefault(abstol);
+ });
+
+ std::vector f3dtols;
+ f3dtols.reserve(f3d.size());
+ std::transform(begin(f3d), end(f3d), std::back_inserter(f3dtols),
+ [abstol = abstol](const VarStr& f3) {
+ return Options::root()[f3.name]["atol"].withDefault(abstol);
+ });
+
+ N_Vector abstolvec = N_VClone(uvec);
+ if (abstolvec == nullptr) {
+ throw BoutException("SUNDIALS memory allocation (abstol vector) failed\n");
+ }
+
+ set_abstol_values(N_VGetArrayPointer(abstolvec), f2dtols, f3dtols);
+
+ if (ARKodeSVtolerances(arkode_mem, reltol, abstolvec) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSVtolerances failed\n");
+ }
+ if (ARKodeSVtolerances(inner_arkode_mem, reltol, abstolvec) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSVtolerances failed\n");
+ }
+
+ N_VDestroy(abstolvec);
+ } else {
+ if (ARKodeSStolerances(arkode_mem, reltol, abstol) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSStolerances failed\n");
+ }
+ if (ARKodeSStolerances(inner_arkode_mem, reltol, abstol) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSStolerances failed\n");
+ }
+ }
+
+ if (ARKodeSetMaxNumSteps(arkode_mem, mxsteps) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetMaxNumSteps failed\n");
+ }
+ if (ARKodeSetMaxNumSteps(inner_arkode_mem, mxsteps) != ARK_SUCCESS) {
+ throw BoutException("ARKodeSetMaxNumSteps failed\n");
+ }
+
+ if (inner_treatment == MRI_Treatment::ImEx or inner_treatment == MRI_Treatment::Implicit) {
+ {
+ output.write("\tUsing Newton iteration for inner solver\n");
+
+ const auto prectype =
+ inner_use_precon ? (rightprec ? SUN_PREC_RIGHT : SUN_PREC_LEFT) : SUN_PREC_NONE;
+ inner_sun_solver = callWithSUNContext(SUNLinSol_SPGMR, suncontext, uvec, prectype, inner_maxl);
+ if (inner_sun_solver == nullptr) {
+ throw BoutException("Creating SUNDIALS inner linear solver failed\n");
+ }
+ if (ARKodeSetLinearSolver(inner_arkode_mem, inner_sun_solver, nullptr) != ARKLS_SUCCESS) {
+ throw BoutException("ARKodeSetLinearSolver failed for inner solver\n");
+ }
+
+ /// Set Preconditioner
+ if (inner_use_precon) {
+ if (hasPreconditioner()) { // change to inner_hasPreconditioner when it is available
+ output.write("\tUsing user-supplied preconditioner for inner solver\n");
+
+ if (ARKodeSetPreconditioner(inner_arkode_mem, nullptr, arkode_f_pre)
+ != ARKLS_SUCCESS) {
+ throw BoutException("ARKodeSetPreconditioner failed for inner solver\n");
+ }
+ } else {
+ output.write("\tUsing BBD preconditioner for inner solver\n");
+
+ /// Get options
+ // Compute band_width_default from actually added fields, to allow for multiple
+ // Mesh objects
+ //
+ // Previous implementation was equivalent to:
+ // int MXSUB = mesh->xend - mesh->xstart + 1;
+ // int band_width_default = n3Dvars()*(MXSUB+2);
+ const int band_width_default = std::accumulate(
+ begin(f3d), end(f3d), 0, [](int acc, const VarStr& fvar) {
+ Mesh* localmesh = fvar.var->getMesh();
+ return acc + localmesh->xend - localmesh->xstart + 3;
+ });
+
+ const auto mudq = (*options)["mudq"]
+ .doc("Upper half-bandwidth to be used in the difference "
+ "quotient Jacobian approximation")
+ .withDefault(band_width_default);
+ const auto mldq = (*options)["mldq"]
+ .doc("Lower half-bandwidth to be used in the difference "
+ "quotient Jacobian approximation")
+ .withDefault(band_width_default);
+ const auto mukeep = (*options)["mukeep"]
+ .doc("Upper half-bandwidth of the retained banded "
+ "approximate Jacobian block")
+ .withDefault(n3Dvars() + n2Dvars());
+ const auto mlkeep = (*options)["mlkeep"]
+ .doc("Lower half-bandwidth of the retained banded "
+ "approximate Jacobian block")
+ .withDefault(n3Dvars() + n2Dvars());
+
+ if (ARKBBDPrecInit(inner_arkode_mem, local_N, mudq, mldq, mukeep, mlkeep, 0,
+ arkode_f_bbd_rhs, nullptr)
+ != ARKLS_SUCCESS) {
+ throw BoutException("ARKBBDPrecInit failed for inner solver\n");
+ }
+ }
+ } else {
+ // Not using preconditioning
+ output.write("\tNo inner preconditioning\n");
+ }
+ }
+
+ /// Set Jacobian-vector multiplication function
+ output.write("\tUsing difference quotient approximation for Jacobian in the inner solver\n");
+ }
+
+
+ if (treatment == MRI_Treatment::ImEx or treatment == MRI_Treatment::Implicit) {
+ {
+ output.write("\tUsing Newton iteration\n");
+
+ const auto prectype =
+ use_precon ? (rightprec ? SUN_PREC_RIGHT : SUN_PREC_LEFT) : SUN_PREC_NONE;
+ sun_solver = callWithSUNContext(SUNLinSol_SPGMR, suncontext, uvec, prectype, maxl);
+ if (sun_solver == nullptr) {
+ throw BoutException("Creating SUNDIALS linear solver failed\n");
+ }
+ if (ARKodeSetLinearSolver(arkode_mem, sun_solver, nullptr) != ARKLS_SUCCESS) {
+ throw BoutException("ARKodeSetLinearSolver failed\n");
+ }
+
+ /// Set Preconditioner
+ if (use_precon) {
+ if (hasPreconditioner()) {
+ output.write("\tUsing user-supplied preconditioner\n");
+
+ if (ARKodeSetPreconditioner(arkode_mem, nullptr, arkode_s_pre)
+ != ARKLS_SUCCESS) {
+ throw BoutException("ARKodeSetPreconditioner failed\n");
+ }
+ } else {
+ output.write("\tUsing BBD preconditioner\n");
+
+ /// Get options
+ // Compute band_width_default from actually added fields, to allow for multiple
+ // Mesh objects
+ //
+ // Previous implementation was equivalent to:
+ // int MXSUB = mesh->xend - mesh->xstart + 1;
+ // int band_width_default = n3Dvars()*(MXSUB+2);
+ const int band_width_default = std::accumulate(
+ begin(f3d), end(f3d), 0, [](int acc, const VarStr& fvar) {
+ Mesh* localmesh = fvar.var->getMesh();
+ return acc + localmesh->xend - localmesh->xstart + 3;
+ });
+
+ const auto mudq = (*options)["mudq"]
+ .doc("Upper half-bandwidth to be used in the difference "
+ "quotient Jacobian approximation")
+ .withDefault(band_width_default);
+ const auto mldq = (*options)["mldq"]
+ .doc("Lower half-bandwidth to be used in the difference "
+ "quotient Jacobian approximation")
+ .withDefault(band_width_default);
+ const auto mukeep = (*options)["mukeep"]
+ .doc("Upper half-bandwidth of the retained banded "
+ "approximate Jacobian block")
+ .withDefault(n3Dvars() + n2Dvars());
+ const auto mlkeep = (*options)["mlkeep"]
+ .doc("Lower half-bandwidth of the retained banded "
+ "approximate Jacobian block")
+ .withDefault(n3Dvars() + n2Dvars());
+
+ if (ARKBBDPrecInit(arkode_mem, local_N, mudq, mldq, mukeep, mlkeep, 0,
+ arkode_s_bbd_rhs, nullptr)
+ != ARKLS_SUCCESS) {
+ throw BoutException("ARKBBDPrecInit failed\n");
+ }
+ }
+ } else {
+ // Not using preconditioning
+ output.write("\tNo preconditioning\n");
+ }
+ }
+
+ /// Set Jacobian-vector multiplication function
+ output.write("\tUsing difference quotient approximation for Jacobian\n");
+ }
+
+ return 0;
+}
+
+/**************************************************************************
+ * Run - Advance time
+ **************************************************************************/
+
+int ArkodeMRISolver::run() {
+ TRACE("ArkodeMRISolver::run()");
+
+ if (!initialised) {
+ throw BoutException("ArkodeMRISolver not initialised\n");
+ }
+
+ for (int i = 0; i < getNumberOutputSteps(); i++) {
+
+ /// Run the solver for one output timestep
+ simtime = run(simtime + getOutputTimestep());
+
+ /// Check if the run succeeded
+ if (simtime < 0.0) {
+ // Step failed
+ output.write("Timestep failed. Aborting\n");
+
+ throw BoutException("ARKode timestep failed\n");
+ }
+
+ // Get additional diagnostics
+ long int temp_long_int, temp_long_int2;
+ ARKodeGetNumSteps(arkode_mem, &temp_long_int);
+ nsteps = int(temp_long_int);
+ MRIStepGetNumRhsEvals(arkode_mem, &temp_long_int, &temp_long_int2); //Change after the release
+ nfe_evals = int(temp_long_int);
+ nfi_evals = int(temp_long_int2);
+ if (treatment == MRI_Treatment::ImEx or treatment == MRI_Treatment::Implicit) {
+ ARKodeGetNumNonlinSolvIters(arkode_mem, &temp_long_int);
+ nniters = int(temp_long_int);
+ ARKodeGetNumPrecEvals(arkode_mem, &temp_long_int);
+ npevals = int(temp_long_int);
+ ARKodeGetNumLinIters(arkode_mem, &temp_long_int);
+ nliters = int(temp_long_int);
+ }
+
+ ARKodeGetNumSteps(inner_arkode_mem, &temp_long_int);
+ inner_nsteps = int(temp_long_int);
+ ARKStepGetNumRhsEvals(inner_arkode_mem, &temp_long_int, &temp_long_int2); //Change after the release
+ inner_nfe_evals = int(temp_long_int);
+ inner_nfi_evals = int(temp_long_int2);
+ if (inner_treatment == MRI_Treatment::ImEx or inner_treatment == MRI_Treatment::Implicit) {
+ ARKodeGetNumNonlinSolvIters(inner_arkode_mem, &temp_long_int);
+ inner_nniters = int(temp_long_int);
+ ARKodeGetNumPrecEvals(inner_arkode_mem, &temp_long_int);
+ inner_npevals = int(temp_long_int);
+ ARKodeGetNumLinIters(inner_arkode_mem, &temp_long_int);
+ inner_nliters = int(temp_long_int);
+ }
+
+ if (diagnose) {
+ output.write("\nARKODE: nsteps {:d}, nfe_evals {:d}, nfi_evals {:d}, nniters {:d}, "
+ "npevals {:d}, nliters {:d}\n",
+ nsteps, nfe_evals, nfi_evals, nniters, npevals, nliters);
+ if (treatment == MRI_Treatment::ImEx or treatment == MRI_Treatment::Implicit) {
+ output.write(" -> Newton iterations per step: {:e}\n",
+ static_cast(nniters) / static_cast(nsteps));
+ output.write(" -> Linear iterations per Newton iteration: {:e}\n",
+ static_cast(nliters) / static_cast(nniters));
+ output.write(" -> Preconditioner evaluations per Newton: {:e}\n",
+ static_cast(npevals) / static_cast(nniters));
+ }
+
+ output.write("\nARKODE Inner: inner_nsteps {:d}, inner_nfe_evals {:d}, inner_nfi_evals {:d}, inner_nniters {:d}, "
+ "inner_npevals {:d}, inner_nliters {:d}\n",
+ inner_nsteps, inner_nfe_evals, inner_nfi_evals, inner_nniters, inner_npevals, inner_nliters);
+ if (inner_treatment == MRI_Treatment::ImEx or inner_treatment == MRI_Treatment::Implicit) {
+ output.write(" -> Inner Newton iterations per step: {:e}\n",
+ static_cast(inner_nniters) / static_cast(inner_nsteps));
+ output.write(" -> Inner Linear iterations per Newton iteration: {:e}\n",
+ static_cast(inner_nliters) / static_cast(inner_nniters));
+ output.write(" -> Inner Preconditioner evaluations per Newton: {:e}\n",
+ static_cast(inner_npevals) / static_cast(inner_nniters));
+ }
+ }
+
+ if (call_monitors(simtime, i, getNumberOutputSteps())) {
+ // User signalled to quit
+ break;
+ }
+ }
+
+ return 0;
+}
+
+BoutReal ArkodeMRISolver::run(BoutReal tout) {
+ TRACE("Running solver: solver::run({:e})", tout);
+
+ bout::globals::mpi->MPI_Barrier(BoutComm::get());
+
+ pre_Wtime_s = 0.0;
+ pre_ncalls_s = 0;
+
+ int flag;
+ if (!monitor_timestep) {
+ // Run in normal mode
+ flag = ARKodeEvolve(arkode_mem, tout, uvec, &simtime, ARK_NORMAL);
+ } else {
+ // Run in single step mode, to call timestep monitors
+ BoutReal internal_time;
+ ARKodeGetCurrentTime(arkode_mem, &internal_time);
+ while (internal_time < tout) {
+ // Run another step
+ const BoutReal last_time = internal_time;
+ flag = ARKodeEvolve(arkode_mem, tout, uvec, &internal_time, ARK_ONE_STEP);
+
+ if (flag != ARK_SUCCESS) {
+ output_error.write("ERROR ARKODE solve failed at t = {:e}, flag = {:d}\n",
+ internal_time, flag);
+ return -1.0;
+ }
+
+ // Call timestep monitor
+ call_timestep_monitors(internal_time, internal_time - last_time);
+ }
+ // Get output at the desired time
+ flag = ARKodeGetDky(arkode_mem, tout, 0, uvec);
+ simtime = tout;
+ }
+
+ // Copy variables
+ load_vars(N_VGetArrayPointer(uvec));
+ // Call rhs function to get extra variables at this time
+ run_rhs(simtime);
+ // run_diffusive(simtime);
+ if (flag != ARK_SUCCESS) {
+ output_error.write("ERROR ARKODE solve failed at t = {:e}, flag = {:d}\n", simtime,
+ flag);
+ return -1.0;
+ }
+
+ return simtime;
+}
+
+/**************************************************************************
+ * Explicit RHS function du = F^s_E(t, u)
+ **************************************************************************/
+
+void ArkodeMRISolver::rhs_se(BoutReal t, BoutReal* udata, BoutReal* dudata) {
+ TRACE("Running RHS: ArkodeMRISolver::rhs_e({:e})", t);
+
+ // Load state from udata
+ load_vars(udata);
+
+ // Get the current timestep
+ // Note: ARKodeGetCurrentStep updated too late in older versions
+ ARKodeGetLastStep(arkode_mem, &hcur);
+
+ // Call RHS function
+ run_rhs_se(t);
+
+ // Save derivatives to dudata
+ save_derivs(dudata);
+}
+
+/**************************************************************************
+ * Implicit RHS function du = F^s_I(t, u)
+ **************************************************************************/
+
+void ArkodeMRISolver::rhs_si(BoutReal t, BoutReal* udata, BoutReal* dudata) {
+ TRACE("Running RHS: ArkodeMRISolver::rhs_si({:e})", t);
+
+ load_vars(udata);
+ ARKodeGetLastStep(arkode_mem, &hcur);
+ // Call Implicit RHS function
+ run_rhs_si(t);
+ save_derivs(dudata);
+}
+
+/**************************************************************************
+ * Explicit RHS function du = F^f_E(t, u)
+ **************************************************************************/
+
+void ArkodeMRISolver::rhs_fe(BoutReal t, BoutReal* udata, BoutReal* dudata) {
+ TRACE("Running RHS: ArkodeMRISolver::rhs_e({:e})", t);
+
+ // Load state from udata
+ load_vars(udata);
+
+ // Get the current timestep
+ // Note: ARKodeGetCurrentStep updated too late in older versions
+ ARKodeGetLastStep(arkode_mem, &hcur);
+
+ // Call RHS function
+ run_rhs_fe(t);
+
+ // Save derivatives to dudata
+ save_derivs(dudata);
+}
+
+/**************************************************************************
+ * Implicit RHS function du = F^f_I(t, u)
+ **************************************************************************/
+
+void ArkodeMRISolver::rhs_fi(BoutReal t, BoutReal* udata, BoutReal* dudata) {
+ TRACE("Running RHS: ArkodeMRISolver::rhs_si({:e})", t);
+
+ load_vars(udata);
+ ARKodeGetLastStep(arkode_mem, &hcur);
+ // Call Implicit RHS function
+ run_rhs_fi(t);
+ save_derivs(dudata);
+}
+
+/**************************************************************************
+ * Slow RHS function du = F^s(t, u)
+ **************************************************************************/
+
+void ArkodeMRISolver::rhs_s(BoutReal t, BoutReal* udata, BoutReal* dudata) {
+ TRACE("Running RHS: ArkodeMRISolver::rhs_e({:e})", t);
+
+ // Load state from udata
+ load_vars(udata);
+
+ // Get the current timestep
+ // Note: ARKodeGetCurrentStep updated too late in older versions
+ ARKodeGetLastStep(arkode_mem, &hcur);
+
+ // Call RHS function
+ // run_rhs_s(t);
+ run_rhs_s(t);
+
+ // Save derivatives to dudata
+ save_derivs(dudata);
+}
+
+/**************************************************************************
+ * Fast RHS function du = F^f(t, u)
+ **************************************************************************/
+
+void ArkodeMRISolver::rhs_f(BoutReal t, BoutReal* udata, BoutReal* dudata) {
+ TRACE("Running RHS: ArkodeMRISolver::rhs_e({:e})", t);
+
+ // Load state from udata
+ load_vars(udata);
+
+ // Get the current timestep
+ // Note: ARKodeGetCurrentStep updated too late in older versions
+ ARKodeGetLastStep(arkode_mem, &hcur);
+
+ // Call RHS function
+ // run_rhs_f(t);
+ run_rhs_f(t);
+
+ // Save derivatives to dudata
+ save_derivs(dudata);
+}
+
+/**************************************************************************
+ * Preconditioner functions
+ **************************************************************************/
+
+void ArkodeMRISolver::pre_s(BoutReal t, BoutReal gamma, BoutReal delta, BoutReal* udata,
+ BoutReal* rvec, BoutReal* zvec) {
+ TRACE("Running preconditioner: ArkodeMRISolver::pre({:e})", t);
+
+ const BoutReal tstart = bout::globals::mpi->MPI_Wtime();
+
+ if (!hasPreconditioner()) {
+ // Identity (but should never happen)
+ const auto length = N_VGetLocalLength_Parallel(uvec);
+ std::copy(rvec, rvec + length, zvec);
+ return;
+ }
+
+ // Load state from udata (as with res function)
+ load_vars(udata);
+
+ // Load vector to be inverted into F_vars
+ load_derivs(rvec);
+
+ runPreconditioner(t, gamma, delta);
+
+ // Save the solution from F_vars
+ save_derivs(zvec);
+
+ pre_Wtime_s += bout::globals::mpi->MPI_Wtime() - tstart;
+ pre_ncalls_s++;
+}
+
+void ArkodeMRISolver::pre_f(BoutReal t, BoutReal gamma, BoutReal delta, BoutReal* udata,
+ BoutReal* rvec, BoutReal* zvec) {
+ TRACE("Running preconditioner: ArkodeMRISolver::pre({:e})", t);
+
+ const BoutReal tstart = bout::globals::mpi->MPI_Wtime();
+
+ if (!hasPreconditioner()) {
+ // Identity (but should never happen)
+ const auto length = N_VGetLocalLength_Parallel(uvec);
+ std::copy(rvec, rvec + length, zvec);
+ return;
+ }
+
+ // Load state from udata (as with res function)
+ load_vars(udata);
+
+ // Load vector to be inverted into F_vars
+ load_derivs(rvec);
+
+ runPreconditioner(t, gamma, delta);
+
+ // Save the solution from F_vars
+ save_derivs(zvec);
+
+ pre_Wtime_s += bout::globals::mpi->MPI_Wtime() - tstart;
+ pre_ncalls_s++;
+}
+
+/**************************************************************************
+ * Jacobian-vector multiplication functions
+ **************************************************************************/
+
+void ArkodeMRISolver::jac_s(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jvdata) {
+ TRACE("Running Jacobian: ArkodeMRISolver::jac({:e})", t);
+
+ if (not hasJacobian()) {
+ throw BoutException("No jacobian function supplied!\n");
+ }
+
+ // Load state from ydate
+ load_vars(ydata);
+
+ // Load vector to be multiplied into F_vars
+ load_derivs(vdata);
+
+ // Call function
+ runJacobian(t);
+
+ // Save Jv from vars
+ save_derivs(Jvdata);
+}
+
+void ArkodeMRISolver::jac_f(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jvdata) {
+ TRACE("Running Jacobian: ArkodeMRISolver::jac({:e})", t);
+
+ if (not hasJacobian()) {
+ throw BoutException("No jacobian function supplied!\n");
+ }
+
+ // Load state from ydate
+ load_vars(ydata);
+
+ // Load vector to be multiplied into F_vars
+ load_derivs(vdata);
+
+ // Call function
+ runJacobian(t);
+
+ // Save Jv from vars
+ save_derivs(Jvdata);
+}
+
+/**************************************************************************
+ * ARKODE explicit RHS functions
+ **************************************************************************/
+
+// NOLINTBEGIN(readability-identifier-length)
+namespace {
+int arkode_rhs_s_explicit(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
+
+ BoutReal* udata = N_VGetArrayPointer(u);
+ BoutReal* dudata = N_VGetArrayPointer(du);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate RHS function
+ try {
+ s->rhs_se(t, udata, dudata);
+ } catch (BoutRhsFail& error) {
+ return 1;
+ }
+ return 0;
+}
+
+int arkode_rhs_s_implicit(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
+
+ BoutReal* udata = N_VGetArrayPointer(u);
+ BoutReal* dudata = N_VGetArrayPointer(du);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate RHS function
+ try {
+ s->rhs_si(t, udata, dudata);
+ } catch (BoutRhsFail& error) {
+ return 1;
+ }
+ return 0;
+}
+
+int arkode_rhs_f_explicit(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
+
+ BoutReal* udata = N_VGetArrayPointer(u);
+ BoutReal* dudata = N_VGetArrayPointer(du);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate RHS function
+ try {
+ s->rhs_fe(t, udata, dudata);
+ } catch (BoutRhsFail& error) {
+ return 1;
+ }
+ return 0;
+}
+
+int arkode_rhs_f_implicit(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
+
+ BoutReal* udata = N_VGetArrayPointer(u);
+ BoutReal* dudata = N_VGetArrayPointer(du);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate RHS function
+ try {
+ s->rhs_fi(t, udata, dudata);
+ } catch (BoutRhsFail& error) {
+ return 1;
+ }
+ return 0;
+}
+
+int arkode_s_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
+
+ BoutReal* udata = N_VGetArrayPointer(u);
+ BoutReal* dudata = N_VGetArrayPointer(du);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate RHS function
+ try {
+ s->rhs_s(t, udata, dudata);
+ } catch (BoutRhsFail& error) {
+ return 1;
+ }
+ return 0;
+}
+
+int arkode_f_rhs(BoutReal t, N_Vector u, N_Vector du, void* user_data) {
+
+ BoutReal* udata = N_VGetArrayPointer(u);
+ BoutReal* dudata = N_VGetArrayPointer(du);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate RHS function
+ try {
+ s->rhs_f(t, udata, dudata);
+ } catch (BoutRhsFail& error) {
+ return 1;
+ }
+ return 0;
+}
+
+/// RHS function for BBD preconditioner
+int arkode_s_bbd_rhs(sunindextype UNUSED(Nlocal), BoutReal t, N_Vector u, N_Vector du,
+ void* user_data) {
+ return arkode_rhs_s_implicit(t, u, du, user_data);
+}
+
+int arkode_f_bbd_rhs(sunindextype UNUSED(Nlocal), BoutReal t, N_Vector u, N_Vector du,
+ void* user_data) {
+ return arkode_rhs_f_implicit(t, u, du, user_data);
+}
+
+/// Preconditioner function
+int arkode_s_pre(BoutReal t, N_Vector yy, N_Vector UNUSED(yp), N_Vector rvec, N_Vector zvec,
+ BoutReal gamma, BoutReal delta, int UNUSED(lr), void* user_data) {
+ BoutReal* udata = N_VGetArrayPointer(yy);
+ BoutReal* rdata = N_VGetArrayPointer(rvec);
+ BoutReal* zdata = N_VGetArrayPointer(zvec);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate residuals
+ s->pre_s(t, gamma, delta, udata, rdata, zdata);
+
+ return 0;
+}
+
+int arkode_f_pre(BoutReal t, N_Vector yy, N_Vector UNUSED(yp), N_Vector rvec, N_Vector zvec,
+ BoutReal gamma, BoutReal delta, int UNUSED(lr), void* user_data) {
+ BoutReal* udata = N_VGetArrayPointer(yy);
+ BoutReal* rdata = N_VGetArrayPointer(rvec);
+ BoutReal* zdata = N_VGetArrayPointer(zvec);
+
+ auto* s = static_cast(user_data);
+
+ // Calculate residuals
+ s->pre_f(t, gamma, delta, udata, rdata, zdata);
+
+ return 0;
+}
+
+} // namespace
+// NOLINTEND(readability-identifier-length)
+
+/**************************************************************************
+ * vector abstol functions
+ **************************************************************************/
+
+void ArkodeMRISolver::set_abstol_values(BoutReal* abstolvec_data,
+ std::vector& f2dtols,
+ std::vector& f3dtols) {
+ int p = 0; // Counter for location in abstolvec_data array
+
+ // All boundaries
+ for (const auto& i2d : bout::globals::mesh->getRegion2D("RGN_BNDRY")) {
+ loop_abstol_values_op(i2d, abstolvec_data, p, f2dtols, f3dtols, true);
+ }
+ // Bulk of points
+ for (const auto& i2d : bout::globals::mesh->getRegion2D("RGN_NOBNDRY")) {
+ loop_abstol_values_op(i2d, abstolvec_data, p, f2dtols, f3dtols, false);
+ }
+}
+
+void ArkodeMRISolver::loop_abstol_values_op(Ind2D UNUSED(i2d), BoutReal* abstolvec_data,
+ int& p, std::vector& f2dtols,
+ std::vector& f3dtols, bool bndry) {
+ // Loop over 2D variables
+ for (std::vector::size_type i = 0; i < f2dtols.size(); i++) {
+ if (bndry && !f2d[i].evolve_bndry) {
+ continue;
+ }
+ abstolvec_data[p] = f2dtols[i];
+ p++;
+ }
+
+ for (int jz = 0; jz < bout::globals::mesh->LocalNz; jz++) {
+ // Loop over 3D variables
+ for (std::vector::size_type i = 0; i < f3dtols.size(); i++) {
+ if (bndry && !f3d[i].evolve_bndry) {
+ continue;
+ }
+ abstolvec_data[p] = f3dtols[i];
+ p++;
+ }
+ }
+}
+
+#endif
diff --git a/src/solver/impls/arkode/arkode_mri.hxx b/src/solver/impls/arkode/arkode_mri.hxx
new file mode 100644
index 0000000000..80c9c51ac8
--- /dev/null
+++ b/src/solver/impls/arkode/arkode_mri.hxx
@@ -0,0 +1,176 @@
+/**************************************************************************
+ * Interface to ARKODE MRI solver
+ * NOTE: ARKode is currently in beta testing so use with cautious optimism
+ *
+ * NOTE: Only one solver can currently be compiled in
+ *
+ **************************************************************************
+ * Copyright 2010-2024 BOUT++ contributors
+ *
+ * Contact: Ben Dudson, dudson2@llnl.gov
+ *
+ * This file is part of BOUT++.
+ *
+ * BOUT++ is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU Lesser General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * BOUT++ is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public License
+ * along with BOUT++. If not, see .
+ *
+ **************************************************************************/
+
+#ifndef BOUT_ARKODE_MRI_SOLVER_H
+#define BOUT_ARKODE_MRI_SOLVER_H
+
+#include "bout/build_config.hxx"
+#include "bout/solver.hxx"
+
+#if not BOUT_HAS_ARKODE
+
+namespace {
+RegisterUnavailableSolver
+ registerunavailablearkodemri("arkode_mri", "BOUT++ was not configured with ARKODE/SUNDIALS");
+}
+
+#else
+
+#include "bout/bout_enum_class.hxx"
+#include "bout/bout_types.hxx"
+#include "bout/sundials_backports.hxx"
+
+#include
+#include
+#include
+
+#include
+
+#include
+
+class ArkodeMRISolver;
+class Options;
+
+namespace {
+RegisterSolver registersolverarkodemri("arkode_mri");
+}
+
+// enum describing treatment of equations
+// Note: Capitalized because `explicit` is a C++ reserved keyword
+BOUT_ENUM_CLASS(MRI_Treatment, ImEx, Implicit, Explicit);
+
+// Adaptivity method
+BOUT_ENUM_CLASS(MRI_AdapMethod, PID, PI, I, Explicit_Gustafsson, Implicit_Gustafsson,
+ ImEx_Gustafsson);
+
+class ArkodeMRISolver : public Solver {
+public:
+ explicit ArkodeMRISolver(Options* opts = nullptr);
+ ~ArkodeMRISolver();
+
+ BoutReal getCurrentTimestep() override { return hcur; }
+
+ int init() override;
+
+ int run() override;
+ BoutReal run(BoutReal tout);
+
+ // These functions used internally (but need to be public)
+ void rhs_se(BoutReal t, BoutReal* udata, BoutReal* dudata);
+ void rhs_si(BoutReal t, BoutReal* udata, BoutReal* dudata);
+ void rhs_fe(BoutReal t, BoutReal* udata, BoutReal* dudata);
+ void rhs_fi(BoutReal t, BoutReal* udata, BoutReal* dudata);
+ void rhs_s(BoutReal t, BoutReal* udata, BoutReal* dudata);
+ void rhs_f(BoutReal t, BoutReal* udata, BoutReal* dudata);
+ void pre_s(BoutReal t, BoutReal gamma, BoutReal delta, BoutReal* udata, BoutReal* rvec,
+ BoutReal* zvec);
+ void pre_f(BoutReal t, BoutReal gamma, BoutReal delta, BoutReal* udata, BoutReal* rvec,
+ BoutReal* zvec);
+ void jac_s(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jvdata);
+ void jac_f(BoutReal t, BoutReal* ydata, BoutReal* vdata, BoutReal* Jvdata);
+
+private:
+ BoutReal hcur; //< Current internal timestep
+
+ bool diagnose{false}; //< Output additional diagnostics
+
+ N_Vector uvec{nullptr}; //< Values
+ void* arkode_mem{nullptr}; //< ARKODE internal memory block
+ void* inner_arkode_mem{nullptr}; //< ARKODE internal memory block
+ MRIStepInnerStepper inner_stepper{nullptr}; //< inner stepper
+
+ BoutReal pre_Wtime_s{0.0}; //< Time in preconditioner
+ BoutReal pre_Wtime_f{0.0}; //< Time in preconditioner
+ int pre_ncalls_s{0}; //< Number of calls to preconditioner
+ int pre_ncalls_f{0}; //< Number of calls to preconditioner
+
+ /// Maximum number of steps to take between outputs
+ int mxsteps;
+ /// Integrator treatment enum: IMEX, Implicit or Explicit
+ MRI_Treatment treatment;
+ MRI_Treatment inner_treatment;
+ /// Use linear implicit solver (only evaluates jacobian inversion once)
+ bool set_linear;
+ bool inner_set_linear;
+ /// Solve explicit portion in fixed timestep mode. NOTE: This is not recommended except
+ /// for code comparison
+ bool fixed_step;
+ /// Order of the internal step
+ int order;
+ /// Timestep adaptivity function
+ MRI_AdapMethod adap_method;
+ /// Absolute tolerance
+ BoutReal abstol;
+ /// Relative tolerance
+ BoutReal reltol;
+ /// Use separate absolute tolerance for each field
+ bool use_vector_abstol;
+ /// Maximum timestep (only used if greater than zero)
+ bool use_precon;
+ bool inner_use_precon;
+ /// Number of Krylov basis vectors to use
+ int maxl;
+ int inner_maxl;
+ /// Use right preconditioning instead of left preconditioning
+ bool rightprec;
+
+ // Diagnostics from ARKODE MRI
+ int nsteps{0};
+ int nfe_evals{0};
+ int nfi_evals{0};
+ int nniters{0};
+ int npevals{0};
+ int nliters{0};
+ int inner_nsteps{0};
+ int inner_nfe_evals{0};
+ int inner_nfi_evals{0};
+ int inner_nniters{0};
+ int inner_npevals{0};
+ int inner_nliters{0};
+
+ void set_abstol_values(BoutReal* abstolvec_data, std::vector& f2dtols,
+ std::vector& f3dtols);
+ void loop_abstol_values_op(Ind2D i2d, BoutReal* abstolvec_data, int& p,
+ std::vector& f2dtols,
+ std::vector& f3dtols, bool bndry);
+
+ /// SPGMR solver structure
+ SUNLinearSolver sun_solver{nullptr};
+ SUNLinearSolver inner_sun_solver{nullptr};
+ /// Solver for implicit stages
+ SUNNonlinearSolver nonlinear_solver{nullptr};
+ SUNNonlinearSolver inner_nonlinear_solver{nullptr};
+ /// Timestep controller
+ SUNAdaptController controller{nullptr};
+ SUNAdaptController inner_controller{nullptr};
+ /// Context for SUNDIALS memory allocations
+ sundials::Context suncontext;
+};
+
+#endif // BOUT_HAS_ARKODE
+#endif // BOUT_ARKODE_MRI_SOLVER_H
diff --git a/tests/integrated/test-kpr_mri/CMakeLists.txt b/tests/integrated/test-kpr_mri/CMakeLists.txt
new file mode 100644
index 0000000000..66a4474bb0
--- /dev/null
+++ b/tests/integrated/test-kpr_mri/CMakeLists.txt
@@ -0,0 +1,2 @@
+bout_add_integrated_test(test_kpr_mri SOURCES test_kpr_mri.cxx
+REQUIRES BOUT_HAS_SUNDIALS)
diff --git a/tests/integrated/test-kpr_mri/README.md b/tests/integrated/test-kpr_mri/README.md
new file mode 100644
index 0000000000..ffcef3ded9
--- /dev/null
+++ b/tests/integrated/test-kpr_mri/README.md
@@ -0,0 +1,37 @@
+test-kpr_mri
+===========
+
+ Multirate nonlinear Kvaerno-Prothero-Robinson ODE test problem:
+
+ [f]' = [ G e ] [(-1+f^2-r)/(2f)] + [ r'(t)/(2f) ]
+ [g] [ e -1 ] [(-2+g^2-s)/(2g)] [ s'(t)/(2*sqrt(2+s(t))) ]
+ = [ fs(t,f,g) ]
+ [ ff(t,f,g) ]
+
+ where r(t) = 0.5 * cos(t), s(t) = cos(w * t), 0 < t < 5.
+
+ This problem has analytical solution given by
+ f(t) = sqrt(1+r(t)), g(t) = sqrt(2+s(t)).
+
+ We use the parameters:
+ e = 0.5 (fast/slow coupling strength) [default]
+ G = -100 (stiffness at slow time scale) [default]
+ w = 100 (time-scale separation factor) [default]
+
+ The stiffness of the slow time scale is essentially determined
+ by G, for |G| > 50 it is 'stiff' and ideally suited to a
+ multirate method that is implicit at the slow time scale.
+
+MRI implementations of the functions are as follows:
+
+The slow explicit RHS function:
+ [-0.5 * sin(t)/(2 * f)]
+ [ 0 ]
+
+The slow implicit RHS function:
+ [G e] * [(-1 + f^2 - 0.5 * cos(t))/(2 * f) ]
+ [0 0] [(-2 + g^2 - cos(w * t))/(2 * g) ]
+
+The fast implicit RHS function:
+ [0 0] * [(-1 + f^2 - 0.5 * cos(t))/(2 * f)] + [ 0 ]
+ [e -1] [(-2 + g^2 - cos(w * t))/(2 * g) ] [-w * sin(w * t)/(2 * sqrt(2 + cos(w*t)))]
diff --git a/tests/integrated/test-kpr_mri/makefile b/tests/integrated/test-kpr_mri/makefile
new file mode 100644
index 0000000000..84db76e5c3
--- /dev/null
+++ b/tests/integrated/test-kpr_mri/makefile
@@ -0,0 +1,5 @@
+BOUT_TOP = ../../..
+
+SOURCEC = test_kpr_mri.cxx
+
+include $(BOUT_TOP)/make.config
diff --git a/tests/integrated/test-kpr_mri/runtest b/tests/integrated/test-kpr_mri/runtest
new file mode 100755
index 0000000000..47f335dc94
--- /dev/null
+++ b/tests/integrated/test-kpr_mri/runtest
@@ -0,0 +1,23 @@
+#!/usr/bin/env python3
+
+# requires: petsc
+
+from boututils.run_wrapper import shell_safe, launch_safe
+
+from sys import exit
+
+nthreads = 1
+nproc = 1
+
+print("Making solver test")
+shell_safe("make > make.log")
+
+print("Running solver test")
+status, out = launch_safe("./test_kpr_mri", nproc=nproc, mthread=nthreads, pipe=True)
+with open("run.log", "w") as f:
+ f.write(out)
+
+if status:
+ print(out)
+
+exit(status)
diff --git a/tests/integrated/test-kpr_mri/test_kpr_mri.cxx b/tests/integrated/test-kpr_mri/test_kpr_mri.cxx
new file mode 100644
index 0000000000..c071f25ef2
--- /dev/null
+++ b/tests/integrated/test-kpr_mri/test_kpr_mri.cxx
@@ -0,0 +1,183 @@
+#include "bout/physicsmodel.hxx"
+#include "bout/solver.hxx"
+
+#include
+#include
+#include
+#include
+
+// A simple phyics model with a manufactured true solution
+//
+class TestSolver : public PhysicsModel {
+public:
+ Field3D f, g;
+
+ BoutReal e = 0.5; /* fast/slow coupling strength */
+ BoutReal G = -100.0; /* stiffness at slow time scale */
+ BoutReal w = 100.0; /* time-scale separation factor */
+
+ int init(bool UNUSED(restarting)) override {
+ solver->add(f, "f");
+ solver->add(g, "g");
+
+ f = sqrt(3.0/2.0);
+ g = sqrt(3.0);
+
+ setSplitOperatorMRI();
+
+ return 0;
+ }
+
+ int rhs_se(BoutReal t) override {
+ /* fill in the slow explicit RHS function:
+ [-0.5*sin(t)/(2*f)]
+ [ 0 ] */
+ ddt(f) = -0.5*sin(t)/(2.0*f(1,1,0));
+ ddt(g) = 0.0;
+
+ return 0;
+ }
+
+ int rhs_si(BoutReal t) override {
+ /* fill in the slow implicit RHS function:
+ [G e]*[(-1+f^2-0.5*cos(t))/(2*f)]
+ [0 0] [(-2+g^2-cos(w*t))/(2*g) ] */
+ BoutReal tmp1 = (-1.0 + f(1,1,0) * f(1,1,0) - 0.5*cos(t)) / (2.0 * f(1,1,0));
+ BoutReal tmp2 = (-2.0 + g(1,1,0) * g(1,1,0) - cos(w*t)) / (2.0 * g(1,1,0));
+ ddt(f) = G * tmp1 + e * tmp2;
+ ddt(g) = 0.0;
+
+ return 0;
+ }
+
+ int rhs_fe(BoutReal UNUSED(t)) override {
+
+ ddt(f) = 0.0;
+ ddt(g) = 0.0;
+
+ return 0;
+ }
+
+ int rhs_fi(BoutReal t) override {
+ /* fill in the fast implicit RHS function:
+ [0 0]*[(-1+f^2-0.5*cos(t))/(2*f)] + [ 0 ]
+ [e -1] [(-2+g^2-cos(w*t))/(2*g) ] [-w*sin(w*t)/(2*sqrt(2+cos(w*t)))] */
+ BoutReal tmp1 = (-1.0 + f(1,1,0) * f(1,1,0) - 0.5*cos(t)) / (2.0 * f(1,1,0));
+ BoutReal tmp2 = (-2.0 + g(1,1,0) * g(1,1,0) - cos(w*t)) / (2.0 * g(1,1,0));
+ ddt(f) = 0.0;
+ ddt(g) = e * tmp1 - tmp2 - w * sin(w*t) / (2.0 * sqrt(2.0 + cos(w * t)));
+
+ return 0;
+ }
+
+ int rhs_s(BoutReal t) override {
+ /* fill in the RHS function:
+ [G e]*[(-1+f^2-0.5*cos(t))/(2*f)] + [-0.5*sin(t)/(2*f)]
+ [0 0] [(-2+g^2-cos(w*t))/(2*g) ] [ 0 ] */
+ BoutReal tmp1 = (-1.0 + f(1,1,0) * f(1,1,0) - 0.5*cos(t)) / (2.0 * f(1,1,0));
+ BoutReal tmp2 = (-2.0 + g(1,1,0) * g(1,1,0) - cos(w*t)) / (2.0 * g(1,1,0));
+ ddt(f) = G * tmp1 + e * tmp2 - 0.5*sin(t) / (2.0 * f(1,1,0));
+ ddt(g) = 0.0;
+
+ return 0;
+ }
+
+ int rhs_f(BoutReal t) override {
+ /* fill in the RHS function:
+ [0 0]*[(-1+f^2-0.5*cos(t))/(2*f)] + [ 0 ]
+ [e -1] [(-2+g^2-cos(w*t))/(2*g) ] [-w*sin(w*t)/(2*sqrt(2+cos(w*t)))] */
+ BoutReal tmp1 = (-1.0 + f(1,1,0) * f(1,1,0) - 0.5*cos(t)) / (2.0 * f(1,1,0));
+ BoutReal tmp2 = (-2.0 + g(1,1,0) * g(1,1,0) - cos(w*t)) / (2.0 * g(1,1,0));
+ ddt(f) = 0.0;
+ ddt(g) = e * tmp1 - tmp2 - w * sin(w*t) / (2.0 * sqrt(2.0 + cos(w * t)));
+
+ return 0;
+ }
+
+ // int rhs(BoutReal t) override {
+ // /* fill in the RHS function:
+ // [G e]*[(-1+f^2-0.5*cos(t))/(2*f)] + [-0.5*sin(t)/(2*f) ]
+ // [e -1] [(-2+g^2-cos(w*t))/(2*g) ] [-w*sin(w*t)/(2*sqrt(2+cos(w*t)))] */
+ // BoutReal tmp1 = (-1.0 + f(1,1,0) * f(1,1,0) - 0.5*cos(t)) / (2.0 * f(1,1,0));
+ // BoutReal tmp2 = (-2.0 + g(1,1,0) * g(1,1,0) - cos(w*t)) / (2.0 * g(1,1,0));
+
+ // ddt(f) = G * tmp1 + e * tmp2 - 0.5*sin(t) / (2.0 * f(1,1,0));
+ // ddt(g) = e * tmp1 - tmp2 - w * sin(w*t) / (2.0 * sqrt(2.0 + cos(w * t)));
+
+ // return 0;
+ // }
+
+ bool check_solution(BoutReal atol, BoutReal t) {
+ // Return true if correct solution
+ return ((std::abs(sqrt(0.5*cos(t) + 1.0) - f(1,1,0)) < atol) and (std::abs(sqrt(cos(w*t) + 2.0) - g(1,1,0)) < atol));
+ }
+
+ BoutReal compute_error(BoutReal t)
+ {
+ /* Compute the error with the true solution:
+ f(t) = sqrt(0.5*cos(t) + 1.0)
+ g(t) = sqrt(cos(w*t) + 2.0) */
+ return sqrt( pow(sqrt(0.5*cos(t) + 1.0) - f(1,1,0), 2.0) +
+ pow(sqrt(cos(w*t) + 2.0) - g(1,1,0), 2.0));
+ }
+
+ // Don't need any restarting, or options to control data paths
+ int postInit(bool) override { return 0; }
+};
+
+int main(int argc, char** argv) {
+ // Absolute tolerance for difference between the actual value and the
+ // expected value
+ constexpr BoutReal tolerance = 1.e-5;
+
+ // Our own output to stdout, as main library will only be writing to log files
+ Output output_test;
+
+ auto& root = Options::root();
+
+ root["mesh"]["MXG"] = 1;
+ root["mesh"]["MYG"] = 1;
+ root["mesh"]["nx"] = 3;
+ root["mesh"]["ny"] = 1;
+ root["mesh"]["nz"] = 1;
+
+ root["output"]["enabled"] = false;
+ root["restart_files"]["enabled"] = false;
+
+ Solver::setArgs(argc, argv);
+ BoutComm::setArgs(argc, argv);
+
+ bout::globals::mpi = new MpiWrapper();
+
+ bout::globals::mesh = Mesh::create();
+ bout::globals::mesh->load();
+
+ // Global options
+ root["nout"] = 100;
+ root["timestep"] = 0.05;
+
+ // Get specific options section for this solver. Can't just use default
+ // "solver" section, as we run into problems when solvers use the same
+ // name for an option with inconsistent defaults
+ auto options = Options::getRoot()->getSection("arkode_mri");
+ auto solver = std::unique_ptr{Solver::create("arkode_mri", options)};
+
+ TestSolver model{};
+ solver->setModel(&model);
+
+ BoutMonitor bout_monitor{};
+ solver->addMonitor(&bout_monitor, Solver::BACK);
+
+ solver->solve();
+
+ BoutReal error = model.compute_error(5.0);
+
+ std::cout << "error = " << error << std::endl;
+
+ if (model.check_solution(tolerance, 5.0)) {
+ output_test << " PASSED\n";
+ return 0;
+ }
+ output_test << " FAILED\n";
+ return 1;
+}