diff --git a/pymbolic/mapper/__init__.py b/pymbolic/mapper/__init__.py index 22315f9..459f5a3 100644 --- a/pymbolic/mapper/__init__.py +++ b/pymbolic/mapper/__init__.py @@ -24,11 +24,20 @@ """ from abc import ABC, abstractmethod -from typing import Any +from typing import TYPE_CHECKING, Generic, Hashable, Iterable, TypeVar +from warnings import warn from immutabledict import immutabledict +from typing_extensions import ParamSpec, TypeAlias, TypeIs import pymbolic.primitives as primitives +from pymbolic.typing import ExpressionT + + +if TYPE_CHECKING: + import numpy as np + + from pymbolic.rational import Rational __doc__ = """ @@ -96,14 +105,20 @@ """ -try: - import numpy +if TYPE_CHECKING: + import numpy as np + + def is_numpy_array(val) -> TypeIs[np.ndarray]: + return isinstance(val, np.ndarray) +else: + try: + import numpy as np - def is_numpy_array(val): - return isinstance(val, numpy.ndarray) -except ImportError: - def is_numpy_array(ary): - return False + def is_numpy_array(val): + return isinstance(val, np.ndarray) + except ImportError: + def is_numpy_array(ary): + return False class UnsupportedExpressionError(ValueError): @@ -112,7 +127,11 @@ class UnsupportedExpressionError(ValueError): # {{{ mapper base -class Mapper: +ResultT = TypeVar("ResultT") +P = ParamSpec("P") + + +class Mapper(Generic[ResultT, P]): """A visitor for trees of :class:`pymbolic.Expression` subclasses. Each expression-derived object is dispatched to the method named by the :attr:`pymbolic.Expression.mapper_method` @@ -120,7 +139,8 @@ class Mapper: *mapper_method* in the method resolution order of the object. """ - def handle_unsupported_expression(self, expr, *args, **kwargs): + def handle_unsupported_expression(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Mapper method that is invoked for :class:`pymbolic.Expression` subclasses for which a mapper method does not exist in this mapper. @@ -130,7 +150,8 @@ def handle_unsupported_expression(self, expr, *args, **kwargs): "{} cannot handle expressions of type {}".format( type(self), type(expr))) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, *args: P.args, **kwargs: P.kwargs) -> ResultT: """Dispatch *expr* to its corresponding mapper method. Pass on ``*args`` and ``**kwargs`` unmodified. @@ -162,7 +183,8 @@ def __call__(self, expr, *args, **kwargs): rec = __call__ - def rec_fallback(self, expr, *args, **kwargs): + def rec_fallback(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: if isinstance(expr, primitives.Expression): for cls in type(expr).__mro__[1:]: method_name = getattr(cls, "mapper_method", None) @@ -175,56 +197,95 @@ def rec_fallback(self, expr, *args, **kwargs): else: return self.map_foreign(expr, *args, **kwargs) - def map_algebraic_leaf(self, expr, *args, **kwargs): + def map_algebraic_leaf(self, + expr: primitives.AlgebraicLeaf, + *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_variable(self, expr, *args, **kwargs): + def map_variable(self, + expr: primitives.Variable, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: primitives.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_lookup(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: primitives.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_if_positive(self, expr, *args, **kwargs): + def map_lookup(self, + expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_rational(self, expr, *args, **kwargs): - return self.map_quotient(expr, *args, **kwargs) + def map_if(self, + expr: primitives.If, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError - def map_quotient(self, expr, *args, **kwargs): + def map_rational(self, + expr: Rational, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_constant(self, expr, *args, **kwargs): + def map_quotient(self, + expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_list(self, expr, *args, **kwargs): + def map_floor_div(self, + expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_tuple(self, expr, *args, **kwargs): + def map_remainder(self, + expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_numpy_array(self, expr, *args, **kwargs): + def map_constant(self, + expr: object, *args: P.args, **kwargs: P.kwargs) -> ResultT: raise NotImplementedError - def map_nan(self, expr, *args, **kwargs): + def map_list(self, + expr: list[ExpressionT], *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_tuple(self, + expr: tuple[ExpressionT, ...], + *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_numpy_array(self, + expr: np.ndarray, *args: P.args, **kwargs: P.kwargs) -> ResultT: + raise NotImplementedError + + def map_nan(self, + expr: primitives.NaN, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: return self.map_algebraic_leaf(expr, *args, **kwargs) - def map_foreign(self, expr, *args, **kwargs): + def map_foreign(self, + expr: object, + *args: P.args, + **kwargs: P.kwargs + ) -> ResultT: """Mapper method dispatch for non-:mod:`pymbolic` objects.""" if isinstance(expr, primitives.VALID_CONSTANT_CLASSES): return self.map_constant(expr, *args, **kwargs) elif is_numpy_array(expr): return self.map_numpy_array(expr, *args, **kwargs) - elif isinstance(expr, list): - return self.map_list(expr, *args, **kwargs) elif isinstance(expr, tuple): return self.map_tuple(expr, *args, **kwargs) + elif isinstance(expr, list): + warn("List found in expression graph. " + "This is deprecated and will stop working in 2025. " + "Use tuples instead.", DeprecationWarning, stacklevel=2 + ) + return self.map_list(expr, *args, **kwargs) else: raise ValueError( "{} encountered invalid foreign object: {}".format( @@ -234,17 +295,24 @@ def map_foreign(self, expr, *args, **kwargs): _NOT_IN_CACHE = object() -class CachedMapper(Mapper): +CacheKeyT: TypeAlias = Hashable + + +class CachedMapper(Mapper[ResultT, P]): """ A mapper that memoizes the mapped result for the expressions traversed. .. automethod:: get_cache_key """ - def __init__(self): - self._cache: dict[Any, Any] = {} + def __init__(self) -> None: + self._cache: dict[CacheKeyT, ResultT] = {} Mapper.__init__(self) - def get_cache_key(self, expr, *args, **kwargs): + def get_cache_key(self, + expr: ExpressionT, + *args: P.args, + **kwargs: P.kwargs + ) -> CacheKeyT: """ Returns the key corresponding to which the result of a mapper method is stored in the cache. @@ -260,7 +328,11 @@ def get_cache_key(self, expr, *args, **kwargs): # and "4 == 4.0", but their traversal results cannot be re-used. return (type(expr), expr, args, immutabledict(kwargs)) - def __call__(self, expr, *args, **kwargs): + def __call__(self, + expr: ExpressionT, + *args, + **kwargs + ) -> ResultT: result = self._cache.get( (cache_key := self.get_cache_key(expr, *args, **kwargs)), _NOT_IN_CACHE) @@ -286,7 +358,10 @@ def __call__(self, expr, *args, **kwargs): # {{{ combine mapper -class CombineMapper(Mapper): +CombineArgT = TypeVar("CombineArgT") + + +class CombineMapper(Mapper[ResultT, P], Generic[ResultT, P, CombineArgT]): """A mapper whose goal it is to *combine* all branches of the expression tree into one final result. The default implementation of all mapper methods simply recurse (:meth:`Mapper.rec`) on all branches emanating from @@ -304,16 +379,19 @@ class CombineMapper(Mapper): :class:`pymbolic.mapper.dependency.DependencyMapper` is another example. """ - def combine(self, values): + def combine(self, values: Iterable[CombineArgT]) -> ResultT: raise NotImplementedError - def map_call(self, expr, *args, **kwargs): + def map_call(self, + expr: primitives.Call, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters] )) - def map_call_with_kwargs(self, expr, *args, **kwargs): + def map_call_with_kwargs(self, + expr: primitives.CallWithKwargs, + *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.function, *args, **kwargs), *[self.rec(child, *args, **kwargs) for child in expr.parameters], @@ -321,38 +399,49 @@ def map_call_with_kwargs(self, expr, *args, **kwargs): for child in expr.kw_parameters.values()] )) - def map_subscript(self, expr, *args, **kwargs): + def map_subscript(self, + expr: primitives.Subscript, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine( [self.rec(expr.aggregate, *args, **kwargs), self.rec(expr.index, *args, **kwargs)]) - def map_lookup(self, expr, *args, **kwargs): + def map_lookup(self, + expr: primitives.Lookup, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.rec(expr.aggregate, *args, **kwargs) - def map_sum(self, expr, *args, **kwargs): + def map_sum(self, + expr: primitives.Sum, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(self.rec(child, *args, **kwargs) for child in expr.children) - map_product = map_sum + def map_product(self, + expr: primitives.Product, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(self.rec(child, *args, **kwargs) + for child in expr.children) - def map_quotient(self, expr, *args, **kwargs): + def map_quotient(self, + expr: primitives.Quotient, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( self.rec(expr.numerator, *args, **kwargs), self.rec(expr.denominator, *args, **kwargs))) - map_floor_div = map_quotient - map_remainder = map_quotient + def map_floor_div(self, + expr: primitives.FloorDiv, *args: P.args, **kwargs: P.kwargs) -> ResultT: + return self.combine(( + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) - def map_power(self, expr, *args, **kwargs): + def map_remainder(self, + expr: primitives.Remainder, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( - self.rec(expr.base, *args, **kwargs), - self.rec(expr.exponent, *args, **kwargs))) + self.rec(expr.numerator, *args, **kwargs), + self.rec(expr.denominator, *args, **kwargs))) - def map_polynomial(self, expr, *args, **kwargs): + def map_power(self, + expr: primitives.Power, *args: P.args, **kwargs: P.kwargs) -> ResultT: return self.combine(( - self.rec(expr.base, *args, **kwargs), - *[self.rec(coeff, *args, **kwargs) for exp, coeff in expr.data] - )) + self.rec(expr.base, *args, **kwargs), + self.rec(expr.exponent, *args, **kwargs))) def map_left_shift(self, expr, *args, **kwargs): return self.combine((