diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 6080c577..0068a8e9 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -32,7 +32,7 @@ ArrayContext, dataclass_array_container, with_container_arithmetic, serialize_container, deserialize_container, - freeze, thaw, + freeze, thaw, with_array_context, FirstAxisIsElementsTag, PyOpenCLArrayContext, PytatoPyOpenCLArrayContext, @@ -188,22 +188,9 @@ def _raise_index_inconsistency(i, stream_i): for i, (stream_i, v) in enumerate(iterable))) -@freeze.register(DOFArray) -def _freeze_dofarray(ary, actx=None): - assert actx is None - return type(ary)( - None, - tuple(ary.array_context.freeze(subary) for subary in ary.data)) - - -@thaw.register(DOFArray) -def _thaw_dofarray(ary, actx): - if ary.array_context is not None: - raise ValueError("cannot thaw DOFArray that already has an array context") - - return type(ary)( - actx, - tuple(actx.thaw(subary) for subary in ary.data)) +@with_array_context.register(DOFArray) +def _with_actx_dofarray(ary, actx): + return type(ary)(actx, ary.data) # }}} @@ -1200,6 +1187,11 @@ class Velocity2D: array_context: ArrayContext +@with_array_context.register(Velocity2D) +def _with_actx_velocity_2d(ary, actx): + return type(ary)(ary.u, ary.v, actx) + + def scale_and_orthogonalize(alpha, vel): from arraycontext import rec_map_array_container actx = vel.array_context