Skip to content

Commit

Permalink
add EqualityMapper to follow pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl committed May 14, 2022
1 parent 9c1e37d commit 182b6ca
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 27 deletions.
101 changes: 75 additions & 26 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
CSECachingMapperMixin,
)
import immutables
from pymbolic.mapper.equality import (
EqualityMapper as EqualityMapperBase)
from pymbolic.mapper.evaluator import \
CachedEvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.substitutor import \
Expand All @@ -60,8 +62,9 @@

from pymbolic.parser import Parser as ParserBase
from loopy.diagnostic import LoopyError
from loopy.diagnostic import (ExpressionToAffineConversionError,
UnableToDetermineAccessRangeError)
from loopy.diagnostic import (
ExpressionToAffineConversionError,
UnableToDetermineAccessRangeError)


import islpy as isl
Expand Down Expand Up @@ -117,8 +120,11 @@ def map_literal(self, expr, *args, **kwargs):
return expr

def map_array_literal(self, expr, *args, **kwargs):
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
for ch in expr.children))
children = [self.rec(ch, *args, **kwargs) for ch in expr.children]
if all(ch is orig for ch, orig in zip(children, expr.children)):
return expr

return type(expr)(tuple(children))

def map_group_hw_index(self, expr, *args, **kwargs):
return expr
Expand Down Expand Up @@ -501,6 +507,60 @@ def map_substitution(self, name, rule, arguments):

return self.rec(expr)


class EqualityMapper(EqualityMapperBase):
def map_loopy_function_identifier(self, expr, other) -> bool:
return True

def map_reduction(self, expr, other) -> bool:
return (
expr.operation == other.operation
and expr.allow_simultaneous == other.allow_simultaneous
and self.rec(expr.expr, other.expr)
and all(iname == other_iname
for iname, other_iname in zip(expr.inames, other.inames)))

def map_group_hw_index(self, expr, other) -> bool:
return expr.axis == other.axis

map_local_hw_index = map_group_hw_index

def map_linear_subscript(self, expr, other) -> bool:
return (
self.rec(expr.index, other.index)
and self.rec(expr.aggregate, other.aggregate))

def map_rule_argument(self, expr, other) -> bool:
return expr.index == other.index

def map_resolved_function(self, expr, other) -> bool:
return self.rec(expr.function, other.function)

def map_sub_array_ref(self, expr, other) -> bool:
return (
len(expr.swept_inames) == len(other.swept_inames)
and self.rec(expr.subscript, other.subscript)
and all(self.rec(iname, other_iname)
for iname, other_iname in zip(
expr.swept_inames,
other.swept_inames))
)

def map_tagged_variable(self, expr, other) -> bool:
return (
expr.name == other.name
and all(tag == other_tag
for tag, other_tag in zip(expr.tags, other.tags))
)

def map_type_cast(self, expr, other) -> bool:
return (
expr.type == other.type
and self.rec(expr.child, other.child))

def map_fortran_division(self, expr, other) -> bool:
return self.map_quotient(expr, other)

# }}}


Expand All @@ -514,15 +574,18 @@ def stringifier(self):
def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

def make_equality_mapper(self):
return EqualityMapper()


class Literal(LoopyExpressionBase):
"""A literal to be used during code generation.
.. note::
Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, s):
Expand All @@ -542,8 +605,8 @@ class ArrayLiteral(LoopyExpressionBase):
.. note::
Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, children):
Expand Down Expand Up @@ -572,8 +635,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
.. note::
Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_group_hw_index"

Expand All @@ -583,8 +646,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
.. note::
Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_local_hw_index"

Expand Down Expand Up @@ -791,12 +854,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.operation, self.inames, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.operation == self.operation
and other.inames == self.inames
and other.expr == self.expr)

@property
def is_tuple_typed(self):
return self.operation.arg_count > 1
Expand Down Expand Up @@ -994,14 +1051,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.swept_inames, self.subscript))

def is_equal(self, other):
"""
Returns *True* iff the sub-array refs have identical expressions.
"""
return (other.__class__ == self.__class__
and other.subscript == self.subscript
and other.swept_inames == self.swept_inames)

def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/cgen.git#egg=cgen
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/codepy.git#egg=codepy

Expand Down

0 comments on commit 182b6ca

Please sign in to comment.