diff --git a/examples/arkode/C_serial/ark_analytic.c b/examples/arkode/C_serial/ark_analytic.c index b67c1e602f..0ef18aed30 100644 --- a/examples/arkode/C_serial/ark_analytic.c +++ b/examples/arkode/C_serial/ark_analytic.c @@ -54,15 +54,21 @@ static int f(sunrealtype t, N_Vector y, N_Vector ydot, void* user_data); static int Jac(sunrealtype t, N_Vector y, N_Vector fy, SUNMatrix J, void* user_data, N_Vector tmp1, N_Vector tmp2, N_Vector tmp3); -/* >>> Function to access delta <<< */ -int Access(long int step, int stage, int iter, N_Vector delta, void* user_data) +/* >>> Function to access delta and linear residuals <<< */ +int Access(long int step, int stage, int iter, N_Vector delta, + N_Vector ls_res, N_Vector ls_res_r, void* user_data) { printf("Step %li = \n", step); printf("Stage %i = \n", stage); printf("Iter %i = \n", iter); printf("Delta:\n"); N_VPrintFile(delta, stdout); + printf("LS Res:\n"); + N_VPrintFile(ls_res, stdout); + printf("LS Res Relaxed:\n"); + N_VPrintFile(ls_res_r, stdout); printf("\n"); + return 0; } /* Private function to check function return values */ @@ -142,7 +148,7 @@ int main(void) if (check_flag(&flag, "ARKodeSetLinear", 1)) { return 1; } /* >>> Attach function to access delta <<< */ - flag = ARKStepSetAccessDeltaFn(arkode_mem, Access); + flag = ARKStepSetAccessFn(arkode_mem, Access); if (check_flag(&flag, "ARKStepSetAccessDeltaFn", 1)) { return 1; } /* Open output stream for results, output comment line */ diff --git a/include/arkode/arkode_arkstep.h b/include/arkode/arkode_arkstep.h index ab2d9e2369..afb3cbb4a4 100644 --- a/include/arkode/arkode_arkstep.h +++ b/include/arkode/arkode_arkstep.h @@ -30,8 +30,9 @@ extern "C" { #endif /* Callback to access delta */ -typedef int (*ARKStepAccessDeltaFn)(long int step, int stage, int iter, - N_Vector delta, void* user_data); +typedef int (*ARKStepAccessFn)(long int step, int stage, int iter, + N_Vector delta, N_Vector ls_res, + N_Vector ls_res2, void* user_data); /* ----------------- * ARKStep Constants @@ -106,7 +107,7 @@ SUNDIALS_EXPORT int ARKStepCreateMRIStepInnerStepper(void* arkode_mem, MRIStepInnerStepper* stepper); SUNDIALS_EXPORT -int ARKStepSetAccessDeltaFn(void* arkode_mem, ARKStepAccessDeltaFn access_fn); +int ARKStepSetAccessFn(void* arkode_mem, ARKStepAccessFn access_fn); /* -------------------------------------------------------------------------- * Deprecated Functions -- all are superseded by shared ARKODE-level routines diff --git a/src/arkode/arkode_arkstep_impl.h b/src/arkode/arkode_arkstep_impl.h index 6e8e3f4174..303f4b51c2 100644 --- a/src/arkode/arkode_arkstep_impl.h +++ b/src/arkode/arkode_arkstep_impl.h @@ -161,7 +161,7 @@ typedef struct ARKodeARKStepMemRec sunrealtype* stage_times; /* workspace for applying forcing */ sunrealtype* stage_coefs; /* workspace for applying forcing */ - ARKStepAccessDeltaFn access_fn; + ARKStepAccessFn access_fn; }* ARKodeARKStepMem; diff --git a/src/arkode/arkode_arkstep_io.c b/src/arkode/arkode_arkstep_io.c index 7f7f7c463f..396d12f171 100644 --- a/src/arkode/arkode_arkstep_io.c +++ b/src/arkode/arkode_arkstep_io.c @@ -2408,7 +2408,7 @@ int ARKStepGetNumRelaxSolveIters(void* arkode_mem, long int* iters) EOF ===============================================================*/ -int ARKStepSetAccessDeltaFn(void* arkode_mem, ARKStepAccessDeltaFn access_fn) +int ARKStepSetAccessFn(void* arkode_mem, ARKStepAccessFn access_fn) { ARKodeMem ark_mem; ARKodeARKStepMem step_mem; diff --git a/src/arkode/arkode_arkstep_nls.c b/src/arkode/arkode_arkstep_nls.c index df08411828..bd004a6d1a 100644 --- a/src/arkode/arkode_arkstep_nls.c +++ b/src/arkode/arkode_arkstep_nls.c @@ -26,13 +26,19 @@ int arkAccessDeltaFn(int iter, N_Vector delta, void* arkode_mem) { + int retval = 0; ARKodeMem ark_mem = (ARKodeMem)arkode_mem; ARKodeARKStepMem step_mem; - int retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem); + retval = arkStep_AccessStepMem(ark_mem, __func__, &step_mem); if (retval != ARK_SUCCESS) { return retval; } + ARKLsMem arkls_mem; + retval = arkLs_AccessLMem(ark_mem, __func__, &arkls_mem); + if (retval != ARK_SUCCESS) { return (retval); } + retval = step_mem->access_fn(ark_mem->nst, step_mem->istage, iter, delta, + arkls_mem->ytemp, arkls_mem->ytemp2, ark_mem->user_data); if (retval != ARK_SUCCESS) { return retval; } diff --git a/src/arkode/arkode_ls.c b/src/arkode/arkode_ls.c index 7032f62e40..056e6a257b 100644 --- a/src/arkode/arkode_ls.c +++ b/src/arkode/arkode_ls.c @@ -24,6 +24,8 @@ #include "arkode_impl.h" #include "arkode_ls_impl.h" +#include "sundials/sundials_matrix.h" +#include "sundials/sundials_nvector.h" /* constants */ #define MIN_INC_MULT SUN_RCONST(1000.0) @@ -264,6 +266,15 @@ int ARKodeSetLinearSolver(void* arkode_mem, SUNLinearSolver LS, SUNMatrix A) return (ARKLS_MEM_FAIL); } + if (!arkAllocVec(ark_mem, ark_mem->tempv1, &(arkls_mem->ytemp2))) + { + arkProcessError(ark_mem, ARKLS_MEM_FAIL, __LINE__, __func__, __FILE__, + MSG_LS_MEM_FAIL); + free(arkls_mem); + arkls_mem = NULL; + return (ARKLS_MEM_FAIL); + } + if (!arkAllocVec(ark_mem, ark_mem->tempv1, &(arkls_mem->x))) { arkProcessError(ark_mem, ARKLS_MEM_FAIL, __LINE__, __func__, __FILE__, @@ -3383,9 +3394,12 @@ int arkLsSolve(ARKodeMem ark_mem, N_Vector b, sunrealtype tnow, N_Vector ynow, } } - /* Call solver, and copy x to b */ + /* Call solver */ retval = SUNLinSolSolve(arkls_mem->LS, arkls_mem->A, arkls_mem->x, b, delta); - N_VScale(ONE, arkls_mem->x, b); + + /* compute the residual r = Ax - b */ + SUNMatMatvec(arkls_mem->savedJ, arkls_mem->x, arkls_mem->ytemp); + N_VLinearSum(ONE, arkls_mem->ytemp, -ONE, b, arkls_mem->ytemp); /* If using a direct or matrix-iterative solver, scale the correction to account for change in gamma (this is only beneficial if M==I) */ @@ -3399,9 +3413,16 @@ int arkLsSolve(ARKodeMem ark_mem, N_Vector b, sunrealtype tnow, N_Vector ynow, __FILE__, "An error occurred in ark_step_getgammas"); return (arkls_mem->last_flag); } - if (gamrat != ONE) { N_VScale(TWO / (ONE + gamrat), b, b); } + if (gamrat != ONE) { N_VScale(TWO / (ONE + gamrat), arkls_mem->x, arkls_mem->x); } } + /* compute the relaxed residual r = Ax - b */ + SUNMatMatvec(arkls_mem->savedJ, arkls_mem->x, arkls_mem->ytemp2); + N_VLinearSum(ONE, arkls_mem->ytemp2, -ONE, b, arkls_mem->ytemp2); + + /* copy x to b */ + N_VScale(ONE, arkls_mem->x, b); + /* Retrieve statistics from iterative linear solvers */ resnorm = ZERO; nli_inc = 0; @@ -3498,6 +3519,11 @@ int arkLsFree(ARKodeMem ark_mem) N_VDestroy(arkls_mem->ytemp); arkls_mem->ytemp = NULL; } + if (arkls_mem->ytemp2) + { + N_VDestroy(arkls_mem->ytemp2); + arkls_mem->ytemp2 = NULL; + } if (arkls_mem->x) { N_VDestroy(arkls_mem->x); diff --git a/src/arkode/arkode_ls_impl.h b/src/arkode/arkode_ls_impl.h index 8a296d83d5..f66f6de6a4 100644 --- a/src/arkode/arkode_ls_impl.h +++ b/src/arkode/arkode_ls_impl.h @@ -68,6 +68,7 @@ typedef struct ARKLsMemRec SUNMatrix A; /* A = M - gamma * df/dy */ SUNMatrix savedJ; /* savedJ = old Jacobian */ N_Vector ytemp; /* temp vector passed to jtimes and psolve */ + N_Vector ytemp2; N_Vector x; /* solution vector used by SUNLinearSolver */ N_Vector ycur; /* ptr to current y vector in ARKLs solve */ N_Vector fcur; /* ptr to current fcur = fI(tcur, ycur) */