Skip to content

Commit

Permalink
Delete/simplify/update persistent hashing code for updated pymbolic
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 3, 2024
1 parent 60741e5 commit 96340a9
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 149 deletions.
12 changes: 1 addition & 11 deletions loopy/kernel/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1023,16 +1023,6 @@ def __str__(self):
def __repr__(self):
return "<%s>" % self.__str__()

def update_persistent_hash_for_shape(self, key_hash, key_builder, shape):
if isinstance(shape, tuple):
for shape_i in shape:
if shape_i is None:
key_builder.rec(key_hash, shape_i)
else:
key_builder.update_for_pymbolic_expression(key_hash, shape_i)
else:
key_builder.rec(key_hash, shape)

def update_persistent_hash(self, key_hash, key_builder):
"""Custom hash computation function for use with
:class:`pytools.persistent_dict.PersistentDict`.
Expand All @@ -1041,7 +1031,7 @@ def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, type(self).__name__)
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.dtype)
self.update_persistent_hash_for_shape(key_hash, key_builder, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.dim_tags)
key_builder.rec(key_hash, self.offset)
key_builder.rec(key_hash, self.dim_names)
Expand Down
5 changes: 2 additions & 3 deletions loopy/kernel/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,8 +853,7 @@ def update_persistent_hash(self, key_hash, key_builder):
"""

super().update_persistent_hash(key_hash, key_builder)
self.update_persistent_hash_for_shape(key_hash, key_builder,
self.storage_shape)
key_builder.rec(key_hash, self.storage_shape)
key_builder.rec(key_hash, self.base_indices)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.base_storage)
Expand Down Expand Up @@ -899,7 +898,7 @@ def copy(self, **kwargs: Any) -> SubstitutionRule:
def update_persistent_hash(self, key_hash, key_builder):
key_builder.rec(key_hash, self.name)
key_builder.rec(key_hash, self.arguments)
key_builder.update_for_pymbolic_expression(key_hash, self.expression)
key_builder.rec(key_hash, self.expression)


# }}}
Expand Down
2 changes: 1 addition & 1 deletion loopy/kernel/function_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def depends_on(self):
return frozenset(var.name for var in result)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.update_for_pymbolic_expression(key_hash, self.shape)
key_builder.rec(key_hash, self.shape)
key_builder.rec(key_hash, self.address_space)
key_builder.rec(key_hash, self.dim_tags)

Expand Down
31 changes: 1 addition & 30 deletions loopy/kernel/instruction.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,6 @@ class InstructionBase(ImmutableRecord, Taggable):
"within_inames_is_final within_inames "
"priority".split())

# Names of fields that are pymbolic expressions. Needed for key building
pymbolic_fields = set("")

# Names of fields that are sets of pymbolic expressions. Needed for key building
pymbolic_set_fields = {"predicates"}

def __init__(self,
id: Optional[str],
happens_after: Union[
Expand Down Expand Up @@ -545,25 +539,7 @@ def _key_builder(self):
key_builder.update_for_class(self.__class__)

for field_name in self.fields:
field_value = getattr(self, field_name)
if field_name in self.pymbolic_fields:
key_builder.update_for_pymbolic_field(field_name, field_value)
elif field_name in self.pymbolic_set_fields:
# First sort the fields, as a canonical form
items = tuple(sorted(field_value, key=str))
key_builder.update_for_pymbolic_field(field_name, items)

# from CExpression
elif field_name == "iname_exprs":
from loopy.symbolic import EqualityPreservingStringifyMapper
key_builder.field_dict[field_name] = [
(iname, EqualityPreservingStringifyMapper()(expr)
.encode("utf-8"))
for iname, expr in self.iname_exprs
]

else:
key_builder.update_for_field(field_name, field_value)
key_builder.update_for_field(field_name, getattr(self, field_name))

return key_builder

Expand Down Expand Up @@ -841,7 +817,6 @@ class MultiAssignmentBase(InstructionBase):
"""An assignment instruction with an expression as a right-hand side."""

fields = InstructionBase.fields | {"expression"}
pymbolic_fields = InstructionBase.pymbolic_fields | {"expression"}

@memoize_method
def read_dependency_names(self):
Expand Down Expand Up @@ -933,7 +908,6 @@ class Assignment(MultiAssignmentBase):

fields = MultiAssignmentBase.fields | \
set("assignee temp_var_type atomicity".split())
pymbolic_fields = MultiAssignmentBase.pymbolic_fields | {"assignee"}

def __init__(self,
assignee: Union[str, ExpressionT],
Expand Down Expand Up @@ -1092,7 +1066,6 @@ class CallInstruction(MultiAssignmentBase):

fields = MultiAssignmentBase.fields | \
set("assignees temp_var_types".split())
pymbolic_fields = MultiAssignmentBase.pymbolic_fields | {"assignees"}

def __init__(self,
assignees, expression,
Expand Down Expand Up @@ -1404,8 +1377,6 @@ class CInstruction(InstructionBase):

fields = InstructionBase.fields | \
set("iname_exprs code read_variables assignees".split())
pymbolic_fields = InstructionBase.pymbolic_fields | \
set("assignees".split())

def __init__(self,
iname_exprs, code,
Expand Down
5 changes: 5 additions & 0 deletions loopy/library/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from pymbolic import var
from pymbolic.primitives import expr_dataclass
from pytools.persistent_dict import Hash, KeyBuilder

from loopy.diagnostic import LoopyError
from loopy.kernel.function_interface import ScalarCallable
Expand Down Expand Up @@ -128,6 +129,10 @@ def __str__(self):

return result

def update_persistent_hash(self, key_hash: Hash, key_builder: KeyBuilder) -> None:
# They're all stateless.
key_builder.rec(key_hash, type(self))


class SumReductionOperation(ScalarReductionOperation):
def neutral_element(self, dtype, callables_table, target):
Expand Down
36 changes: 2 additions & 34 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,39 +370,6 @@ def map_fortran_division(self, expr, enclosing_prec):
return f"[FORTRANDIV]({result})"


class EqualityPreservingStringifyMapper(StringifyMapperBase):
"""
For the benefit of
:meth:`loopy.tools.LoopyEqKeyBuilder.update_for_pymbolic_field`,
this mapper satisfies the invariant
``mapper(expr_1) == mapper(expr_2)``
if and only if
``expr_1 == expr_2``
"""

def __init__(self):
super().__init__()

def map_constant(self, expr, enclosing_prec):
if isinstance(expr, np.generic):
# Explicitly typed: Emitted string must reflect type exactly.

# FIXME: This syntax cannot currently be parsed.

return "{}({})".format(type(expr).__name__, repr(expr))
else:
result = repr(expr)

from pymbolic.mapper.stringifier import PREC_SUM
if not (result.startswith("(") and result.endswith(")")) \
and ("-" in result or "+" in result) \
and (enclosing_prec > PREC_SUM):
return self.parenthesize(result)
else:
return result


class UnidirectionalUnifier(UnidirectionalUnifierBase):
def map_reduction(self, expr, other, unis):
if not isinstance(other, type(expr)):
Expand Down Expand Up @@ -1612,7 +1579,8 @@ def map_call(self, expr):
class LoopyParser(ParserBase):
lex_table = [
(_open_dbl_bracket, pytools.lex.RE(r"\[\[")),
] + ParserBase.lex_table
*ParserBase.lex_table
]

def parse_float(self, s):
match = TRAILING_FLOAT_TAG_RE.match(s)
Expand Down
72 changes: 2 additions & 70 deletions loopy/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,14 @@
from immutables import Map

import islpy as isl
from pymbolic.mapper.persistent_hash import (
PersistentHashWalkMapper as PersistentHashWalkMapperBase,
)
from pytools import ProcessLogger, memoize_method
from pytools.persistent_dict import (
KeyBuilder as KeyBuilderBase,
WriteOncePersistentDict,
)

from loopy.symbolic import (
from .symbolic import (
RuleAwareIdentityMapper,
UncachedWalkMapper as LoopyWalkMapper,
)


Expand All @@ -61,33 +57,6 @@ def update_persistent_hash(obj, key_hash, key_builder):

# {{{ custom KeyBuilder subclass

class PersistentHashWalkMapper(LoopyWalkMapper, PersistentHashWalkMapperBase):
"""A subclass of :class:`loopy.symbolic.WalkMapper` for constructing
persistent hash keys for use with
:class:`pytools.persistent_dict.PersistentDict`.
See also :meth:`LoopyKeyBuilder.update_for_pymbolic_expression`.
"""

def __init__(self, key_hash):
LoopyWalkMapper.__init__(self)
PersistentHashWalkMapperBase.__init__(self, key_hash)

def map_reduction(self, expr, *args):
if not self.visit(expr):
return

self.key_hash.update(type(expr.operation).__name__.encode("utf-8"))
self.rec(expr.expr, *args)

def map_foreign(self, expr, *args, **kwargs):
"""Mapper method dispatch for non-:mod:`pymbolic` objects."""
if expr is None:
self.key_hash.update(b"<None>")
else:
PersistentHashWalkMapperBase.map_foreign(self, expr, *args, **kwargs)


class LoopyKeyBuilder(KeyBuilderBase):
"""A custom :class:`pytools.persistent_dict.KeyBuilder` subclass
for objects within :mod:`loopy`.
Expand Down Expand Up @@ -115,29 +84,8 @@ def update_for_Map(self, key_hash, key): # noqa
else:
raise AssertionError()

def update_for_pymbolic_expression(self, key_hash, key):
if key is None:
self.update_for_NoneType(key_hash, key)
else:
PersistentHashWalkMapper(key_hash)(key)

update_for_PMap = update_for_dict # noqa: N815


class PymbolicExpressionHashWrapper:
def __init__(self, expression):
self.expression = expression

def __eq__(self, other):
return (type(self) is type(other)
and self.expression == other.expression)

def __ne__(self, other):
return not self.__eq__(other)

def update_persistent_hash(self, key_hash, key_builder):
key_builder.update_for_pymbolic_expression(key_hash, self.expression)

# }}}


Expand Down Expand Up @@ -173,11 +121,6 @@ def update_for_class(self, class_):
def update_for_field(self, field_name, value):
self.field_dict[field_name] = value

def update_for_pymbolic_field(self, field_name, value):
from loopy.symbolic import EqualityPreservingStringifyMapper
self.field_dict[field_name] = \
EqualityPreservingStringifyMapper()(value).encode("utf-8")

def key(self):
"""A key suitable for equality comparison."""
return (self.class_.__name__.encode("utf-8"), self.field_dict)
Expand Down Expand Up @@ -905,7 +848,6 @@ def clear_in_mem_caches() -> None:
def memoize_on_disk(func, key_builder_t=LoopyKeyBuilder):
from functools import wraps

import pymbolic.primitives as prim
from pytools.persistent_dict import WriteOncePersistentDict

from loopy.kernel import LoopKernel
Expand All @@ -930,17 +872,7 @@ def wrapper(*args, **kwargs):
or kwargs.pop("_no_memoize_on_disk", False)):
return func(*args, **kwargs)

def _get_persistent_hashable_arg(arg):
if isinstance(arg, prim.Expression):
return PymbolicExpressionHashWrapper(arg)
else:
return arg

cache_key = (func.__qualname__, func.__name__,
tuple(_get_persistent_hashable_arg(arg)
for arg in args),
{kw: _get_persistent_hashable_arg(arg)
for kw, arg in kwargs.items()})
cache_key = (func.__qualname__, func.__name__, args, kwargs)

try:
result = transform_cache[cache_key]
Expand Down

0 comments on commit 96340a9

Please sign in to comment.