From 81b39bc1f058adae0ee72f0e2ec1e5b373fb1222 Mon Sep 17 00:00:00 2001 From: Nick Date: Mon, 22 Jul 2024 08:24:56 -0500 Subject: [PATCH] Add an advection example. --- arraycontext/parameter_study/transform.py | 236 ++++++++++------------ examples/advection.py | 81 ++++++++ 2 files changed, 184 insertions(+), 133 deletions(-) create mode 100644 examples/advection.py diff --git a/arraycontext/parameter_study/transform.py b/arraycontext/parameter_study/transform.py index 003abbf3..19921f56 100644 --- a/arraycontext/parameter_study/transform.py +++ b/arraycontext/parameter_study/transform.py @@ -44,6 +44,7 @@ Sequence, Set, Tuple, + Union, ) import numpy as np @@ -101,16 +102,14 @@ class ParameterStudyAxisTag(UniqueTag): class ExpansionMapper(CopyMapper): - # def __init__(self, dependency_map: Dict[Array,Tag]): - # super().__init__() - # self.depends = dependency_map def __init__(self, actual_input_shapes: Mapping[str, ShapeType], actual_input_axes: Mapping[str, FrozenSet[Axis]]): super().__init__() self.actual_input_shapes = actual_input_shapes self.actual_input_axes = actual_input_axes - def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: Array, + def single_predecessor_updates(self, + curr_expr: Array, new_expr: Array) -> Tuple[ShapeType, AxesT]: # Initialize with something for the typing. @@ -120,17 +119,17 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A return shape_to_append, new_axes # Now we may need to change. - changed = False for i in range(len(new_expr.axes)): axis_tags = list(new_expr.axes[i].tags) already_added = False - for j, tag in enumerate(axis_tags): + for _j, tag in enumerate(axis_tags): # Should be relatively few tags on each axis $O(1)$. if isinstance(tag, ParameterStudyAxisTag): new_axes = new_axes + (new_expr.axes[i],) shape_to_append = shape_to_append + (new_expr.shape[i],) if already_added: - raise ValueError("An individual axis may only be tagged with one ParameterStudyAxisTag.") + raise ValueError("An individual axis may only be " + + "tagged with one ParameterStudyAxisTag.") already_added = True # Remove initialized extraneous data @@ -138,7 +137,7 @@ def does_single_predecessor_require_rewrite_of_this_operation(self, curr_expr: A def map_roll(self, expr: Roll) -> Array: new_array = self.rec(expr.array) - _, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + _, new_axes = self.single_predecessor_updates(expr.array, new_array) return Roll(array=new_array, shift=expr.shift, @@ -149,7 +148,7 @@ def map_roll(self, expr: Roll) -> Array: def map_axis_permutation(self, expr: AxisPermutation) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, new_array) # Include the axes we are adding to the system. axis_permute = expr.axis_permutation + tuple([i + len(expr.axis_permutation) @@ -163,8 +162,8 @@ def map_axis_permutation(self, expr: AxisPermutation) -> Array: def _map_index_base(self, expr: IndexBase) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, - new_array) + _, new_axes = self.single_predecessor_updates(expr.array, + new_array) return type(expr)(new_array, indices=self.rec_idx_or_size_tuple(expr.indices), # May need to modify indices @@ -174,10 +173,11 @@ def _map_index_base(self, expr: IndexBase) -> Array: def map_reshape(self, expr: Reshape) -> Array: new_array = self.rec(expr.array) - postpend_shape, new_axes = self.does_single_predecessor_require_rewrite_of_this_operation(expr.array, - new_array) + postpend_shape, new_axes = self.single_predecessor_updates(expr.array, + new_array) return Reshape(new_array, - newshape=self.rec_idx_or_size_tuple(expr.newshape + postpend_shape), + newshape=self.rec_idx_or_size_tuple(expr.newshape + + postpend_shape), order=expr.order, axes=expr.axes + new_axes, tags=expr.tags, @@ -201,109 +201,71 @@ def map_placeholder(self, expr: Placeholder) -> Array: # {{{ Operations with multiple predecessors. - def _get_active_studies_from_multiple_predecessors(self, new_arrays: Tuple[Array, ...]) -> Tuple[Tuple[Axis, ...], - Set[ParameterStudyAxisTag], - Dict[Array, - Tuple[ParameterStudyAxisTag, ...]]]: + def _studies_from_multiple_pred(self, + new_arrays: Tuple[Array, ...]) -> Tuple[AxesT, + Set[ParameterStudyAxisTag], + Dict[Array, Tuple[ParameterStudyAxisTag, ...]]]: new_axes_for_end: Tuple[Axis, ...] = () - active_studies: Set[ParameterStudyAxisTag] = set() + cur_studies: Set[ParameterStudyAxisTag] = set() studies_by_array: Dict[Array, Tuple[ParameterStudyAxisTag, ...]] = {} - for ind, array in enumerate(new_arrays): + for _ind, array in enumerate(new_arrays): for axis in array.axes: axis_tags = axis.tags_of_type(ParameterStudyAxisTag) if axis_tags: axis_tags = list(axis_tags) assert len(axis_tags) == 1 if array in studies_by_array.keys(): - studies_by_array[array] = studies_by_array[array] + (axis_tags[0],) + studies_by_array[array] = studies_by_array[array] + \ + (axis_tags[0],) else: studies_by_array[array] = (axis_tags[0],) - if axis_tags[0] not in active_studies: - active_studies.add(axis_tags[0]) + if axis_tags[0] not in cur_studies: + cur_studies.add(axis_tags[0]) new_axes_for_end = new_axes_for_end + (axis,) - return new_axes_for_end, active_studies, studies_by_array + return new_axes_for_end, cur_studies, studies_by_array def map_stack(self, expr: Stack) -> Array: - # TODO: Fix - single_instance_input_shape = expr.arrays[0].shape - new_arrays = tuple(self.rec(arr) for arr in expr.arrays) - - new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) - - study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - - new_shape_of_predecessors = single_instance_input_shape - new_axes = expr.axes - - for study in active_studies: - if isinstance(study, ParameterStudyAxisTag): - # Just defensive programming - # The active studies are added to the end of the bindings. - study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = new_shape_of_predecessors + (study.axis_size,) - new_axes = new_axes + (Axis(tags=frozenset((study,))),) - # This assumes that the axis only has 1 tag, - # because there should be no dependence across instances. - - # This is going to be expensive. + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - # Now we need to update the expressions. - # Now that we have the appropriate shape, we need to update the input arrays to match. - cp_map = CopyMapper() - corrected_new_arrays: Tuple[Array, ...] = () - for ind, array in enumerate(new_arrays): - tmp = cp_map(array) # Get a copy of the array. - if len(array.axes) < len(new_axes): - # We need to grow the array to the new size. - for study in active_studies: - if study not in studies_by_array[array]: - build: Tuple[Array, ...] = tuple([cp_map(tmp) for _ in range(study.axis_size)]) - tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=tmp.axes + (Axis(tags=frozenset((study,))),), - tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) - elif len(array.axes) > len(new_axes): - raise ValueError(f"Input array is too big. Expected at most: {len(new_axes)} Found: {len(array.axes)} axes.") - - # Now we need to correct to the appropriate shape with an axis permutation. - # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(single_instance_input_shape))]) + return Stack(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - for iaxis, axis in enumerate(tmp.axes): - axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) - if axis_tags: - assert len(axis_tags) == 1 - permute = permute + (study_to_axis_number[axis_tags[0]],) - assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, - axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) + def map_concatenate(self, expr: Concatenate) -> Array: + new_arrays, new_axes_for_end = self._mult_pred_same_shape(expr) - return Stack(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, - tags=expr.tags, non_equality_tags=expr.non_equality_tags) + return Concatenate(arrays=new_arrays, + axis=expr.axis, + axes=expr.axes + new_axes_for_end, + tags=expr.tags, + non_equality_tags=expr.non_equality_tags) - def map_concatenate(self, expr: Concatenate) -> Array: - single_instance_input_shape = expr.arrays[0].shape - # Note that one of the axes within the first single_instance_input_shape - # will not match in size across all inputs. + def _mult_pred_same_shape(self, expr: Union[Stack, Concatenate]) -> Tuple[Tuple[Array, ...], + AxesT]: + sing_inst_in_shape = expr.arrays[0].shape new_arrays = tuple(self.rec(arr) for arr in expr.arrays) - new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) + _, cur_studies, studies_by_array = self._studies_from_multiple_pred(new_arrays) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} - new_shape_of_predecessors = single_instance_input_shape + new_shape_of_predecessors = sing_inst_in_shape new_axes = expr.axes - for study in active_studies: + for study in cur_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming # The active studies are added to the end of the bindings. study_to_axis_number[study] = len(new_shape_of_predecessors) - new_shape_of_predecessors = new_shape_of_predecessors + (study.axis_size,) + new_shape_of_predecessors = new_shape_of_predecessors + \ + (study.axis_size,) new_axes = new_axes + (Axis(tags=frozenset((study,))),) # This assumes that the axis only has 1 tag, # because there should be no dependence across instances. @@ -311,42 +273,51 @@ def map_concatenate(self, expr: Concatenate) -> Array: # This is going to be expensive. # Now we need to update the expressions. - # Now that we have the appropriate shape, we need to update the input arrays to match. + # Now that we have the appropriate shape, + # we need to update the input arrays to match. + cp_map = CopyMapper() corrected_new_arrays: Tuple[Array, ...] = () - for ind, array in enumerate(new_arrays): + for _, array in enumerate(new_arrays): tmp = cp_map(array) # Get a copy of the array. if len(array.axes) < len(new_axes): # We need to grow the array to the new size. - for study in active_studies: + for study in cur_studies: if study not in studies_by_array[array]: - build: Tuple[Array, ...] = tuple([cp_map(tmp) for _ in range(study.axis_size)]) + build: Tuple[Array, ...] = tuple([cp_map(tmp) for + _ in range(study.axis_size)]) tmp = Stack(arrays=build, axis=len(tmp.axes), - axes=tmp.axes + (Axis(tags=frozenset((study,))),), - tags=tmp.tags, non_equality_tags=tmp.non_equality_tags) + axes=tmp.axes + + (Axis(tags=frozenset((study,))),), + tags=tmp.tags, + non_equality_tags=tmp.non_equality_tags) elif len(array.axes) > len(new_axes): - raise ValueError(f"Input array is too big. Expected at most: {len(new_axes)} Found: {len(array.axes)} axes.") + raise ValueError("Input array is too big. " + \ + f"Expected at most: {len(new_axes)} " + \ + f"Found: {len(array.axes)} axes.") # Now we need to correct to the appropriate shape with an axis permutation. # These are known to be in the right place. - permute: Tuple[int, ...] = tuple([i for i in range(len(single_instance_input_shape))]) + permute: Tuple[int, ...] = tuple([i for i in range(len(sing_inst_in_shape))]) - for iaxis, axis in enumerate(tmp.axes): + for _, axis in enumerate(tmp.axes): axis_tags = list(axis.tags_of_type(ParameterStudyAxisTag)) if axis_tags: assert len(axis_tags) == 1 permute = permute + (study_to_axis_number[axis_tags[0]],) assert len(permute) == len(new_shape_of_predecessors) - corrected_new_arrays = corrected_new_arrays + (AxisPermutation(tmp, permute, tags=tmp.tags, - axes=tmp.axes, non_equality_tags=tmp.non_equality_tags),) + corrected_new_arrays = corrected_new_arrays + \ + (AxisPermutation(tmp, permute, tags=tmp.tags, + axes=tmp.axes, + non_equality_tags=tmp.non_equality_tags),) - return Concatenate(arrays=corrected_new_arrays, axis=expr.axis, axes=expr.axes + new_axes_for_end, - tags=expr.tags, non_equality_tags=expr.non_equality_tags) + return corrected_new_arrays, new_axes def map_index_lambda(self, expr: IndexLambda) -> Array: # Update bindings first. new_bindings: Dict[str, Array] = {name: self.rec(bnd) - for name, bnd in sorted(expr.bindings.items())} + for name, bnd in + sorted(expr.bindings.items())} # Determine the new parameter studies that are being conducted. from pytools import unique @@ -364,13 +335,13 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: studies_by_variable[name][tag] = True all_axis_tags = all_axis_tags + (tag,) - active_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) + cur_studies: Sequence[ParameterStudyAxisTag] = list(unique(all_axis_tags)) study_to_axis_number: Dict[ParameterStudyAxisTag, int] = {} new_shape = expr.shape new_axes = expr.axes - for study in active_studies: + for study in cur_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming # The active studies are added to the end of the bindings. @@ -381,7 +352,8 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: # because there should be no dependence across instances. # Now we need to update the expressions. - scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, study_to_axis_number) + scalar_expr = ParamAxisExpander()(expr.expr, studies_by_variable, + study_to_axis_number) return IndexLambda(expr=scalar_expr, bindings=immutabledict(new_bindings), @@ -395,7 +367,7 @@ def map_index_lambda(self, expr: IndexLambda) -> Array: def map_einsum(self, expr: Einsum) -> Array: new_arrays = tuple([self.rec(arg) for arg in expr.args]) - new_axes_for_end, active_studies, studies_by_array = self._get_active_studies_from_multiple_predecessors(new_arrays) + new_axes_for_end, cur_studies, _ = self._studies_from_multiple_pred(new_arrays) # Access Descriptors hold the Einsum notation. new_access_descriptors = list(expr.access_descriptors) @@ -403,7 +375,7 @@ def map_einsum(self, expr: Einsum) -> Array: new_shape = expr.shape - for study in active_studies: + for study in cur_studies: if isinstance(study, ParameterStudyAxisTag): # Just defensive programming # The active studies are added to the end. @@ -418,22 +390,21 @@ def map_einsum(self, expr: Einsum) -> Array: new_access_descriptors[ind] = new_access_descriptors[ind] + \ (EinsumElementwiseAxis(dim=study_to_axis_number[axis_tags[0]]),) - out = Einsum(tuple(new_access_descriptors), new_arrays, axes=expr.axes + new_axes_for_end, + return Einsum(tuple(new_access_descriptors), new_arrays, + axes=expr.axes + new_axes_for_end, redn_axis_to_redn_descr=expr.redn_axis_to_redn_descr, index_to_access_descr=expr.index_to_access_descr, tags=expr.tags, non_equality_tags=expr.non_equality_tags) - breakpoint() - return out - - return super().map_einsum(expr) # }}} Operations with multiple predecessors. class ParamAxisExpander(IdentityMapper): - def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, Mapping[ParameterStudyAxisTag, bool]], + def map_subscript(self, expr: prim.Subscript, + studies_by_variable: Mapping[str, + Mapping[ParameterStudyAxisTag, bool]], study_to_axis_number: Mapping[ParameterStudyAxisTag, int]): # We know that we are not changing the variable that we are indexing into. # This is stored in the aggregate member of the class Subscript. @@ -447,9 +418,10 @@ def map_subscript(self, expr: prim.Subscript, studies_by_variable: Mapping[str, new_vars: Tuple[prim.Variable, ...] = () - for key, val in sorted(study_to_axis_number.items(), key=lambda item: item[1]): + for key, num in sorted(study_to_axis_number.items(), + key=lambda item: item[1]): if key in studies_by_variable[name]: - new_vars = new_vars + (prim.Variable(f"_{study_to_axis_number[key]}"),) + new_vars = new_vars + (prim.Variable(f"_{num}"),) if isinstance(index, tuple): index = index + new_vars @@ -477,7 +449,8 @@ class ParamStudyPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext): def compile(self, f: Callable[..., Any]) -> Callable[..., Any]: return ParamStudyLazyPyOpenCLFunctionCaller(self, f) - def transform_loopy_program(self, t_unit: lp.TranslationUnit) -> lp.TranslationUnit: + def transform_loopy_program(self, + t_unit: lp.TranslationUnit) -> lp.TranslationUnit: # Update in a subclass if you want. return t_unit @@ -497,10 +470,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: Returns the result of :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`'s function application on *args*. - Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, it is compiled - to a :mod:`pytato` DAG that would apply - :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` with *args* in a lazy-sense. - The intermediary pytato DAG for *args* is memoized in *self*. + Before applying :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f`, + it is compiled to a :mod:`pytato` DAG that would apply + :attr:`~ParamStudyLazyPyOpenCLFunctionCaller.f` + with *args* in a lazy-sense. The intermediary pytato DAG for *args* is + memoized in *self*. """ arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr( args, kwargs) @@ -519,10 +493,10 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any: arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}" for arg_id in arg_id_to_arg} - output_template = self.f( - *[_get_f_placeholder_args_for_param_study(arg, iarg, - input_id_to_name_in_program, self.actx) - for iarg, arg in enumerate(args)], + placeholder_args = [_get_f_placeholder_args_for_param_study(arg, iarg, + input_id_to_name_in_program, self.actx) + for iarg, arg in enumerate(args)] + output_template = self.f(*placeholder_args, **{kw: _get_f_placeholder_args_for_param_study(arg, kw, input_id_to_name_in_program, self.actx) @@ -548,30 +522,25 @@ def _as_dict_of_named_arrays(keys, ary): rec_keyed_map_array_container(_as_dict_of_named_arrays, output_template) - # input_shapes = {input_id_to_name_in_program[i]: arg_id_to_descr[i].shape for i in arg_id_to_descr.keys() if hasattr(arg_id_to_descr[i], "shape")} - # input_axes = {input_id_to_name_in_program[i]: arg_id_to_arg[i].axes for i in arg_id_to_descr.keys()} input_shapes = {} input_axes = {} for key, val in arg_id_to_descr.items(): if isinstance(val, LeafArrayDescriptor): input_shapes[input_id_to_name_in_program[key]] = val.shape input_axes[input_id_to_name_in_program[key]] = arg_id_to_arg[key].axes - my_expansion_map = ExpansionMapper(input_shapes, input_axes) # Get the dependencies - breakpoint() + expand_map = ExpansionMapper(input_shapes, input_axes) + # Get the dependencies - pt_dict_of_named_arrays = pt.make_dict_of_named_arrays(dict_of_named_arrays) - - breakpoint() + sing_inst_outs = pt.make_dict_of_named_arrays(dict_of_named_arrays) # Use the normal compiler now. - compiled_func = self._dag_to_compiled_func(my_expansion_map(pt_dict_of_named_arrays), + compiled_func = self._dag_to_compiled_func(expand_map(sing_inst_outs), # pt_dict_of_named_arrays, input_id_to_name_in_program=input_id_to_name_in_program, output_id_to_name_in_program=output_id_to_name_in_program, output_template=output_template) - breakpoint() self.program_cache[arg_id_to_descr] = compiled_func return compiled_func(arg_id_to_arg) @@ -582,7 +551,7 @@ def _cut_if_in_param_study(name, arg) -> Array: if it is tagged with a `ParameterStudyAxisTag` to ensure the survival of the information those tags will be converted to temporary Array Tags of the same type. The placeholder will not - have the axes marked with a `ParameterStudyAxisTag` tag. + have the axes marked with a `ParameterStudyAxisTag` tag. """ ndim: int = len(arg.shape) newshape = [] @@ -592,7 +561,8 @@ def _cut_if_in_param_study(name, arg) -> Array: if not axis_tags: update_axes = update_axes + (arg.axes[i],) newshape.append(arg.shape[i]) - update_axes = update_axes[1:] # remove the first one that was placed there for typing. + # remove the first one that was placed there for typing. + update_axes = update_axes[1:] update_tags: FrozenSet[Tag] = arg.tags return pt.make_placeholder(name, newshape, arg.dtype, axes=update_axes, tags=update_tags) @@ -600,10 +570,10 @@ def _cut_if_in_param_study(name, arg) -> Array: def _get_f_placeholder_args_for_param_study(arg, kw, arg_id_to_name, actx): """ - Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. + Helper for :class:`BaseLazilyCompilingFunctionCaller.__call__`. Returns the placeholder version of an argument to :attr:`ParamStudyLazyPyOpenCLFunctionCaller.f`. - + Note this will modify the shape of the placeholder to remove any parameter study axes until the trace can be completed. diff --git a/examples/advection.py b/examples/advection.py new file mode 100644 index 00000000..ab4405d5 --- /dev/null +++ b/examples/advection.py @@ -0,0 +1,81 @@ +from dataclasses import dataclass + +import numpy as np # for the data types + +import pyopencl as cl + +from arraycontext.parameter_study import ( + pack_for_parameter_study, + unpack_parameter_study, +) +from arraycontext.parameter_study.transform import ( + ParameterStudyAxisTag, + ParamStudyPytatoPyOpenCLArrayContext, +) + + +ctx = cl.create_some_context(interactive=False) +queue = cl.CommandQueue(ctx) +actx = ParamStudyPytatoPyOpenCLArrayContext(queue) + + + +@dataclass(frozen=True) +class ParameterStudyForX(ParameterStudyAxisTag): + pass + + +@dataclass(frozen=True) +class ParameterStudyForY(ParameterStudyAxisTag): + pass + +def test_one_time_step_advection(): + + from arraycontext.impl.pytato import _BasePytatoArrayContext + if not isinstance(actx, ParamStudyPytatoPyOpenCLArrayContext): + pytest.skip("only parameter study array contexts are supported") + + import numpy as np + seed = 12345 + rng = np.random.default_rng(seed) + + base_shape = np.prod((15, 5)) + x0 = actx.from_numpy(rng.random(base_shape)) + x1 = actx.from_numpy(rng.random(base_shape)) + x2 = actx.from_numpy(rng.random(base_shape)) + x3 = actx.from_numpy(rng.random(base_shape)) + + + speed_shape = (1,) + y0 = actx.from_numpy(rng.random(speed_shape)) + y1 = actx.from_numpy(rng.random(speed_shape)) + y2 = actx.from_numpy(rng.random(speed_shape)) + y3 = actx.from_numpy(rng.random(speed_shape)) + + + ht = 0.0001 + hx = 0.005 + inds = actx.np.arange(base_shape, dtype=int) + Kp1 = actx.np.roll(inds, -1) + Km1 = actx.np.roll(inds, 1) + + def rhs(fields, wave_speed): + # 2nd order in space finite difference + return fields + wave_speed * (-1) * (ht / (2 * hx)) * \ + (fields[Kp1] - fields[Km1]) + + pack_x = pack_for_parameter_study(actx, ParamStudy1, (4,), x0, x1, x2, x3) + assert pack_x.shape == (75,4) + + pack_y = pack_for_parameter_study(actx, ParamStudy1, (4,), y0,y1, y2,y3) + assert pack_y.shape == (1,4) + + compiled_rhs = actx.compile(rhs) + + output = compiled_rhs(pack_x, pack_y) + + assert output.shape(75,4) + + output_x = unpack_parameter_study(output, ParamStudy1) + assert len(output_x) == 1 # Only 1 study associated with this variable. + assert len(output_x[0]) == 4 # 4 inputs for the parameter study.