Skip to content

Commit

Permalink
Introduce ReductionOperation class, accept 'initial' in reductions
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Dec 29, 2021
1 parent 45bd2a9 commit b3e359c
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 36 deletions.
3 changes: 2 additions & 1 deletion pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ def map_einsum(self, expr: Einsum) -> Array:
args_as_pym_expr[0])

if redn_bounds:
from pytato.reductions import SumReductionOperation
inner_expr = Reduce(inner_expr,
"sum",
SumReductionOperation(),
redn_bounds)

return IndexLambda(expr=inner_expr,
Expand Down
162 changes: 147 additions & 15 deletions pytato/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
THE SOFTWARE.
"""

from typing import Optional, Tuple, Union, Sequence, Dict, List
from typing import Any, Optional, Tuple, Union, Sequence, Dict, List
from abc import ABC, abstractmethod

import numpy as np

from pytato.array import ShapeType, Array, make_index_lambda
from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES
import pymbolic.primitives as prim
Expand All @@ -43,11 +47,97 @@
.. autofunction:: prod
.. autofunction:: all
.. autofunction:: any
.. currentmodule:: pytato.reductions
.. autoclass:: ReductionOperation
.. autoclass:: SumReductionOperation
.. autoclass:: ProductReductionOperation
.. autoclass:: MaxReductionOperation
.. autoclass:: MinReductionOperation
.. autoclass:: AllReductionOperation
.. autoclass:: AnyReductionOperation
"""

# }}}


class _NoValue:
pass


# {{{ reduction operations

class ReductionOperation(ABC):
"""
.. automethod:: neutral_element
.. automethod:: __hash__
.. automethod:: __eq__
"""

@abstractmethod
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
pass

@abstractmethod
def __hash__(self) -> int:
pass

@abstractmethod
def __eq__(self, other: Any) -> bool:
pass


class _StatelessReductionOperation(ReductionOperation):
def __hash__(self) -> int:
return hash(type(self))

def __eq__(self, other: Any) -> bool:
return type(self) is type(other)


class SumReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return 0


class ProductReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return 1


class MaxReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
if dtype.kind == "f":
return dtype.type(float("-inf"))
elif dtype.kind == "i":
return np.iinfo(dtype).min
else:
raise TypeError(f"unknown neutral element for max and {dtype}")


class MinReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
if dtype.kind == "f":
return dtype.type(float("inf"))
elif dtype.kind == "i":
return np.iinfo(dtype).max
else:
raise TypeError(f"unknown neutral element for min and {dtype}")


class AllReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return np.bool_(True)


class AnyReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return np.bool_(False)

# }}}


# {{{ reductions

def _normalize_reduction_axes(
Expand Down Expand Up @@ -124,8 +214,9 @@ def _get_reduction_indices_bounds(shape: ShapeType,
return indices, pmap(redn_bounds) # type: ignore


def _make_reduction_lambda(op: str, a: Array,
axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def _make_reduction_lambda(op: ReductionOperation, a: Array,
axis: Optional[Union[int, Tuple[int]]],
initial: Any) -> Array:
"""
Return a :class:`IndexLambda` that performs reduction over the *axis* axes
of *a* with the reduction op *op*.
Expand All @@ -137,9 +228,28 @@ def _make_reduction_lambda(op: str, a: Array,
:arg axis: The axes over which the reduction is to be performed. If axis is
*None*, perform reduction over all of *a*'s axes.
"""
new_shape, axes = _normalize_reduction_axes(a.shape, axis)
new_shape, reduction_axes = _normalize_reduction_axes(a.shape, axis)
del axis
indices, redn_bounds = _get_reduction_indices_bounds(a.shape, axes)
indices, redn_bounds = _get_reduction_indices_bounds(a.shape, reduction_axes)

if initial is _NoValue:
for iax in reduction_axes:
shape_iax = a.shape[iax]

from pytato.utils import are_shape_components_equal
if are_shape_components_equal(shape_iax, 0):
raise ValueError(
"zero-size reduction operation with no supplied "
"'initial' value")

if isinstance(iax, Array):
raise NotImplementedError(
"cannot statically determine emptiness of "
f"reduction axis {iax} (0-based)")

elif initial != op.neutral_element(a.dtype):
raise NotImplementedError("reduction with 'initial' not equal to the "
"neutral element")

return make_index_lambda(
Reduce(
Expand All @@ -151,52 +261,74 @@ def _make_reduction_lambda(op: str, a: Array,
a.dtype)


def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
initial: Any = 0) -> Array:
"""
Sums array *a*'s elements along the *axis* axes.
:arg a: The :class:`pytato.Array` on which to perform the reduction.
:arg axis: The axes along which the elements are to be sum-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
"""
return _make_reduction_lambda("sum", a, axis)
return _make_reduction_lambda(SumReductionOperation(), a, axis, initial)


def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, *,
initial: Any = _NoValue) -> Array:
"""
Returns the max of array *a*'s elements along the *axis* axes.
:arg a: The :class:`pytato.Array` on which to perform the reduction.
:arg axis: The axes along which the elements are to be max-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
If not supplied, an :exc:`ValueError` will be raised
if the reduction is empty.
In that case, the reduction size must not be symbolic.
"""
return _make_reduction_lambda("max", a, axis)
return _make_reduction_lambda(MaxReductionOperation(), a, axis, initial)


def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
initial: Any = _NoValue) -> Array:
"""
Returns the min of array *a*'s elements along the *axis* axes.
:arg a: The :class:`pytato.Array` on which to perform the reduction.
:arg axis: The axes along which the elements are to be min-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
If not supplied, an :exc:`ValueError` will be raised
if the reduction is empty.
In that case, the reduction size must not be symbolic.
"""
return _make_reduction_lambda("min", a, axis)
return _make_reduction_lambda(MinReductionOperation(), a, axis, initial)


def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
initial: Any = 1) -> Array:
"""
Returns the product of array *a*'s elements along the *axis* axes.
:arg a: The :class:`pytato.Array` on which to perform the reduction.
:arg axis: The axes along which the elements are to be product-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
"""
return _make_reduction_lambda("product", a, axis)
return _make_reduction_lambda(ProductReductionOperation(), a, axis, initial)


def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
Expand All @@ -208,7 +340,7 @@ def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
:arg axis: The axes along which the elements are to be product-reduced.
Defaults to all axes of the input array.
"""
return _make_reduction_lambda("all", a, axis)
return _make_reduction_lambda(AllReductionOperation(), a, axis, initial=True)


def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
Expand All @@ -220,7 +352,7 @@ def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
:arg axis: The axes along which the elements are to be product-reduced.
Defaults to all axes of the input array.
"""
return _make_reduction_lambda("any", a, axis)
return _make_reduction_lambda(AnyReductionOperation(), a, axis, initial=False)

# }}}

Expand Down
18 changes: 11 additions & 7 deletions pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"""

from numbers import Number
from typing import Any, Union, Mapping, FrozenSet, Set, Tuple, Optional
from typing import (
Any, Union, Mapping, FrozenSet, Set, Tuple, Optional, TYPE_CHECKING)

from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as
IdentityMapperBase)
Expand All @@ -44,6 +45,10 @@
import numpy as np
import re

if TYPE_CHECKING:
from pytato.reductions import ReductionOperation


__doc__ = """
.. currentmodule:: pytato.scalar_expr
Expand Down Expand Up @@ -232,21 +237,20 @@ class Reduce(ExpressionBase):
.. attribute:: op
One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``,``"all"``, ``"any"``.
A :class:`pytato.reductions.ReductionOperation`.
.. attribute:: bounds
A mapping from reduction inames to tuples ``(lower_bound, upper_bound)``
identifying half-open bounds intervals. Must be hashable.
"""
inner_expr: ScalarExpression
op: str
op: ReductionOperation
bounds: Mapping[str, Tuple[ScalarExpression, ScalarExpression]]

def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None:
def __init__(self, inner_expr: ScalarExpression,
op: ReductionOperation, bounds: Any) -> None:
self.inner_expr = inner_expr
if op not in {"sum", "product", "max", "min", "all", "any"}:
raise ValueError(f"unsupported op: {op}")
self.op = op
self.bounds = bounds

Expand All @@ -256,7 +260,7 @@ def __hash__(self) -> int:
tuple(self.bounds.keys()),
tuple(self.bounds.values())))

def __getinitargs__(self) -> Tuple[ScalarExpression, str, Any]:
def __getinitargs__(self) -> Tuple[ScalarExpression, ReductionOperation, Any]:
return (self.inner_expr, self.op, self.bounds)

mapper_method = "map_reduce"
Expand Down
27 changes: 14 additions & 13 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pytato.loopy import LoopyCall
from pytato.tags import ImplStored, _BaseNameTag, Named, PrefixNamed
from pytools.tag import Tag
import pytato.reductions as red

# set in doc/conf.py
if getattr(sys, "PYTATO_BUILDING_SPHINX_DOCS", False):
Expand Down Expand Up @@ -537,13 +538,13 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef":
REDUCTION_INDEX_RE = re.compile("_r(0|([1-9][0-9]*))")

# Maps Pytato reduction types to the corresponding Loopy reduction types.
PYTATO_REDUCTION_TO_LOOPY_REDUCTION = {
"sum": "sum",
"product": "product",
"max": "max",
"min": "min",
"all": "all",
"any": "any",
PYTATO_REDUCTION_TO_LOOPY_REDUCTION: Mapping[Type[red.ReductionOperation], str] = {
red.SumReductionOperation: "sum",
red.ProductReductionOperation: "product",
red.MaxReductionOperation: "max",
red.MinReductionOperation: "min",
red.AllReductionOperation: "all",
red.AnyReductionOperation: "any",
}


Expand Down Expand Up @@ -620,8 +621,13 @@ def map_reduce(self, expr: scalar_expr.Reduce,
from loopy.symbolic import Reduction as LoopyReduction
state = prstnt_ctx.state

try:
loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(expr.op)]
except KeyError:
raise NotImplementedError(expr.op)

unique_names_mapping = {
old_name: state.var_name_gen(f"_pt_{expr.op}" + old_name)
old_name: state.var_name_gen(f"_pt_{loopy_redn}" + old_name)
for old_name in expr.bounds}

inner_expr = loopy_substitute(expr.inner_expr,
Expand All @@ -633,11 +639,6 @@ def map_reduce(self, expr: scalar_expr.Reduce,
inner_expr = self.rec(inner_expr, prstnt_ctx,
local_ctx.copy(reduction_bounds=new_bounds))

try:
loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[expr.op]
except KeyError:
raise NotImplementedError(expr.op)

inner_expr = LoopyReduction(loopy_redn,
tuple(unique_names_mapping.values()),
inner_expr)
Expand Down

0 comments on commit b3e359c

Please sign in to comment.