diff --git a/arraycontext/impl/pytato/__init__.py b/arraycontext/impl/pytato/__init__.py index 8d3b678f..a32e8de0 100644 --- a/arraycontext/impl/pytato/__init__.py +++ b/arraycontext/impl/pytato/__init__.py @@ -534,8 +534,10 @@ def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray: options=opts, cl_device=self.queue.device, function_name=function_name, - target=self.get_target()) - pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program) + target=self.get_target() + ).bind_to_context(self.context) + pt_prg = pt_prg.with_transformed_translation_unit( + self.transform_loopy_program) self._freeze_prg_cache[normalized_expr] = pt_prg else: transformed_dag, function_name = ( diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index c4996879..5fe9e7b3 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -378,6 +378,8 @@ def _as_dict_of_named_arrays(keys, ary): # {{{ LazilyPyOpenCLCompilingFunctionCaller class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller): + actx: PytatoPyOpenCLArrayContext + @property def compiled_function_returning_array_container_class( self) -> Type["CompiledFunction"]: @@ -391,7 +393,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): if prg_id is None: prg_id = self.f - from pytato.target.loopy import BoundPyOpenCLProgram + from pytato.target.loopy import BoundPyOpenCLExecutable self.actx._compile_trace_callback( prg_id, "pre_transform_dag", dict_of_named_arrays) @@ -422,8 +424,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): options=opts, function_name=_prg_id_to_kernel_name(prg_id), target=self.actx.get_target(), - ) - assert isinstance(pytato_program, BoundPyOpenCLProgram) + ).bind_to_context(self.actx.context) # pylint: disable=no-member + assert isinstance(pytato_program, BoundPyOpenCLExecutable) self.actx._compile_trace_callback( prg_id, "post_generate_loopy", pytato_program) @@ -434,15 +436,14 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None): with ProcessLogger(logger, f"transform_loopy_program for '{prg_id}'"): pytato_program = (pytato_program - .with_transformed_program( + .with_transformed_translation_unit( lambda x: x.with_kernel( x.default_entrypoint .tagged(FromArrayContextCompile())))) pytato_program = (pytato_program - .with_transformed_program(self - .actx - .transform_loopy_program)) + .with_transformed_translation_unit( + self.actx.transform_loopy_program)) self.actx._compile_trace_callback( prg_id, "post_transform_loopy_program", pytato_program)