Skip to content

Commit

Permalink
@expr_dataclass: don't require cls to be Expression subclass
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Oct 21, 2024
1 parent ff79a6f commit 2c129f2
Showing 1 changed file with 25 additions and 14 deletions.
39 changes: 25 additions & 14 deletions pymbolic/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,10 @@
ClassVar,
Mapping,
NoReturn,
Protocol,
Type,
TypeVar,
cast,
)
from warnings import warn

Expand Down Expand Up @@ -889,9 +892,13 @@ def __iter__(self):
)


class _HasMapperMethod(Protocol):
mapper_method: ClassVar[str]


def _augment_expression_dataclass(
cls: type[DataclassInstance],
hash: bool,
generate_hash: bool,
) -> None:
attr_tuple = ", ".join(f"self.{fld.name}" for fld in fields(cls))
if attr_tuple:
Expand Down Expand Up @@ -924,8 +931,9 @@ def {cls.__name__}_eq(self, other):
return True
if self.__class__ is not other.__class__:
return False
if hash(self) != hash(other):
return False
if {generate_hash}:
if hash(self) != hash(other):
return False
if self.__class__ is not cls and self.init_arg_names != {fld_name_tuple}:
warn(f"{{self.__class__}} is derived from {cls}, which is now "
f"a dataclass. {{self.__class__}} should be converted to being "
Expand Down Expand Up @@ -960,7 +968,7 @@ def {cls.__name__}_hash(self):
object.__setattr__(self, "_hash_value", hash_val)
return hash_val
if {hash}:
if {generate_hash}:
cls.__hash__ = {cls.__name__}_hash
Expand Down Expand Up @@ -1026,23 +1034,23 @@ def {cls.__name__}_setstate(self, state):

# {{{ assign mapper_method

assert issubclass(cls, Expression)
mm_cls = cast(Type[_HasMapperMethod], cls)

snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", cls.__name__).lower()
snake_clsname = _CAMEL_TO_SNAKE_RE.sub("_", mm_cls.__name__).lower()
default_mapper_method_name = f"map_{snake_clsname}"

# This covers two cases: the class does not have the attribute in the first
# place, or it inherits a value but does not set it itself.
sets_mapper_method = "mapper_method" in cls.__dict__
sets_mapper_method = "mapper_method" in mm_cls.__dict__

if sets_mapper_method:
if default_mapper_method_name == cls.mapper_method:
warn(f"Explicit mapper_method on {cls} not needed, default matches "
if default_mapper_method_name == mm_cls.mapper_method:
warn(f"Explicit mapper_method on {mm_cls} not needed, default matches "
"explicit assignment. Just delete the explicit assignment.",
stacklevel=3)

if not sets_mapper_method:
cls.mapper_method = intern(default_mapper_method_name)
mm_cls.mapper_method = intern(default_mapper_method_name)

# }}}

Expand All @@ -1053,18 +1061,21 @@ def {cls.__name__}_setstate(self, state):
@dataclass_transform(frozen_default=True)
def expr_dataclass(
init: bool = True,
hash: bool = True
hash: bool = True,
) -> Callable[[type[_T]], type[_T]]:
"""A class decorator that makes the class a :func:`~dataclasses.dataclass`
r"""A class decorator that makes the class a :func:`~dataclasses.dataclass`
while also adding functionality needed for :class:`Expression` nodes.
Specifically, it adds cached hashing, equality comparisons
with ``self is other`` shortcuts as well as some methods/attributes
for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``)
for backward compatibility (e.g. ``__getinitargs__``, ``init_arg_names``).
It also adds a :attr:`Expression.mapper_method` based on the class name
if not already present. If :attr:`~Expression.mapper_method` is inherited,
it will be viewed as unset and replaced.
Note that the class to which this decorator is applied need not be
a subclass of :class:`~pymbolic.Expression`.
.. versionadded:: 2024.1
"""
def map_cls(cls: type[_T]) -> type[_T]:
Expand All @@ -1078,7 +1089,7 @@ def map_cls(cls: type[_T]) -> type[_T]:
# It should just understand that?
_augment_expression_dataclass(
dc_cls, # type: ignore[arg-type]
hash=hash
generate_hash=hash,
)
return dc_cls

Expand Down

0 comments on commit 2c129f2

Please sign in to comment.