Skip to content

Commit

Permalink
Call flatten() on expressionss that are assumed to be simplified
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 22, 2024
1 parent 94d64dd commit 3009c04
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
3 changes: 2 additions & 1 deletion loopy/codegen/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from loopy.codegen.control import build_loop_nest
from loopy.codegen.result import merge_codegen_results
from loopy.diagnostic import LoopyError, warn
from loopy.symbolic import flatten


# {{{ conditional-reducing slab decomposition
Expand Down Expand Up @@ -309,7 +310,7 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
codegen_state = codegen_state.intersect(slab)

from loopy.symbolic import pw_aff_to_expr
hw_axis_expr = hw_axis_expr + pw_aff_to_expr(lower_bound)
hw_axis_expr = flatten(hw_axis_expr + pw_aff_to_expr(lower_bound))

# }}}

Expand Down
18 changes: 15 additions & 3 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from pymbolic.mapper.dependency import CachedDependencyMapper as DependencyMapperBase
from pymbolic.mapper.evaluator import CachedEvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.flattener import FlattenMapper as FlattenMapperBase
from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase
from pymbolic.mapper.substitutor import (
CachedSubstitutionMapper as SubstitutionMapperBase,
Expand Down Expand Up @@ -195,6 +196,14 @@ def map_resolved_function(self, expr, *args, **kwargs):
map_fortran_division = IdentityMapperBase.map_quotient


class FlattenMapper(FlattenMapperBase, IdentityMapperMixin):
pass


def flatten(expr: ExpressionT) -> ExpressionT:
return FlattenMapper()(expr)


class IdentityMapper(IdentityMapperBase, IdentityMapperMixin):
pass

Expand Down Expand Up @@ -1833,7 +1842,7 @@ def aff_to_expr(aff: isl.Aff) -> ExpressionT:
if coeff:
result += coeff*aff_to_expr(aff.get_div(i))

return result // denom
return flatten(result // denom)


def pw_aff_to_expr(pw_aff: isl.PwAff, int_ok: bool = False) -> ExpressionT:
Expand Down Expand Up @@ -2178,14 +2187,17 @@ def qpolynomial_to_expr(qpoly):
assert all(isinstance(num, int) for num in numerators)
assert isinstance(common_denominator, int)

# FIXME: Delete if in favor of the general case once we depend on pymbolic 2024.1.
if common_denominator == 1:
return sum(num * monomial
res = sum(num * monomial
for num, monomial in zip(numerators, monomials))
else:
return FloorDiv(sum(num * monomial
res = FloorDiv(sum(num * monomial
for num, monomial in zip(numerators, monomials)),
common_denominator)

return flatten(res)

# }}}


Expand Down

0 comments on commit 3009c04

Please sign in to comment.