Skip to content

Commit

Permalink
Use pytato's BoundProgram.bind_to_context on generated code in lazy
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Jul 27, 2023
1 parent 90240d4 commit c68e074
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
6 changes: 4 additions & 2 deletions arraycontext/impl/pytato/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,8 +534,10 @@ def _to_frozen(key: Tuple[Any, ...], ary) -> TaggableCLArray:
options=opts,
cl_device=self.queue.device,
function_name=function_name,
target=self.get_target())
pt_prg = pt_prg.with_transformed_program(self.transform_loopy_program)
target=self.get_target()
).bind_to_context(self.context)
pt_prg = pt_prg.with_transformed_translation_unit(
self.transform_loopy_program)
self._freeze_prg_cache[normalized_expr] = pt_prg
else:
transformed_dag, function_name = (
Expand Down
15 changes: 8 additions & 7 deletions arraycontext/impl/pytato/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ def _as_dict_of_named_arrays(keys, ary):
# {{{ LazilyPyOpenCLCompilingFunctionCaller

class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
actx: PytatoPyOpenCLArrayContext

@property
def compiled_function_returning_array_container_class(
self) -> Type["CompiledFunction"]:
Expand All @@ -391,7 +393,7 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
if prg_id is None:
prg_id = self.f

from pytato.target.loopy import BoundPyOpenCLProgram
from pytato.target.loopy import BoundPyOpenCLExecutable

self.actx._compile_trace_callback(
prg_id, "pre_transform_dag", dict_of_named_arrays)
Expand Down Expand Up @@ -422,8 +424,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
options=opts,
function_name=_prg_id_to_kernel_name(prg_id),
target=self.actx.get_target(),
)
assert isinstance(pytato_program, BoundPyOpenCLProgram)
).bind_to_context(self.actx.context) # pylint: disable=no-member
assert isinstance(pytato_program, BoundPyOpenCLExecutable)

self.actx._compile_trace_callback(
prg_id, "post_generate_loopy", pytato_program)
Expand All @@ -434,15 +436,14 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
with ProcessLogger(logger, f"transform_loopy_program for '{prg_id}'"):

pytato_program = (pytato_program
.with_transformed_program(
.with_transformed_translation_unit(
lambda x: x.with_kernel(
x.default_entrypoint
.tagged(FromArrayContextCompile()))))

pytato_program = (pytato_program
.with_transformed_program(self
.actx
.transform_loopy_program))
.with_transformed_translation_unit(
self.actx.transform_loopy_program))

self.actx._compile_trace_callback(
prg_id, "post_transform_loopy_program", pytato_program)
Expand Down

0 comments on commit c68e074

Please sign in to comment.