From 182b6ca4e0271ddd7362396cca2b652f6ea70c54 Mon Sep 17 00:00:00 2001 From: Alexandru Fikl Date: Mon, 15 Nov 2021 22:08:25 -0600 Subject: [PATCH] add EqualityMapper to follow pymbolic --- loopy/symbolic.py | 101 ++++++++++++++++++++++++++++++++++------------ requirements.txt | 2 +- 2 files changed, 76 insertions(+), 27 deletions(-) diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 4c56016d0..19ae7eaef 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -43,6 +43,8 @@ CSECachingMapperMixin, ) import immutables +from pymbolic.mapper.equality import ( + EqualityMapper as EqualityMapperBase) from pymbolic.mapper.evaluator import \ CachedEvaluationMapper as EvaluationMapperBase from pymbolic.mapper.substitutor import \ @@ -60,8 +62,9 @@ from pymbolic.parser import Parser as ParserBase from loopy.diagnostic import LoopyError -from loopy.diagnostic import (ExpressionToAffineConversionError, - UnableToDetermineAccessRangeError) +from loopy.diagnostic import ( + ExpressionToAffineConversionError, + UnableToDetermineAccessRangeError) import islpy as isl @@ -117,8 +120,11 @@ def map_literal(self, expr, *args, **kwargs): return expr def map_array_literal(self, expr, *args, **kwargs): - return type(expr)(tuple(self.rec(ch, *args, **kwargs) - for ch in expr.children)) + children = [self.rec(ch, *args, **kwargs) for ch in expr.children] + if all(ch is orig for ch, orig in zip(children, expr.children)): + return expr + + return type(expr)(tuple(children)) def map_group_hw_index(self, expr, *args, **kwargs): return expr @@ -501,6 +507,60 @@ def map_substitution(self, name, rule, arguments): return self.rec(expr) + +class EqualityMapper(EqualityMapperBase): + def map_loopy_function_identifier(self, expr, other) -> bool: + return True + + def map_reduction(self, expr, other) -> bool: + return ( + expr.operation == other.operation + and expr.allow_simultaneous == other.allow_simultaneous + and self.rec(expr.expr, other.expr) + and all(iname == other_iname + for iname, other_iname in zip(expr.inames, other.inames))) + + def map_group_hw_index(self, expr, other) -> bool: + return expr.axis == other.axis + + map_local_hw_index = map_group_hw_index + + def map_linear_subscript(self, expr, other) -> bool: + return ( + self.rec(expr.index, other.index) + and self.rec(expr.aggregate, other.aggregate)) + + def map_rule_argument(self, expr, other) -> bool: + return expr.index == other.index + + def map_resolved_function(self, expr, other) -> bool: + return self.rec(expr.function, other.function) + + def map_sub_array_ref(self, expr, other) -> bool: + return ( + len(expr.swept_inames) == len(other.swept_inames) + and self.rec(expr.subscript, other.subscript) + and all(self.rec(iname, other_iname) + for iname, other_iname in zip( + expr.swept_inames, + other.swept_inames)) + ) + + def map_tagged_variable(self, expr, other) -> bool: + return ( + expr.name == other.name + and all(tag == other_tag + for tag, other_tag in zip(expr.tags, other.tags)) + ) + + def map_type_cast(self, expr, other) -> bool: + return ( + expr.type == other.type + and self.rec(expr.child, other.child)) + + def map_fortran_division(self, expr, other) -> bool: + return self.map_quotient(expr, other) + # }}} @@ -514,6 +574,9 @@ def stringifier(self): def make_stringifier(self, originating_stringifier=None): return StringifyMapper() + def make_equality_mapper(self): + return EqualityMapper() + class Literal(LoopyExpressionBase): """A literal to be used during code generation. @@ -521,8 +584,8 @@ class Literal(LoopyExpressionBase): .. note:: Only used in the output of - :mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` + (and similar mappers). Not for use in :mod:`loopy` source representation. """ def __init__(self, s): @@ -542,8 +605,8 @@ class ArrayLiteral(LoopyExpressionBase): .. note:: Only used in the output of - :mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` + (and similar mappers). Not for use in :mod:`loopy` source representation. """ def __init__(self, children): @@ -572,8 +635,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex): .. note:: Only used in the output of - :mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` + (and similar mappers). Not for use in :mod:`loopy` source representation. """ mapper_method = "map_group_hw_index" @@ -583,8 +646,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex): .. note:: Only used in the output of - :mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and - similar mappers). Not for use in Loopy source representation. + :class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and + similar mappers). Not for use in :mod:`loopy` source representation. """ mapper_method = "map_local_hw_index" @@ -791,12 +854,6 @@ def __getinitargs__(self): def get_hash(self): return hash((self.__class__, self.operation, self.inames, self.expr)) - def is_equal(self, other): - return (other.__class__ == self.__class__ - and other.operation == self.operation - and other.inames == self.inames - and other.expr == self.expr) - @property def is_tuple_typed(self): return self.operation.arg_count > 1 @@ -994,14 +1051,6 @@ def __getinitargs__(self): def get_hash(self): return hash((self.__class__, self.swept_inames, self.subscript)) - def is_equal(self, other): - """ - Returns *True* iff the sub-array refs have identical expressions. - """ - return (other.__class__ == self.__class__ - and other.subscript == self.subscript - and other.swept_inames == self.swept_inames) - def make_stringifier(self, originating_stringifier=None): return StringifyMapper() diff --git a/requirements.txt b/requirements.txt index c44f010c3..82d76f5db 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1 git+https://github.com/inducer/islpy.git#egg=islpy git+https://github.com/inducer/cgen.git#egg=cgen git+https://github.com/inducer/pyopencl.git#egg=pyopencl -git+https://github.com/inducer/pymbolic.git#egg=pymbolic +git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic git+https://github.com/inducer/genpy.git#egg=genpy git+https://github.com/inducer/codepy.git#egg=codepy