Skip to content

Commit

Permalink
[WIP] Type the mappers
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 7, 2024
1 parent f3ab2de commit 4190844
Showing 1 changed file with 141 additions and 52 deletions.
193 changes: 141 additions & 52 deletions pymbolic/mapper/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,20 @@
"""

from abc import ABC, abstractmethod
from typing import Any
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, TypeVar
from warnings import warn

from immutabledict import immutabledict
from typing_extensions import ParamSpec, TypeAlias, TypeIs

import pymbolic.primitives as primitives
from pymbolic.typing import ExpressionT


if TYPE_CHECKING:
import numpy as np

from pymbolic.rational import Rational


__doc__ = """
Expand Down Expand Up @@ -96,14 +105,20 @@
"""


try:
import numpy
if TYPE_CHECKING:
import numpy as np

def is_numpy_array(val) -> TypeIs[np.ndarray]:
return isinstance(val, np.ndarray)
else:
try:
import numpy as np

def is_numpy_array(val):
return isinstance(val, numpy.ndarray)
except ImportError:
def is_numpy_array(ary):
return False
def is_numpy_array(val):
return isinstance(val, np.ndarray)
except ImportError:
def is_numpy_array(ary):
return False


class UnsupportedExpressionError(ValueError):
Expand All @@ -112,15 +127,20 @@ class UnsupportedExpressionError(ValueError):

# {{{ mapper base

class Mapper:
ResultT = TypeVar("ResultT")
P = ParamSpec("P")


class Mapper(Generic[ResultT, P]):
"""A visitor for trees of :class:`pymbolic.Expression`
subclasses. Each expression-derived object is dispatched to the
method named by the :attr:`pymbolic.Expression.mapper_method`
attribute and if not found, the methods named by the class attribute
*mapper_method* in the method resolution order of the object.
"""

def handle_unsupported_expression(self, expr, *args, **kwargs):
def handle_unsupported_expression(self,
expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT:
"""Mapper method that is invoked for
:class:`pymbolic.Expression` subclasses for which a mapper
method does not exist in this mapper.
Expand All @@ -130,7 +150,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs):
"{} cannot handle expressions of type {}".format(
type(self), type(expr)))

def __call__(self, expr, *args, **kwargs):
def __call__(self,
expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT:
"""Dispatch *expr* to its corresponding mapper method. Pass on
``*args`` and ``**kwargs`` unmodified.
Expand Down Expand Up @@ -162,7 +183,8 @@ def __call__(self, expr, *args, **kwargs):

rec = __call__

def rec_fallback(self, expr, *args, **kwargs):
def rec_fallback(self,
expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT:
if isinstance(expr, primitives.Expression):
for cls in type(expr).__mro__[1:]:
method_name = getattr(cls, "mapper_method", None)
Expand All @@ -175,56 +197,95 @@ def rec_fallback(self, expr, *args, **kwargs):
else:
return self.map_foreign(expr, *args, **kwargs)

def map_algebraic_leaf(self, expr, *args, **kwargs):
def map_algebraic_leaf(self,
expr: primitives.AlgebraicLeaf,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_variable(self, expr, *args, **kwargs):
def map_variable(self,
expr: primitives.Variable, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_algebraic_leaf(expr, *args, **kwargs)

def map_subscript(self, expr, *args, **kwargs):
def map_subscript(self,
expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_algebraic_leaf(expr, *args, **kwargs)

def map_call(self, expr, *args, **kwargs):
def map_call(self,
expr: primitives.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_algebraic_leaf(expr, *args, **kwargs)

def map_lookup(self, expr, *args, **kwargs):
def map_call_with_kwargs(self,
expr: primitives.CallWithKwargs,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_algebraic_leaf(expr, *args, **kwargs)

def map_if_positive(self, expr, *args, **kwargs):
def map_lookup(self,
expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.map_algebraic_leaf(expr, *args, **kwargs)

def map_rational(self, expr, *args, **kwargs):
return self.map_quotient(expr, *args, **kwargs)
def map_if(self,
expr: primitives.If, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_quotient(self, expr, *args, **kwargs):
def map_rational(self,
expr: Rational, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_constant(self, expr, *args, **kwargs):
def map_quotient(self,
expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_list(self, expr, *args, **kwargs):
def map_floor_div(self,
expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_tuple(self, expr, *args, **kwargs):
def map_remainder(self,
expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_numpy_array(self, expr, *args, **kwargs):
def map_constant(self,
expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_nan(self, expr, *args, **kwargs):
def map_list(self,
expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_tuple(self,
expr: tuple[ExpressionT, ...],
*args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_numpy_array(self,
expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> ResultT:
raise NotImplementedError

def map_nan(self,
expr: primitives.NaN,
*args: P.args,
**kwargs: P.kwargs
) -> ResultT:
return self.map_algebraic_leaf(expr, *args, **kwargs)

def map_foreign(self, expr, *args, **kwargs):
def map_foreign(self,
expr: object,
*args: P.args,
**kwargs: P.kwargs
) -> ResultT:
"""Mapper method dispatch for non-:mod:`pymbolic` objects."""

if isinstance(expr, primitives.VALID_CONSTANT_CLASSES):
return self.map_constant(expr, *args, **kwargs)
elif is_numpy_array(expr):
return self.map_numpy_array(expr, *args, **kwargs)
elif isinstance(expr, list):
return self.map_list(expr, *args, **kwargs)
elif isinstance(expr, tuple):
return self.map_tuple(expr, *args, **kwargs)
elif isinstance(expr, list):
warn("List found in expression graph. "
"This is deprecated and will stop working in 2025. "
"Use tuples instead.", DeprecationWarning, stacklevel=2
)
return self.map_list(expr, *args, **kwargs)
else:
raise ValueError(
"{} encountered invalid foreign object: {}".format(
Expand All @@ -234,17 +295,24 @@ def map_foreign(self, expr, *args, **kwargs):
_NOT_IN_CACHE = object()


class CachedMapper(Mapper):
CacheKeyT: TypeAlias = Hashable


class CachedMapper(Mapper[ResultT, P]):
"""
A mapper that memoizes the mapped result for the expressions traversed.
.. automethod:: get_cache_key
"""
def __init__(self):
self._cache: dict[Any, Any] = {}
def __init__(self) -> None:
self._cache: dict[CacheKeyT, ResultT] = {}
Mapper.__init__(self)

def get_cache_key(self, expr, *args, **kwargs):
def get_cache_key(self,
expr: ExpressionT,
*args: P.args,
**kwargs: P.kwargs
) -> CacheKeyT:
"""
Returns the key corresponding to which the result of a mapper method is
stored in the cache.
Expand All @@ -260,7 +328,11 @@ def get_cache_key(self, expr, *args, **kwargs):
# and "4 == 4.0", but their traversal results cannot be re-used.
return (type(expr), expr, args, immutabledict(kwargs))

def __call__(self, expr, *args, **kwargs):
def __call__(self,
expr: ExpressionT,
*args,
**kwargs
) -> ResultT:
result = self._cache.get(
(cache_key := self.get_cache_key(expr, *args, **kwargs)),
_NOT_IN_CACHE)
Expand All @@ -286,7 +358,10 @@ def __call__(self, expr, *args, **kwargs):

# {{{ combine mapper

class CombineMapper(Mapper):
CombineArgT = TypeVar("CombineArgT")


class CombineMapper(Mapper[ResultT, P], Generic[ResultT, P, CombineArgT]):
"""A mapper whose goal it is to *combine* all branches of the expression
tree into one final result. The default implementation of all mapper
methods simply recurse (:meth:`Mapper.rec`) on all branches emanating from
Expand All @@ -304,55 +379,69 @@ class CombineMapper(Mapper):
:class:`pymbolic.mapper.dependency.DependencyMapper` is another example.
"""

def combine(self, values):
def combine(self, values: Iterable[CombineArgT]) -> ResultT:
raise NotImplementedError

def map_call(self, expr, *args, **kwargs):
def map_call(self,
expr: primitives.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine((
self.rec(expr.function, *args, **kwargs),
*[self.rec(child, *args, **kwargs) for child in expr.parameters]
))

def map_call_with_kwargs(self, expr, *args, **kwargs):
def map_call_with_kwargs(self,
expr: primitives.CallWithKwargs,
*args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine((
self.rec(expr.function, *args, **kwargs),
*[self.rec(child, *args, **kwargs) for child in expr.parameters],
*[self.rec(child, *args, **kwargs)
for child in expr.kw_parameters.values()]
))

def map_subscript(self, expr, *args, **kwargs):
def map_subscript(self,
expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine(
[self.rec(expr.aggregate, *args, **kwargs),
self.rec(expr.index, *args, **kwargs)])

def map_lookup(self, expr, *args, **kwargs):
def map_lookup(self,
expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.rec(expr.aggregate, *args, **kwargs)

def map_sum(self, expr, *args, **kwargs):
def map_sum(self,
expr: primitives.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine(self.rec(child, *args, **kwargs)
for child in expr.children)

map_product = map_sum
def map_product(self,
expr: primitives.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine(self.rec(child, *args, **kwargs)
for child in expr.children)

def map_quotient(self, expr, *args, **kwargs):
def map_quotient(self,
expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine((
self.rec(expr.numerator, *args, **kwargs),
self.rec(expr.denominator, *args, **kwargs)))

map_floor_div = map_quotient
map_remainder = map_quotient
def map_floor_div(self,
expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine((
self.rec(expr.numerator, *args, **kwargs),
self.rec(expr.denominator, *args, **kwargs)))

def map_power(self, expr, *args, **kwargs):
def map_remainder(self,
expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine((
self.rec(expr.base, *args, **kwargs),
self.rec(expr.exponent, *args, **kwargs)))
self.rec(expr.numerator, *args, **kwargs),
self.rec(expr.denominator, *args, **kwargs)))

def map_polynomial(self, expr, *args, **kwargs):
def map_power(self,
expr: primitives.Power, *args: P.args, **kwargs: P.kwargs) -> ResultT:
return self.combine((
self.rec(expr.base, *args, **kwargs),
*[self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data]
))
self.rec(expr.base, *args, **kwargs),
self.rec(expr.exponent, *args, **kwargs)))

def map_left_shift(self, expr, *args, **kwargs):
return self.combine((
Expand Down

0 comments on commit 4190844

Please sign in to comment.