diff --git a/pytato/codegen.py b/pytato/codegen.py index 95bc5a6d5..838eaa378 100644 --- a/pytato/codegen.py +++ b/pytato/codegen.py @@ -334,8 +334,9 @@ def map_einsum(self, expr: Einsum) -> Array: args_as_pym_expr[0]) if redn_bounds: + from pytato.reductions import SumReductionOperation inner_expr = Reduce(inner_expr, - "sum", + SumReductionOperation(), redn_bounds) return IndexLambda(expr=inner_expr, diff --git a/pytato/reductions.py b/pytato/reductions.py index ed0b01269..f4f3e29f1 100644 --- a/pytato/reductions.py +++ b/pytato/reductions.py @@ -27,7 +27,11 @@ THE SOFTWARE. """ -from typing import Optional, Tuple, Union, Sequence, Dict, List +from typing import Any, Optional, Tuple, Union, Sequence, Dict, List +from abc import ABC, abstractmethod + +import numpy as np + from pytato.array import ShapeType, Array, make_index_lambda from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES import pymbolic.primitives as prim @@ -43,11 +47,97 @@ .. autofunction:: prod .. autofunction:: all .. autofunction:: any + +.. currentmodule:: pytato.reductions + +.. autoclass:: ReductionOperation +.. autoclass:: SumReductionOperation +.. autoclass:: ProductReductionOperation +.. autoclass:: MaxReductionOperation +.. autoclass:: MinReductionOperation +.. autoclass:: AllReductionOperation +.. autoclass:: AnyReductionOperation """ # }}} +class _NoValue: + pass + + +# {{{ reduction operations + +class ReductionOperation(ABC): + """ + .. automethod:: neutral_element + .. automethod:: __hash__ + .. automethod:: __eq__ + """ + + @abstractmethod + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + pass + + @abstractmethod + def __hash__(self) -> int: + pass + + @abstractmethod + def __eq__(self, other: Any) -> bool: + pass + + +class _StatelessReductionOperation(ReductionOperation): + def __hash__(self) -> int: + return hash(type(self)) + + def __eq__(self, other: Any) -> bool: + return type(self) is type(other) + + +class SumReductionOperation(_StatelessReductionOperation): + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + return 0 + + +class ProductReductionOperation(_StatelessReductionOperation): + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + return 1 + + +class MaxReductionOperation(_StatelessReductionOperation): + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + if dtype.kind == "f": + return dtype.type(float("-inf")) + elif dtype.kind == "i": + return np.iinfo(dtype).min + else: + raise TypeError(f"unknown neutral element for max and {dtype}") + + +class MinReductionOperation(_StatelessReductionOperation): + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + if dtype.kind == "f": + return dtype.type(float("inf")) + elif dtype.kind == "i": + return np.iinfo(dtype).max + else: + raise TypeError(f"unknown neutral element for min and {dtype}") + + +class AllReductionOperation(_StatelessReductionOperation): + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + return np.bool_(True) + + +class AnyReductionOperation(_StatelessReductionOperation): + def neutral_element(self, dtype: np.dtype[Any]) -> Any: + return np.bool_(False) + +# }}} + + # {{{ reductions def _normalize_reduction_axes( @@ -124,8 +214,9 @@ def _get_reduction_indices_bounds(shape: ShapeType, return indices, pmap(redn_bounds) # type: ignore -def _make_reduction_lambda(op: str, a: Array, - axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def _make_reduction_lambda(op: ReductionOperation, a: Array, + axis: Optional[Union[int, Tuple[int]]], + initial: Any) -> Array: """ Return a :class:`IndexLambda` that performs reduction over the *axis* axes of *a* with the reduction op *op*. @@ -137,9 +228,28 @@ def _make_reduction_lambda(op: str, a: Array, :arg axis: The axes over which the reduction is to be performed. If axis is *None*, perform reduction over all of *a*'s axes. """ - new_shape, axes = _normalize_reduction_axes(a.shape, axis) + new_shape, reduction_axes = _normalize_reduction_axes(a.shape, axis) del axis - indices, redn_bounds = _get_reduction_indices_bounds(a.shape, axes) + indices, redn_bounds = _get_reduction_indices_bounds(a.shape, reduction_axes) + + if initial is _NoValue: + for iax in reduction_axes: + shape_iax = a.shape[iax] + + from pytato.utils import are_shape_components_equal + if are_shape_components_equal(shape_iax, 0): + raise ValueError( + "zero-size reduction operation with no supplied " + "'initial' value") + + if isinstance(iax, Array): + raise NotImplementedError( + "cannot statically determine emptiness of " + f"reduction axis {iax} (0-based)") + + elif initial != op.neutral_element(a.dtype): + raise NotImplementedError("reduction with 'initial' not equal to the " + "neutral element") return make_index_lambda( Reduce( @@ -151,7 +261,8 @@ def _make_reduction_lambda(op: str, a: Array, a.dtype) -def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, + initial: Any = _NoValue) -> Array: """ Sums array *a*'s elements along the *axis* axes. @@ -159,11 +270,16 @@ def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: :arg axis: The axes along which the elements are to be sum-reduced. Defaults to all axes of the input array. + :arg initial: The value returned for an empty array, if supplied. + If not supplied, an :exc:`ValueError` will be raised + if the reduction is empty. + In that case, the reduction size must not be symbolic. """ - return _make_reduction_lambda("sum", a, axis) + return _make_reduction_lambda(SumReductionOperation(), a, axis, initial) -def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, *, + initial: Any = _NoValue) -> Array: """ Returns the max of array *a*'s elements along the *axis* axes. @@ -171,11 +287,16 @@ def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: :arg axis: The axes along which the elements are to be max-reduced. Defaults to all axes of the input array. + :arg initial: The value returned for an empty array, if supplied. + If not supplied, an :exc:`ValueError` will be raised + if the reduction is empty. + In that case, the reduction size must not be symbolic. """ - return _make_reduction_lambda("max", a, axis) + return _make_reduction_lambda(MaxReductionOperation(), a, axis, initial) -def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, + initial: Any = _NoValue) -> Array: """ Returns the min of array *a*'s elements along the *axis* axes. @@ -183,11 +304,16 @@ def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: :arg axis: The axes along which the elements are to be min-reduced. Defaults to all axes of the input array. + :arg initial: The value returned for an empty array, if supplied. + If not supplied, an :exc:`ValueError` will be raised + if the reduction is empty. + In that case, the reduction size must not be symbolic. """ - return _make_reduction_lambda("min", a, axis) + return _make_reduction_lambda(MinReductionOperation(), a, axis, initial) -def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, + initial: Any = 1) -> Array: """ Returns the product of array *a*'s elements along the *axis* axes. @@ -195,11 +321,15 @@ def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: :arg axis: The axes along which the elements are to be product-reduced. Defaults to all axes of the input array. + :arg initial: The value returned for an empty array, if supplied. + If not supplied, an :exc:`ValueError` will be raised + if the reduction is empty. + In that case, the reduction size must not be symbolic. """ - return _make_reduction_lambda("product", a, axis) + return _make_reduction_lambda(ProductReductionOperation(), a, axis, initial) -def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None): """ Returns the logical-and array *a*'s elements along the *axis* axes. @@ -208,10 +338,10 @@ def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: :arg axis: The axes along which the elements are to be product-reduced. Defaults to all axes of the input array. """ - return _make_reduction_lambda("all", a, axis) + return _make_reduction_lambda(AllReductionOperation(), a, axis, initial=True) -def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: +def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None): """ Returns the logical-or of array *a*'s elements along the *axis* axes. @@ -220,7 +350,7 @@ def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: :arg axis: The axes along which the elements are to be product-reduced. Defaults to all axes of the input array. """ - return _make_reduction_lambda("any", a, axis) + return _make_reduction_lambda(AnyReductionOperation(), a, axis, initial=False) # }}} diff --git a/pytato/scalar_expr.py b/pytato/scalar_expr.py index dc02a4939..d0b215e71 100644 --- a/pytato/scalar_expr.py +++ b/pytato/scalar_expr.py @@ -25,7 +25,8 @@ """ from numbers import Number -from typing import Any, Union, Mapping, FrozenSet, Set, Tuple, Optional +from typing import ( + Any, Union, Mapping, FrozenSet, Set, Tuple, Optional, TYPE_CHECKING) from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as IdentityMapperBase) @@ -44,6 +45,10 @@ import numpy as np import re +if TYPE_CHECKING: + from pytato.reductions import ReductionOperation + + __doc__ = """ .. currentmodule:: pytato.scalar_expr @@ -232,7 +237,7 @@ class Reduce(ExpressionBase): .. attribute:: op - One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``,``"all"``, ``"any"``. + A :class:`pytato.reductions.ReductionOperation`. .. attribute:: bounds @@ -240,13 +245,12 @@ class Reduce(ExpressionBase): identifying half-open bounds intervals. Must be hashable. """ inner_expr: ScalarExpression - op: str + op: ReductionOperation bounds: Mapping[str, Tuple[ScalarExpression, ScalarExpression]] - def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None: + def __init__(self, inner_expr: ScalarExpression, + op: ReductionOperation, bounds: Any) -> None: self.inner_expr = inner_expr - if op not in {"sum", "product", "max", "min", "all", "any"}: - raise ValueError(f"unsupported op: {op}") self.op = op self.bounds = bounds @@ -256,7 +260,7 @@ def __hash__(self) -> int: tuple(self.bounds.keys()), tuple(self.bounds.values()))) - def __getinitargs__(self) -> Tuple[ScalarExpression, str, Any]: + def __getinitargs__(self) -> Tuple[ScalarExpression, ReductionOperation, Any]: return (self.inner_expr, self.op, self.bounds) mapper_method = "map_reduce" diff --git a/pytato/target/loopy/codegen.py b/pytato/target/loopy/codegen.py index 820e19886..6c7b13e30 100644 --- a/pytato/target/loopy/codegen.py +++ b/pytato/target/loopy/codegen.py @@ -48,6 +48,7 @@ from pytato.loopy import LoopyCall from pytato.tags import ImplStored, _BaseNameTag, Named, PrefixNamed from pytools.tag import Tag +import pytato.reductions as red # set in doc/conf.py if getattr(sys, "PYTATO_BUILDING_SPHINX_DOCS", False): @@ -537,13 +538,13 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef": REDUCTION_INDEX_RE = re.compile("_r(0|([1-9][0-9]*))") # Maps Pytato reduction types to the corresponding Loopy reduction types. -PYTATO_REDUCTION_TO_LOOPY_REDUCTION = { - "sum": "sum", - "product": "product", - "max": "max", - "min": "min", - "all": "all", - "any": "any", +PYTATO_REDUCTION_TO_LOOPY_REDUCTION: Mapping[Type[red.ReductionOperation], str] = { + red.SumReductionOperation: "sum", + red.ProductReductionOperation: "product", + red.MaxReductionOperation: "max", + red.MinReductionOperation: "min", + red.AllReductionOperation: "all", + red.AnyReductionOperation: "any", } @@ -620,8 +621,13 @@ def map_reduce(self, expr: scalar_expr.Reduce, from loopy.symbolic import Reduction as LoopyReduction state = prstnt_ctx.state + try: + loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(expr.op)] + except KeyError: + raise NotImplementedError(expr.op) + unique_names_mapping = { - old_name: state.var_name_gen(f"_pt_{expr.op}" + old_name) + old_name: state.var_name_gen(f"_pt_{loopy_redn}" + old_name) for old_name in expr.bounds} inner_expr = loopy_substitute(expr.inner_expr, @@ -633,11 +639,6 @@ def map_reduce(self, expr: scalar_expr.Reduce, inner_expr = self.rec(inner_expr, prstnt_ctx, local_ctx.copy(reduction_bounds=new_bounds)) - try: - loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[expr.op] - except KeyError: - raise NotImplementedError(expr.op) - inner_expr = LoopyReduction(loopy_redn, tuple(unique_names_mapping.values()), inner_expr)