Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare for non-simplifying pymbolic #871

Merged
merged 3 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ on:
schedule:
- cron: '17 3 * * 0'

concurrency:
group: ${{ github.head_ref || github.ref_name }}
cancel-in-progress: true

jobs:
ruff:
name: Ruff
Expand Down
3 changes: 2 additions & 1 deletion loopy/codegen/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from loopy.codegen.control import build_loop_nest
from loopy.codegen.result import merge_codegen_results
from loopy.diagnostic import LoopyError, warn
from loopy.symbolic import flatten


# {{{ conditional-reducing slab decomposition
Expand Down Expand Up @@ -309,7 +310,7 @@ def set_up_hw_parallel_loops(codegen_state, schedule_index, next_func,
codegen_state = codegen_state.intersect(slab)

from loopy.symbolic import pw_aff_to_expr
hw_axis_expr = hw_axis_expr + pw_aff_to_expr(lower_bound)
hw_axis_expr = flatten(hw_axis_expr + pw_aff_to_expr(lower_bound))

# }}}

Expand Down
4 changes: 2 additions & 2 deletions loopy/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

import numpy as np

from pymbolic.mapper import RecursiveMapper
from pymbolic.mapper import Mapper

from loopy.codegen import UnvectorizableError
from loopy.diagnostic import LoopyError
Expand Down Expand Up @@ -55,7 +55,7 @@ def dtype_to_type_context(target, dtype):

# {{{ vectorizability checker

class VectorizabilityChecker(RecursiveMapper):
class VectorizabilityChecker(Mapper):
"""The return value from this mapper is a :class:`bool` indicating whether
the result of the expression is vectorized along :attr:`vec_iname`.
If the expression is not vectorizable, the mapper raises
Expand Down
18 changes: 15 additions & 3 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
)
from pymbolic.mapper.dependency import CachedDependencyMapper as DependencyMapperBase
from pymbolic.mapper.evaluator import CachedEvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.flattener import FlattenMapper as FlattenMapperBase
from pymbolic.mapper.stringifier import StringifyMapper as StringifyMapperBase
from pymbolic.mapper.substitutor import (
CachedSubstitutionMapper as SubstitutionMapperBase,
Expand Down Expand Up @@ -195,6 +196,14 @@ def map_resolved_function(self, expr, *args, **kwargs):
map_fortran_division = IdentityMapperBase.map_quotient


class FlattenMapper(FlattenMapperBase, IdentityMapperMixin):
pass


def flatten(expr: ExpressionT) -> ExpressionT:
return FlattenMapper()(expr)


class IdentityMapper(IdentityMapperBase, IdentityMapperMixin):
pass

Expand Down Expand Up @@ -1833,7 +1842,7 @@ def aff_to_expr(aff: isl.Aff) -> ExpressionT:
if coeff:
result += coeff*aff_to_expr(aff.get_div(i))

return result // denom
return flatten(result // denom)


def pw_aff_to_expr(pw_aff: isl.PwAff, int_ok: bool = False) -> ExpressionT:
Expand Down Expand Up @@ -2178,14 +2187,17 @@ def qpolynomial_to_expr(qpoly):
assert all(isinstance(num, int) for num in numerators)
assert isinstance(common_denominator, int)

# FIXME: Delete if in favor of the general case once we depend on pymbolic 2024.1.
if common_denominator == 1:
return sum(num * monomial
res = sum(num * monomial
for num, monomial in zip(numerators, monomials))
else:
return FloorDiv(sum(num * monomial
res = FloorDiv(sum(num * monomial
for num, monomial in zip(numerators, monomials)),
common_denominator)

return flatten(res)

# }}}


Expand Down
9 changes: 4 additions & 5 deletions loopy/target/c/codegen/expression.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import islpy as isl
import pymbolic.primitives as p
from pymbolic import var
from pymbolic.mapper import IdentityMapper, RecursiveMapper
from pymbolic.mapper import IdentityMapper, Mapper
from pymbolic.mapper.stringifier import (
PREC_BITWISE_AND,
PREC_BITWISE_OR,
Expand Down Expand Up @@ -124,9 +124,8 @@ def wrap_in_typecast(self, actual_type: LoopyType, needed_type: LoopyType, s):

return s

def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None):
result = RecursiveMapper.rec(self, expr, type_context)

def rec(self, expr, type_context=None, needed_type: Optional[LoopyType] = None): # type: ignore[override]
result = Mapper.rec(self, expr, type_context)
if needed_type is None:
return result
else:
Expand Down Expand Up @@ -604,7 +603,7 @@ def map_nan(self, expr, type_context):

# {{{ C expression to code mapper

class CExpressionToCodeMapper(RecursiveMapper):
class CExpressionToCodeMapper(Mapper):

# {{{ helpers

Expand Down
Loading