Skip to content

Commit

Permalink
Add mode argument back to SUNStepper full RHS
Browse files Browse the repository at this point in the history
  • Loading branch information
Steven-Roberts committed Oct 29, 2024
1 parent 3a26598 commit ad41aa0
Show file tree
Hide file tree
Showing 10 changed files with 96 additions and 24 deletions.
16 changes: 14 additions & 2 deletions include/sundials/sundials_stepper.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@
extern "C" {
#endif

#define ARK_FULLRHS_START 0
#define ARK_FULLRHS_END 1
#define ARK_FULLRHS_OTHER 2

typedef enum
{
SUN_FULLRHS_START,
SUN_FULLRHS_END,
SUN_FULLRHS_OTHER
} SUNFullRhsMode;

typedef int (*SUNRhsJacFn)(sunrealtype t, N_Vector y, N_Vector fy,
SUNMatrix Jac, void* user_data, N_Vector tmp1,
N_Vector tmp2, N_Vector tmp3);
Expand All @@ -33,7 +44,8 @@ typedef SUNErrCode (*SUNStepperEvolveFn)(SUNStepper stepper, sunrealtype tout,
N_Vector vret, sunrealtype* tret);

typedef SUNErrCode (*SUNStepperFullRhsFn)(SUNStepper stepper, sunrealtype t,
N_Vector v, N_Vector f);
N_Vector v, N_Vector f,
SUNFullRhsMode mode);

typedef SUNErrCode (*SUNStepperResetFn)(SUNStepper stepper, sunrealtype tR,
N_Vector vR);
Expand All @@ -60,7 +72,7 @@ SUNErrCode SUNStepper_Evolve(SUNStepper stepper, sunrealtype tout,

SUNDIALS_EXPORT
SUNErrCode SUNStepper_FullRhs(SUNStepper stepper, sunrealtype t, N_Vector v,
N_Vector f);
N_Vector f, SUNFullRhsMode mode);

SUNDIALS_EXPORT
SUNErrCode SUNStepper_Reset(SUNStepper stepper, sunrealtype tR, N_Vector vR);
Expand Down
19 changes: 17 additions & 2 deletions src/arkode/arkode_mristep.c
Original file line number Diff line number Diff line change
Expand Up @@ -2890,10 +2890,25 @@ int mriStepInnerStepper_FullRhs(MRIStepInnerStepper stepper, sunrealtype t,

int mriStepInnerStepper_FullRhsSUNStepper(MRIStepInnerStepper stepper,
sunrealtype t, N_Vector y, N_Vector f,
SUNDIALS_MAYBE_UNUSED int mode)
SUNDIALS_MAYBE_UNUSED int ark_mode)
{
SUNStepper sunstepper = (SUNStepper)stepper->content;
SUNErrCode err = sunstepper->ops->fullrhs(sunstepper, t, y, f);

int mode;
switch (ark_mode)
{
case ARK_FULLRHS_START:
mode = SUN_FULLRHS_START;
break;
case ARK_FULLRHS_END:
mode = SUN_FULLRHS_END;
break;
default:
mode = SUN_FULLRHS_OTHER;
break;
}

SUNErrCode err = sunstepper->ops->fullrhs(sunstepper, t, y, f, mode);
stepper->last_flag = sunstepper->last_flag;
if (err != SUN_SUCCESS) { return ARK_SUNSTEPPER_ERR; }
return ARK_SUCCESS;
Expand Down
21 changes: 19 additions & 2 deletions src/arkode/arkode_sunstepper.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,32 @@ static SUNErrCode arkSUNStepperEvolve(SUNStepper stepper, sunrealtype tout,
----------------------------------------------------------------------------*/

static SUNErrCode arkSUNStepperFullRhs(SUNStepper stepper, sunrealtype t,
N_Vector y, N_Vector f)
N_Vector y, N_Vector f, SUNFullRhsMode mode)
{
SUNFunctionBegin(stepper->sunctx);
/* extract the ARKODE memory struct */
void* arkode_mem;
SUNCheckCall(SUNStepper_GetContent(stepper, &arkode_mem));
ARKodeMem ark_mem = (ARKodeMem)arkode_mem;

stepper->last_flag = ark_mem->step_fullrhs(ark_mem, t, y, f, ARK_FULLRHS_OTHER);
int ark_mode;
switch (mode)
{
case SUN_FULLRHS_START:
ark_mode = ARK_FULLRHS_START;
break;
case SUN_FULLRHS_END:
ark_mode = ARK_FULLRHS_END;
break;
case SUN_FULLRHS_OTHER:
ark_mode = ARK_FULLRHS_OTHER;
break;
default:
ark_mode = -1;
break;
}

stepper->last_flag = ark_mem->step_fullrhs(ark_mem, t, y, f, ark_mode);
if (stepper->last_flag != ARK_SUCCESS) { return SUN_ERR_OP_FAIL; }

return SUN_SUCCESS;
Expand Down
3 changes: 0 additions & 3 deletions src/arkode/fmod_int32/farkode_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ module farkode_mod
integer(C_INT), parameter, public :: ARK_ADAPT_EXP_GUS = 3_C_INT
integer(C_INT), parameter, public :: ARK_ADAPT_IMP_GUS = 4_C_INT
integer(C_INT), parameter, public :: ARK_ADAPT_IMEX_GUS = 5_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_START = 0_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_END = 1_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_OTHER = 2_C_INT
integer(C_INT), parameter, public :: ARK_INTERP_MAX_DEGREE = 5_C_INT
integer(C_INT), parameter, public :: ARK_INTERP_NONE = -1_C_INT
integer(C_INT), parameter, public :: ARK_INTERP_HERMITE = 0_C_INT
Expand Down
3 changes: 0 additions & 3 deletions src/arkode/fmod_int64/farkode_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,6 @@ module farkode_mod
integer(C_INT), parameter, public :: ARK_ADAPT_EXP_GUS = 3_C_INT
integer(C_INT), parameter, public :: ARK_ADAPT_IMP_GUS = 4_C_INT
integer(C_INT), parameter, public :: ARK_ADAPT_IMEX_GUS = 5_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_START = 0_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_END = 1_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_OTHER = 2_C_INT
integer(C_INT), parameter, public :: ARK_INTERP_MAX_DEGREE = 5_C_INT
integer(C_INT), parameter, public :: ARK_INTERP_NONE = -1_C_INT
integer(C_INT), parameter, public :: ARK_INTERP_HERMITE = 0_C_INT
Expand Down
6 changes: 4 additions & 2 deletions src/sundials/fmod_int32/fsundials_core_mod.c
Original file line number Diff line number Diff line change
Expand Up @@ -2688,19 +2688,21 @@ SWIGEXPORT int _wrap_FSUNStepper_Evolve(void *farg1, double const *farg2, N_Vect
}


SWIGEXPORT int _wrap_FSUNStepper_FullRhs(void *farg1, double const *farg2, N_Vector farg3, N_Vector farg4) {
SWIGEXPORT int _wrap_FSUNStepper_FullRhs(void *farg1, double const *farg2, N_Vector farg3, N_Vector farg4, int const *farg5) {
int fresult ;
SUNStepper arg1 = (SUNStepper) 0 ;
sunrealtype arg2 ;
N_Vector arg3 = (N_Vector) 0 ;
N_Vector arg4 = (N_Vector) 0 ;
SUNFullRhsMode arg5 ;
SUNErrCode result;

arg1 = (SUNStepper)(farg1);
arg2 = (sunrealtype)(*farg2);
arg3 = (N_Vector)(farg3);
arg4 = (N_Vector)(farg4);
result = (SUNErrCode)SUNStepper_FullRhs(arg1,arg2,arg3,arg4);
arg5 = (SUNFullRhsMode)(*farg5);
result = (SUNErrCode)SUNStepper_FullRhs(arg1,arg2,arg3,arg4,arg5);
fresult = (SUNErrCode)(result);
return fresult;
}
Expand Down
21 changes: 18 additions & 3 deletions src/sundials/fmod_int32/fsundials_core_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,17 @@ module fsundials_core_mod
public :: FSUNAdaptController_SetErrorBias
public :: FSUNAdaptController_UpdateH
public :: FSUNAdaptController_Space
integer(C_INT), parameter, public :: ARK_FULLRHS_START = 0_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_END = 1_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_OTHER = 2_C_INT
! typedef enum SUNFullRhsMode
enum, bind(c)
enumerator :: SUN_FULLRHS_START
enumerator :: SUN_FULLRHS_END
enumerator :: SUN_FULLRHS_OTHER
end enum
integer, parameter, public :: SUNFullRhsMode = kind(SUN_FULLRHS_START)
public :: SUN_FULLRHS_START, SUN_FULLRHS_END, SUN_FULLRHS_OTHER
public :: FSUNStepper_Create
public :: FSUNStepper_Destroy
public :: FSUNStepper_Evolve
Expand Down Expand Up @@ -2086,14 +2097,15 @@ function swigc_FSUNStepper_Evolve(farg1, farg2, farg3, farg4) &
integer(C_INT) :: fresult
end function

function swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4) &
function swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4, farg5) &
bind(C, name="_wrap_FSUNStepper_FullRhs") &
result(fresult)
use, intrinsic :: ISO_C_BINDING
type(C_PTR), value :: farg1
real(C_DOUBLE), intent(in) :: farg2
type(C_PTR), value :: farg3
type(C_PTR), value :: farg4
integer(C_INT), intent(in) :: farg5
integer(C_INT) :: fresult
end function

Expand Down Expand Up @@ -5010,25 +5022,28 @@ function FSUNStepper_Evolve(stepper, tout, vret, tret) &
swig_result = fresult
end function

function FSUNStepper_FullRhs(stepper, t, v, f) &
function FSUNStepper_FullRhs(stepper, t, v, f, mode) &
result(swig_result)
use, intrinsic :: ISO_C_BINDING
integer(C_INT) :: swig_result
type(C_PTR) :: stepper
real(C_DOUBLE), intent(in) :: t
type(N_Vector), target, intent(inout) :: v
type(N_Vector), target, intent(inout) :: f
integer(SUNFullRhsMode), intent(in) :: mode
integer(C_INT) :: fresult
type(C_PTR) :: farg1
real(C_DOUBLE) :: farg2
type(C_PTR) :: farg3
type(C_PTR) :: farg4
integer(C_INT) :: farg5

farg1 = stepper
farg2 = t
farg3 = c_loc(v)
farg4 = c_loc(f)
fresult = swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4)
farg5 = mode
fresult = swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4, farg5)
swig_result = fresult
end function

Expand Down
6 changes: 4 additions & 2 deletions src/sundials/fmod_int64/fsundials_core_mod.c
Original file line number Diff line number Diff line change
Expand Up @@ -2688,19 +2688,21 @@ SWIGEXPORT int _wrap_FSUNStepper_Evolve(void *farg1, double const *farg2, N_Vect
}


SWIGEXPORT int _wrap_FSUNStepper_FullRhs(void *farg1, double const *farg2, N_Vector farg3, N_Vector farg4) {
SWIGEXPORT int _wrap_FSUNStepper_FullRhs(void *farg1, double const *farg2, N_Vector farg3, N_Vector farg4, int const *farg5) {
int fresult ;
SUNStepper arg1 = (SUNStepper) 0 ;
sunrealtype arg2 ;
N_Vector arg3 = (N_Vector) 0 ;
N_Vector arg4 = (N_Vector) 0 ;
SUNFullRhsMode arg5 ;
SUNErrCode result;

arg1 = (SUNStepper)(farg1);
arg2 = (sunrealtype)(*farg2);
arg3 = (N_Vector)(farg3);
arg4 = (N_Vector)(farg4);
result = (SUNErrCode)SUNStepper_FullRhs(arg1,arg2,arg3,arg4);
arg5 = (SUNFullRhsMode)(*farg5);
result = (SUNErrCode)SUNStepper_FullRhs(arg1,arg2,arg3,arg4,arg5);
fresult = (SUNErrCode)(result);
return fresult;
}
Expand Down
21 changes: 18 additions & 3 deletions src/sundials/fmod_int64/fsundials_core_mod.f90
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,17 @@ module fsundials_core_mod
public :: FSUNAdaptController_SetErrorBias
public :: FSUNAdaptController_UpdateH
public :: FSUNAdaptController_Space
integer(C_INT), parameter, public :: ARK_FULLRHS_START = 0_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_END = 1_C_INT
integer(C_INT), parameter, public :: ARK_FULLRHS_OTHER = 2_C_INT
! typedef enum SUNFullRhsMode
enum, bind(c)
enumerator :: SUN_FULLRHS_START
enumerator :: SUN_FULLRHS_END
enumerator :: SUN_FULLRHS_OTHER
end enum
integer, parameter, public :: SUNFullRhsMode = kind(SUN_FULLRHS_START)
public :: SUN_FULLRHS_START, SUN_FULLRHS_END, SUN_FULLRHS_OTHER
public :: FSUNStepper_Create
public :: FSUNStepper_Destroy
public :: FSUNStepper_Evolve
Expand Down Expand Up @@ -2086,14 +2097,15 @@ function swigc_FSUNStepper_Evolve(farg1, farg2, farg3, farg4) &
integer(C_INT) :: fresult
end function

function swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4) &
function swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4, farg5) &
bind(C, name="_wrap_FSUNStepper_FullRhs") &
result(fresult)
use, intrinsic :: ISO_C_BINDING
type(C_PTR), value :: farg1
real(C_DOUBLE), intent(in) :: farg2
type(C_PTR), value :: farg3
type(C_PTR), value :: farg4
integer(C_INT), intent(in) :: farg5
integer(C_INT) :: fresult
end function

Expand Down Expand Up @@ -5010,25 +5022,28 @@ function FSUNStepper_Evolve(stepper, tout, vret, tret) &
swig_result = fresult
end function

function FSUNStepper_FullRhs(stepper, t, v, f) &
function FSUNStepper_FullRhs(stepper, t, v, f, mode) &
result(swig_result)
use, intrinsic :: ISO_C_BINDING
integer(C_INT) :: swig_result
type(C_PTR) :: stepper
real(C_DOUBLE), intent(in) :: t
type(N_Vector), target, intent(inout) :: v
type(N_Vector), target, intent(inout) :: f
integer(SUNFullRhsMode), intent(in) :: mode
integer(C_INT) :: fresult
type(C_PTR) :: farg1
real(C_DOUBLE) :: farg2
type(C_PTR) :: farg3
type(C_PTR) :: farg4
integer(C_INT) :: farg5

farg1 = stepper
farg2 = t
farg3 = c_loc(v)
farg4 = c_loc(f)
fresult = swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4)
farg5 = mode
fresult = swigc_FSUNStepper_FullRhs(farg1, farg2, farg3, farg4, farg5)
swig_result = fresult
end function

Expand Down
4 changes: 2 additions & 2 deletions src/sundials/sundials_stepper.c
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ SUNErrCode SUNStepper_Evolve(SUNStepper stepper, sunrealtype tout, N_Vector y,
}

SUNErrCode SUNStepper_FullRhs(SUNStepper stepper, sunrealtype t, N_Vector v,
N_Vector f)
N_Vector f, SUNFullRhsMode mode)
{
SUNFunctionBegin(stepper->sunctx);
if (stepper->ops->fullrhs) { return stepper->ops->fullrhs(stepper, t, v, f); }
if (stepper->ops->fullrhs) { return stepper->ops->fullrhs(stepper, t, v, f, mode); }
return SUN_ERR_NOT_IMPLEMENTED;
}

Expand Down

0 comments on commit ad41aa0

Please sign in to comment.