diff --git a/arraycontext/context.py b/arraycontext/context.py index d296f8f7..30f58cb1 100644 --- a/arraycontext/context.py +++ b/arraycontext/context.py @@ -339,7 +339,7 @@ def to_numpy(self, @abstractmethod def call_loopy(self, - program: "loopy.TranslationUnit", + t_unit: "loopy.TranslationUnit", **kwargs: Any) -> Dict[str, Array]: """Execute the :mod:`loopy` program *program* on the arguments *kwargs*. diff --git a/arraycontext/impl/numpy/__init__.py b/arraycontext/impl/numpy/__init__.py index 77b7b49f..f8ba95e3 100644 --- a/arraycontext/impl/numpy/__init__.py +++ b/arraycontext/impl/numpy/__init__.py @@ -1,3 +1,6 @@ +from __future__ import annotations + + """ .. currentmodule:: arraycontext @@ -30,7 +33,7 @@ THE SOFTWARE. """ -from typing import Any, Dict +from typing import Any import numpy as np @@ -39,6 +42,7 @@ from arraycontext.container.traversal import rec_map_array_container, with_array_context from arraycontext.context import ( + Array, ArrayContext, ArrayOrContainerOrScalar, ArrayOrContainerOrScalarT, @@ -62,10 +66,12 @@ class NumpyArrayContext(ArrayContext): .. automethod:: __init__ """ + + _loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase] + def __init__(self) -> None: super().__init__() - self._loopy_transform_cache: \ - Dict[lp.TranslationUnit, lp.TranslationUnit] = {} + self._loopy_transform_cache = {} array_types = (NumpyNonObjectArray,) @@ -88,17 +94,18 @@ def to_numpy(self, ) -> NumpyOrContainerOrScalar: return array - def call_loopy(self, t_unit, **kwargs): + def call_loopy( + self, + t_unit: lp.TranslationUnit, **kwargs: Any + ) -> dict[str, Array]: t_unit = t_unit.copy(target=lp.ExecutableCTarget()) try: - t_unit = self._loopy_transform_cache[t_unit] + executor = self._loopy_transform_cache[t_unit] except KeyError: - orig_t_unit = t_unit - t_unit = self.transform_loopy_program(t_unit) - self._loopy_transform_cache[orig_t_unit] = t_unit - del orig_t_unit + executor = self.transform_loopy_program(t_unit).executor() + self._loopy_transform_cache[t_unit] = executor - _, result = t_unit(**kwargs) + _, result = executor(**kwargs) return result