Skip to content

Commit

Permalink
Numpy actx: cache execuctor
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Sep 4, 2024
1 parent a63c263 commit ce9736c
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 10 deletions.
2 changes: 1 addition & 1 deletion arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ def to_numpy(self,

@abstractmethod
def call_loopy(self,
program: "loopy.TranslationUnit",
t_unit: "loopy.TranslationUnit",
**kwargs: Any) -> Dict[str, Array]:
"""Execute the :mod:`loopy` program *program* on the arguments
*kwargs*.
Expand Down
26 changes: 17 additions & 9 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from __future__ import annotations


"""
.. currentmodule:: arraycontext
Expand Down Expand Up @@ -30,6 +33,7 @@
THE SOFTWARE.
"""

from collections.abc import Mapping
from typing import Any, Dict

import numpy as np
Expand All @@ -39,6 +43,7 @@

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainerOrScalar,
ArrayOrContainerOrScalarT,
Expand All @@ -62,10 +67,12 @@ class NumpyArrayContext(ArrayContext):
.. automethod:: __init__
"""

_loopy_transform_cache: dict[lp.TranslationUnit, lp.ExecutorBase]

def __init__(self) -> None:
super().__init__()
self._loopy_transform_cache: \
Dict[lp.TranslationUnit, lp.TranslationUnit] = {}
self._loopy_transform_cache = {}

array_types = (NumpyNonObjectArray,)

Expand All @@ -88,17 +95,18 @@ def to_numpy(self,
) -> NumpyOrContainerOrScalar:
return array

def call_loopy(self, t_unit, **kwargs):
def call_loopy(
self,
t_unit: lp.TranslationUnit, **kwargs: Any
) -> dict[str, Array]:
t_unit = t_unit.copy(target=lp.ExecutableCTarget())
try:
t_unit = self._loopy_transform_cache[t_unit]
executor = self._loopy_transform_cache[t_unit]
except KeyError:
orig_t_unit = t_unit
t_unit = self.transform_loopy_program(t_unit)
self._loopy_transform_cache[orig_t_unit] = t_unit
del orig_t_unit
executor = self.transform_loopy_program(t_unit).executor()
self._loopy_transform_cache[t_unit] = executor

_, result = t_unit(**kwargs)
_, result = executor(**kwargs)

return result

Expand Down

0 comments on commit ce9736c

Please sign in to comment.