From c9d11e498f0764c950bbbcb5dfc59fa0fb3f2f17 Mon Sep 17 00:00:00 2001 From: Matt Smith Date: Fri, 2 Aug 2024 12:35:59 -0500 Subject: [PATCH] Add `TransformMapper`/`TransformMapperWithExtraArgs` (#530) * add clone_for_callee to CopyMapperWithExtraArgs * add TransformMapper/TransformMapperWithExtraArgs * make CopyMapper/CopyMapperWithExtraArgs inherit from TransformMapper/TransformMapperWithExtraArgs * expand on purpose of TransformMapper in docstring --- pytato/transform/__init__.py | 136 ++++++++++++++++++++++------------- 1 file changed, 88 insertions(+), 48 deletions(-) diff --git a/pytato/transform/__init__.py b/pytato/transform/__init__.py index 0c437261c..ea790e9b6 100644 --- a/pytato/transform/__init__.py +++ b/pytato/transform/__init__.py @@ -89,7 +89,7 @@ MappedT = TypeVar("MappedT", Array, AbstractResultWithNamedArrays, ArrayOrNames) CombineT = TypeVar("CombineT") # used in CombineMapper -CopyMapperResultT = TypeVar("CopyMapperResultT", # used in CopyMapper +TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper Array, AbstractResultWithNamedArrays, ArrayOrNames) CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper IndexOrShapeExpr = TypeVar("IndexOrShapeExpr") @@ -101,6 +101,8 @@ .. autoclass:: Mapper .. autoclass:: CachedMapper +.. autoclass:: TransformMapper +.. autoclass:: TransformMapperWithExtraArgs .. autoclass:: CopyMapper .. autoclass:: CopyMapperWithExtraArgs .. autoclass:: CombineMapper @@ -150,7 +152,7 @@ .. class:: _SelfMapper A type variable used to represent the type of a mapper in - :meth:`CopyMapper.clone_for_callee`. + :meth:`TransformMapper.clone_for_callee`. """ transform_logger = logging.getLogger(__file__) @@ -230,6 +232,8 @@ class CachedMapper(Mapper, Generic[CachedMapperT]): """Mapper class that maps each node in the DAG exactly once. This loses some information compared to :class:`Mapper` as a node is visited only from one of its predecessors. + + .. automethod:: get_cache_key """ def __init__(self) -> None: @@ -256,24 +260,23 @@ def __call__(self, expr: ArrayOrNames) -> CachedMapperT: # }}} -# {{{ CopyMapper +# {{{ TransformMapper -class CopyMapper(CachedMapper[ArrayOrNames]): - """Performs a deep copy of a :class:`pytato.array.Array`. - The typical use of this mapper is to override individual ``map_`` methods - in subclasses to permit term rewriting on an expression graph. +class TransformMapper(CachedMapper[ArrayOrNames]): + """Base class for mappers that transform :class:`pytato.array.Array`\\ s into + other :class:`pytato.array.Array`\\ s. - .. automethod:: clone_for_callee + Enables certain operations that can only be done if the mapping results are also + arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not + implement default mapper methods; for that, see :class:`CopyMapper`. - .. note:: - - This does not copy the data of a :class:`pytato.array.DataWrapper`. + .. automethod:: clone_for_callee """ if TYPE_CHECKING: - def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT: - return cast(CopyMapperResultT, super().rec(expr)) + def rec(self, expr: TransformMapperResultT) -> TransformMapperResultT: + return cast(TransformMapperResultT, super().rec(expr)) - def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT: + def __call__(self, expr: TransformMapperResultT) -> TransformMapperResultT: return self.rec(expr) def clone_for_callee( @@ -284,6 +287,76 @@ def clone_for_callee( """ return type(self)() +# }}} + + +# {{{ TransformMapperWithExtraArgs + +class TransformMapperWithExtraArgs(CachedMapper[ArrayOrNames]): + """ + Similar to :class:`TransformMapper`, but each mapper method takes extra + ``*args``, ``**kwargs`` that are propagated along a path by default. + + The logic in :class:`TransformMapper` purposely does not take the extra + arguments to keep the cost of its each call frame low. + + .. automethod:: clone_for_callee + """ + def __init__(self) -> None: + super().__init__() + # type-ignored as '._cache' attribute is not coherent with the base + # class + self._cache: dict[tuple[ArrayOrNames, + tuple[Any, ...], + tuple[tuple[str, Any], ...] + ], + ArrayOrNames] = {} # type: ignore[assignment] + + def get_cache_key(self, + expr: ArrayOrNames, + *args: Any, **kwargs: Any) -> tuple[ArrayOrNames, + tuple[Any, ...], + tuple[tuple[str, Any], ...] + ]: + return (expr, args, tuple(sorted(kwargs.items()))) + + def rec(self, + expr: TransformMapperResultT, + *args: Any, **kwargs: Any) -> TransformMapperResultT: + key = self.get_cache_key(expr, *args, **kwargs) + try: + # type-ignore-reason: self._cache has ArrayOrNames as its values + return self._cache[key] # type: ignore[return-value] + except KeyError: + result = Mapper.rec(self, expr, + *args, + **kwargs) + self._cache[key] = result + # type-ignore-reason: Mapper.rec is imprecise + return result # type: ignore[no-any-return] + + def clone_for_callee( + self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper: + """ + Called to clone *self* before starting traversal of a + :class:`pytato.function.FunctionDefinition`. + """ + return type(self)() + +# }}} + + +# {{{ CopyMapper + +class CopyMapper(TransformMapper): + """Performs a deep copy of a :class:`pytato.array.Array`. + The typical use of this mapper is to override individual ``map_`` methods + in subclasses to permit term rewriting on an expression graph. + + .. note:: + + This does not copy the data of a :class:`pytato.array.DataWrapper`. + """ def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...] ) -> tuple[IndexOrShapeExpr, ...]: # type-ignore-reason: apparently mypy cannot substitute typevars @@ -467,7 +540,7 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array: return call[expr.name] -class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): +class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs): """ Similar to :class:`CopyMapper`, but each mapper method takes extra ``*args``, ``**kwargs`` that are propagated along a path by default. @@ -475,39 +548,6 @@ class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]): The logic in :class:`CopyMapper` purposely does not take the extra arguments to keep the cost of its each call frame low. """ - def __init__(self) -> None: - super().__init__() - # type-ignored as '._cache' attribute is not coherent with the base - # class - self._cache: dict[tuple[ArrayOrNames, - tuple[Any, ...], - tuple[tuple[str, Any], ...] - ], - ArrayOrNames] = {} # type: ignore[assignment] - - def get_cache_key(self, - expr: ArrayOrNames, - *args: Any, **kwargs: Any) -> tuple[ArrayOrNames, - tuple[Any, ...], - tuple[tuple[str, Any], ...] - ]: - return (expr, args, tuple(sorted(kwargs.items()))) - - def rec(self, - expr: CopyMapperResultT, - *args: Any, **kwargs: Any) -> CopyMapperResultT: - key = self.get_cache_key(expr, *args, **kwargs) - try: - # type-ignore-reason: self._cache has ArrayOrNames as its values - return self._cache[key] # type: ignore[return-value] - except KeyError: - result = Mapper.rec(self, expr, - *args, - **kwargs) - self._cache[key] = result - # type-ignore-reason: Mapper.rec is imprecise - return result # type: ignore[no-any-return] - def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...], *args: Any, **kwargs: Any ) -> tuple[IndexOrShapeExpr, ...]: