Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add EqualityMapper to follow pymbolic #607

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 67 additions & 22 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
CSECachingMapperMixin,
)
import immutables
from pymbolic.mapper.equality import (
EqualityMapper as EqualityMapperBase)
from pymbolic.mapper.evaluator import \
CachedEvaluationMapper as EvaluationMapperBase
from pymbolic.mapper.substitutor import \
Expand Down Expand Up @@ -502,6 +504,60 @@ def map_substitution(self, name, rule, arguments):

return self.rec(expr)


class EqualityMapper(EqualityMapperBase):
def map_loopy_function_identifier(self, expr, other) -> bool:
return True

def map_reduction(self, expr, other) -> bool:
return (
expr.operation == other.operation
and expr.allow_simultaneous == other.allow_simultaneous
and self.rec(expr.expr, other.expr)
and all(iname == other_iname
for iname, other_iname in zip(expr.inames, other.inames)))

def map_group_hw_index(self, expr, other) -> bool:
return expr.axis == other.axis

map_local_hw_index = map_group_hw_index

def map_linear_subscript(self, expr, other) -> bool:
return (
self.rec(expr.index, other.index)
and self.rec(expr.aggregate, other.aggregate))

def map_rule_argument(self, expr, other) -> bool:
return expr.index == other.index

def map_resolved_function(self, expr, other) -> bool:
return self.rec(expr.function, other.function)

def map_sub_array_ref(self, expr, other) -> bool:
return (
len(expr.swept_inames) == len(other.swept_inames)
and self.rec(expr.subscript, other.subscript)
and all(self.rec(iname, other_iname)
for iname, other_iname in zip(
expr.swept_inames,
other.swept_inames))
)

def map_tagged_variable(self, expr, other) -> bool:
return (
expr.name == other.name
and all(tag == other_tag
for tag, other_tag in zip(expr.tags, other.tags))
)

def map_type_cast(self, expr, other) -> bool:
return (
expr.type == other.type
and self.rec(expr.child, other.child))

def map_fortran_division(self, expr, other) -> bool:
return self.map_quotient(expr, other)

# }}}


Expand All @@ -515,15 +571,18 @@ def stringifier(self):
def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

def make_equality_mapper(self):
return EqualityMapper()


class Literal(LoopyExpressionBase):
"""A literal to be used during code generation.

.. note::

Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, s):
Expand All @@ -543,8 +602,8 @@ class ArrayLiteral(LoopyExpressionBase):
.. note::

Only used in the output of
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""

def __init__(self, children):
Expand Down Expand Up @@ -573,8 +632,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
.. note::

Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
(and similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_group_hw_index"

Expand All @@ -584,8 +643,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
.. note::

Only used in the output of
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in Loopy source representation.
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
similar mappers). Not for use in :mod:`loopy` source representation.
"""
mapper_method = "map_local_hw_index"

Expand Down Expand Up @@ -792,12 +851,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.operation, self.inames, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.operation == self.operation
and other.inames == self.inames
and other.expr == self.expr)

@property
def is_tuple_typed(self):
return self.operation.arg_count > 1
Expand Down Expand Up @@ -994,14 +1047,6 @@ def __getinitargs__(self):
def get_hash(self):
return hash((self.__class__, self.swept_inames, self.subscript))

def is_equal(self, other):
"""
Returns *True* iff the sub-array refs have identical expressions.
"""
return (other.__class__ == self.__class__
and other.subscript == self.subscript
and other.swept_inames == self.swept_inames)

def make_stringifier(self, originating_stringifier=None):
return StringifyMapper()

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
git+https://github.com/inducer/islpy.git#egg=islpy
git+https://github.com/inducer/cgen.git#egg=cgen
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
git+https://github.com/inducer/genpy.git#egg=genpy
git+https://github.com/inducer/codepy.git#egg=codepy

Expand Down