Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicit scopes to CSE #150

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pymbolic/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,12 @@ def map_constant(self, expr):

def wrap_intermediate(x):
if len(x) > 1:
from pymbolic.primitives import CommonSubexpression
from pymbolic.primitives import CommonSubexpression, cse_scope

result = numpy.empty(len(x), dtype=object)
for i, x_i in enumerate(x):
result[i] = CommonSubexpression(x_i)
result[i] = CommonSubexpression(x_i, scope=cse_scope.EVALUATION)

return result
else:
return x
Expand Down
3 changes: 2 additions & 1 deletion pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def map_common_subexpression(self, expr):
if type(result) is prim.CommonSubexpression:
result = result.child

return type(expr)(result, expr.prefix, **expr.get_extra_properties())
return type(expr)(result, expr.prefix, expr.scope,
**expr.get_extra_properties())

def map_substitution(self, expr):
return type(expr)(
Expand Down
3 changes: 2 additions & 1 deletion pymbolic/interop/symengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def not_supported(self, expr):
self.rec(expr.args[0]), sympy_expr.prefix, sympy_expr.scope)
elif isinstance(expr, symengine.Function) and \
self.function_name(expr) == "CSE":
return prim.CommonSubexpression(self.rec(expr.args[0]))
return prim.CommonSubexpression(
self.rec(expr.args[0]), scope=prim.cse_scope.EVALUATION)
return SympyLikeToPymbolicMapper.not_supported(self, expr)

# }}}
Expand Down
4 changes: 2 additions & 2 deletions pymbolic/mapper/c_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ class CCodeMapper(SimplifyingSortingStringifyMapper):
.. doctest::

>>> import pymbolic.primitives as p
>>> CSE = p.make_common_subexpression
>>> x = p.Variable("x")
>>> CSE = p.CommonSubexpression
>>> u = CSE(3*x**2-5, "u")
>>> expr = u/(u+3)*(u+5)
>>> print(expr)
(CSE(3*x**2 + -5) / (CSE(3*x**2 + -5) + 3))*(CSE(3*x**2 + -5) + 5)

Notice that if we were to directly generate code from this, the
subexpression *u* would be evaluated multiple times.
subexpression *u* would not be evaluated multiple times.
Comment on lines 54 to +55
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

? Is this right?


.. doctest::

Expand Down
7 changes: 3 additions & 4 deletions pymbolic/mapper/cse_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


from pymbolic.mapper import IdentityMapper, WalkMapper
from pymbolic.primitives import CommonSubexpression
from pymbolic.primitives import CommonSubexpression, cse_scope


class CSEWalkMapper(WalkMapper):
Expand All @@ -43,10 +43,9 @@ def __init__(self, walk_mapper):

def map_call(self, expr):
if self.subexpr_histogram.get(expr, 0) > 1:
return CommonSubexpression(expr)
return CommonSubexpression(expr, scope=cse_scope.EVALUATION)
else:
return getattr(IdentityMapper, expr.mapper_method)(
self, expr)
return getattr(IdentityMapper, expr.mapper_method)(self, expr)

map_sum = map_call
map_product = map_call
Expand Down
98 changes: 57 additions & 41 deletions pymbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,8 +777,7 @@ def __eq__(self, other) -> bool:
depr_key = (type(self), "__eq__")
if depr_key not in self._deprecation_warnings_issued:
warn(f"Expression.__eq__ is used by {self.__class__}. This is deprecated. "
"Use equality comparison supplied by expr_dataclass"
"instead. "
"Use equality comparison supplied by expr_dataclass instead. "
"This will stop working in 2025.",
DeprecationWarning, stacklevel=2)
self._deprecation_warnings_issued.add(depr_key)
Expand Down Expand Up @@ -994,7 +993,8 @@ def {cls.__name__}_eq(self, other):

return self.__class__ == other.__class__ and {comparison}

cls.__eq__ = {cls.__name__}_eq
if {hash}:
cls.__eq__ = {cls.__name__}_eq
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also guards the __eq__ override with hash. Was playing with IntG and it also needs to overwrite equality (seems mostly used in tests?).

Not sure if it's worth adding a separate eq flag (to match dataclass)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would these better default to "__eq__" in cls.__dict__? It seems like the dataclass decorator also does similar checks to not overwrite the user implementation.



def {cls.__name__}_hash(self):
Expand Down Expand Up @@ -1024,8 +1024,9 @@ def {cls.__name__}_hash(self):
def {cls.__name__}_init_arg_names(self):
depr_key = (type(self), "init_arg_names")
if depr_key not in self._deprecation_warnings_issued:
warn("__getinitargs__ is deprecated and will be removed in 2025. "
"Use dataclasses.fields instead.",
warn("Attribute 'init_arg_names' of {cls} is deprecated and will "
"not have a default implementation starting from 2025. "
"Use 'dataclasses.fields' instead.",
DeprecationWarning, stacklevel=2)

self._deprecation_warnings_issued.add(depr_key)
Expand All @@ -1038,8 +1039,9 @@ def {cls.__name__}_init_arg_names(self):
def {cls.__name__}_getinitargs(self):
depr_key = (type(self), "__getinitargs__")
if depr_key not in self._deprecation_warnings_issued:
warn("__getinitargs__ is deprecated and will be removed in 2025. "
"Use dataclasses.fields instead.",
warn("Method '__getinitargs__' of {cls} is deprecated and will "
"not have a default implementation starting from 2025. "
"Use 'dataclasses.fields' instead.",
DeprecationWarning, stacklevel=2)

self._deprecation_warnings_issued.add(depr_key)
Expand Down Expand Up @@ -1947,89 +1949,103 @@ def is_zero(value):
return not is_nonzero(value)


def wrap_in_cse(expr, prefix=None):
def wrap_in_cse(expr: Expression,
prefix: str | None = None,
scope: str | None = None) -> Expression:
if isinstance(expr, (Variable, Subscript)):
return expr

if scope is None:
scope = cse_scope.EVALUATION

if isinstance(expr, CommonSubexpression):
if prefix is None:
return expr

if expr.prefix is None and type(expr) is CommonSubexpression:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use same logic as make_cse below.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I'm not sure what you had in mind here? At least at a first glance, these seem to have different logic: make_cse ignores the prefix when re-wrapping and wrap_in_cse ignores the scope when re-wrapping.

return CommonSubexpression(expr.child, prefix)
return CommonSubexpression(expr.child, prefix, scope)

# existing prefix wins
return expr

else:
return CommonSubexpression(expr, prefix)
return CommonSubexpression(expr, prefix, scope)


def make_common_subexpression(field, prefix=None, scope=None):
"""Wrap *field* in a :class:`CommonSubexpression` with
*prefix*. If *field* is a :mod:`numpy` object array,
each individual entry is instead wrapped. If *field* is a
:class:`pymbolic.geometric_algebra.MultiVector`, each
coefficient is individually wrapped.
def make_common_subexpression(expr: ExpressionT,
prefix: str | None = None,
scope: str | None = None) -> ExpressionT:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left scope=None here because setting it to scope=cse_scope.EVALUATION shows the actual string in the docs (i.e. "pymbolic_eval"), which are not shown in the docs for cse_scope itself.. very confusing :\

"""Wrap *expr* in a :class:`CommonSubexpression` with *prefix*.

See :class:`CommonSubexpression` for the meaning of *prefix*
and *scope*.
If *expr* is a :mod:`numpy` object array, each individual entry is instead
wrapped. If *expr* is a :class:`pymbolic.geometric_algebra.MultiVector`, each
coefficient is individually wrapped. In general, the function tries to avoid
re-wrapping existing :class:`CommonSubexpression` if the same scope is given.

See :class:`CommonSubexpression` for the meaning of *prefix* and *scope*. The
scope defaults to :attr:`cse_scope.EVALUATION`.
"""

if isinstance(field, CommonSubexpression) and (
scope is None or scope == cse_scope.EVALUATION
or field.scope == scope):
if scope is None:
scope = cse_scope.EVALUATION

if (isinstance(expr, CommonSubexpression)
and (scope == cse_scope.EVALUATION or expr.scope == scope)):
# Don't re-wrap
return field
return expr

try:
import numpy
have_obj_array = (
isinstance(field, numpy.ndarray)
and field.dtype.char == "O")
logical_shape = (
field.shape
if isinstance(field, numpy.ndarray)
else ())

if isinstance(expr, numpy.ndarray) and expr.dtype.char == "O":
is_obj_array = True
logical_shape = expr.shape
else:
is_obj_array = False
logical_shape = ()
except ImportError:
have_obj_array = False
is_obj_array = False
logical_shape = ()

from pymbolic.geometric_algebra import MultiVector
if isinstance(field, MultiVector):

if isinstance(expr, MultiVector):
new_data = {}
for bits, coeff in field.data.items():
for bits, coeff in expr.data.items():
if prefix is not None:
blade_str = field.space.blade_bits_to_str(bits, "")
blade_str = expr.space.blade_bits_to_str(bits, "")
component_prefix = prefix+"_"+blade_str
else:
component_prefix = None

new_data[bits] = make_common_subexpression(
coeff, component_prefix, scope)

return MultiVector(new_data, field.space)
return MultiVector(new_data, expr.space)

elif is_obj_array and logical_shape != ():
assert isinstance(expr, numpy.ndarray)

elif have_obj_array and logical_shape != ():
result = numpy.zeros(logical_shape, dtype=object)
for i in numpy.ndindex(logical_shape):
if prefix is not None:
component_prefix = prefix+"_".join(str(i_i) for i_i in i)
else:
component_prefix = None

if is_constant(field[i]):
result[i] = field[i]
if is_constant(expr[i]):
result[i] = expr[i]
else:
result[i] = make_common_subexpression(
field[i], component_prefix, scope)
expr[i], component_prefix, scope)

return result

else:
if is_constant(field):
return field
if is_constant(expr):
return expr
else:
return CommonSubexpression(field, prefix, scope)
return CommonSubexpression(expr, prefix, scope)


def make_sym_vector(name, components, var_factory=Variable):
Expand Down
6 changes: 3 additions & 3 deletions test/test_pymbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,7 @@ def test_flop_counter():
y = prim.Variable("y")
z = prim.Variable("z")

subexpr = prim.CommonSubexpression(3 * (x**2 + y + z))
subexpr = prim.make_common_subexpression(3 * (x**2 + y + z))
expr = 3*subexpr + subexpr

from pymbolic.mapper.flop_counter import CSEAwareFlopCounter, FlopCounter
Expand Down Expand Up @@ -795,7 +795,7 @@ def test_diff_cse():
m = prim.Variable("math")

x = prim.Variable("x")
cse = prim.CommonSubexpression(x**2 + 1)
cse = prim.make_common_subexpression(x**2 + 1)
expr = m.attr("exp")(cse)*m.attr("sin")(cse**2)

diff_result = differentiate(expr, x)
Expand Down Expand Up @@ -1009,7 +1009,7 @@ def test_nodecount():
y = prim.Variable("y")
z = prim.Variable("z")

subexpr = prim.CommonSubexpression(4 * (x**2 + y + z))
subexpr = prim.make_common_subexpression(4 * (x**2 + y + z))
expr = 3*subexpr + subexpr + subexpr + subexpr
expr = expr + expr + expr

Expand Down
Loading