Skip to content

Commit

Permalink
cse: add explicit scope
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed Oct 1, 2024
1 parent 03ceaa1 commit 9a62b35
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 14 deletions.
6 changes: 4 additions & 2 deletions pymbolic/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,10 +251,12 @@ def map_constant(self, expr):

def wrap_intermediate(x):
if len(x) > 1:
from pymbolic.primitives import CommonSubexpression
from pymbolic.primitives import CommonSubexpression, cse_scope

result = numpy.empty(len(x), dtype=object)
for i, x_i in enumerate(x):
result[i] = CommonSubexpression(x_i)
result[i] = CommonSubexpression(x_i, scope=cse_scope.EVALUATION)

return result
else:
return x
Expand Down
3 changes: 2 additions & 1 deletion pymbolic/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ def map_common_subexpression(self, expr):
if type(result) is prim.CommonSubexpression:
result = result.child

return type(expr)(result, expr.prefix, **expr.get_extra_properties())
return type(expr)(result, expr.prefix, expr.scope,
**expr.get_extra_properties())

def map_substitution(self, expr):
return type(expr)(
Expand Down
3 changes: 2 additions & 1 deletion pymbolic/interop/symengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def not_supported(self, expr):
self.rec(expr.args[0]), sympy_expr.prefix, sympy_expr.scope)
elif isinstance(expr, symengine.Function) and \
self.function_name(expr) == "CSE":
return prim.CommonSubexpression(self.rec(expr.args[0]))
return prim.CommonSubexpression(
self.rec(expr.args[0]), scope=prim.cse_scope.EVALUATION)
return SympyLikeToPymbolicMapper.not_supported(self, expr)

# }}}
Expand Down
2 changes: 1 addition & 1 deletion pymbolic/mapper/c_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class CCodeMapper(SimplifyingSortingStringifyMapper):
>>> import pymbolic.primitives as p
>>> x = p.Variable("x")
>>> CSE = p.CommonSubexpression
>>> u = CSE(3*x**2-5, "u")
>>> u = CSE(3*x**2-5, "u", p.cse_scope.EVALUATION)
>>> expr = u/(u+3)*(u+5)
>>> print(expr)
(CSE(3*x**2 + -5) / (CSE(3*x**2 + -5) + 3))*(CSE(3*x**2 + -5) + 5)
Expand Down
7 changes: 3 additions & 4 deletions pymbolic/mapper/cse_tagger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


from pymbolic.mapper import IdentityMapper, WalkMapper
from pymbolic.primitives import CommonSubexpression
from pymbolic.primitives import CommonSubexpression, cse_scope


class CSEWalkMapper(WalkMapper):
Expand All @@ -43,10 +43,9 @@ def __init__(self, walk_mapper):

def map_call(self, expr):
if self.subexpr_histogram.get(expr, 0) > 1:
return CommonSubexpression(expr)
return CommonSubexpression(expr, scope=cse_scope.EVALUATION)
else:
return getattr(IdentityMapper, expr.mapper_method)(
self, expr)
return getattr(IdentityMapper, expr.mapper_method)(self, expr)

map_sum = map_call
map_product = map_call
Expand Down
4 changes: 2 additions & 2 deletions pymbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1796,13 +1796,13 @@ def wrap_in_cse(expr, prefix=None):
if prefix is None:
return expr
if expr.prefix is None and type(expr) is CommonSubexpression:
return CommonSubexpression(expr.child, prefix)
return CommonSubexpression(expr.child, prefix, cse_scope.EVALUATION)

# existing prefix wins
return expr

else:
return CommonSubexpression(expr, prefix)
return CommonSubexpression(expr, prefix, cse_scope.EVALUATION)


def make_common_subexpression(field, prefix=None, scope=None):
Expand Down
8 changes: 5 additions & 3 deletions test/test_pymbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,7 +723,8 @@ def test_flop_counter():
y = prim.Variable("y")
z = prim.Variable("z")

subexpr = prim.CommonSubexpression(3 * (x**2 + y + z))
subexpr = prim.CommonSubexpression(3 * (x**2 + y + z),
scope=prim.cse_scope.EVALUATION)
expr = 3*subexpr + subexpr

from pymbolic.mapper.flop_counter import CSEAwareFlopCounter, FlopCounter
Expand Down Expand Up @@ -795,7 +796,7 @@ def test_diff_cse():
m = prim.Variable("math")

x = prim.Variable("x")
cse = prim.CommonSubexpression(x**2 + 1)
cse = prim.CommonSubexpression(x**2 + 1, scope=prim.cse_scope.EVALUATION)
expr = m.attr("exp")(cse)*m.attr("sin")(cse**2)

diff_result = differentiate(expr, x)
Expand Down Expand Up @@ -1009,7 +1010,8 @@ def test_nodecount():
y = prim.Variable("y")
z = prim.Variable("z")

subexpr = prim.CommonSubexpression(4 * (x**2 + y + z))
subexpr = prim.CommonSubexpression(4 * (x**2 + y + z),
scope=prim.cse_scope.EVALUATION)
expr = 3*subexpr + subexpr + subexpr + subexpr
expr = expr + expr + expr

Expand Down

0 comments on commit 9a62b35

Please sign in to comment.