diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index 645a57e31..c64c2ea67 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -28,6 +28,7 @@ from loopy.codegen.control import build_loop_nest from loopy.codegen.result import merge_codegen_results from loopy.diagnostic import LoopyError, warn +from loopy.symbolic import flatten # {{{ conditional-reducing slab decomposition @@ -309,7 +310,7 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func, codegen_state = codegen_state.intersect(slab) from loopy.symbolic import pw_aff_to_expr - hw_axis_expr = hw_axis_expr + pw_aff_to_expr(lower_bound) + hw_axis_expr = flatten(hw_axis_expr + pw_aff_to_expr(lower_bound)) # }}} diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 22dbd3bf5..6727423a8 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -54,6 +54,7 @@ ) from pymbolic.mapper.dependency import CachedDependencyMapper as DependencyMapperBase from pymbolic.mapper.evaluator import CachedEvaluationMapper as EvaluationMapperBase +from pymbolic.mapper.flattener import FlattenMapper as FlattenMapperBase from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase from pymbolic.mapper.substitutor import ( CachedSubstitutionMapper as SubstitutionMapperBase, @@ -195,6 +196,14 @@ def map_resolved_function(self, expr, *args, **kwargs): map_fortran_division = IdentityMapperBase.map_quotient +class FlattenMapper(FlattenMapperBase, IdentityMapperMixin): + pass + + +def flatten(expr: ExpressionT) -> ExpressionT: + return FlattenMapper()(expr) + + class IdentityMapper(IdentityMapperBase, IdentityMapperMixin): pass @@ -1833,7 +1842,7 @@ def aff_to_expr(aff: isl.Aff) -> ExpressionT: if coeff: result += coeff*aff_to_expr(aff.get_div(i)) - return result // denom + return flatten(result // denom) def pw_aff_to_expr(pw_aff: isl.PwAff, int_ok: bool = False) -> ExpressionT: @@ -2178,14 +2187,17 @@ def qpolynomial_to_expr(qpoly): assert all(isinstance(num, int) for num in numerators) assert isinstance(common_denominator, int) + # FIXME: Delete if in favor of the general case once we depend on pymbolic 2024.1. if common_denominator == 1: - return sum(num * monomial + res = sum(num * monomial for num, monomial in zip(numerators, monomials)) else: - return FloorDiv(sum(num * monomial + res = FloorDiv(sum(num * monomial for num, monomial in zip(numerators, monomials)), common_denominator) + return flatten(res) + # }}}