Skip to content

Commit

Permalink
Add TransformMapper/TransformMapperWithExtraArgs (#530)
Browse files Browse the repository at this point in the history
* add clone_for_callee to CopyMapperWithExtraArgs

* add TransformMapper/TransformMapperWithExtraArgs

* make CopyMapper/CopyMapperWithExtraArgs inherit from TransformMapper/TransformMapperWithExtraArgs

* expand on purpose of TransformMapper in docstring
  • Loading branch information
majosm authored Aug 2, 2024
1 parent 86733af commit c9d11e4
Showing 1 changed file with 88 additions and 48 deletions.
136 changes: 88 additions & 48 deletions pytato/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
MappedT = TypeVar("MappedT",
Array, AbstractResultWithNamedArrays, ArrayOrNames)
CombineT = TypeVar("CombineT") # used in CombineMapper
CopyMapperResultT = TypeVar("CopyMapperResultT", # used in CopyMapper
TransformMapperResultT = TypeVar("TransformMapperResultT", # used in TransformMapper
Array, AbstractResultWithNamedArrays, ArrayOrNames)
CachedMapperT = TypeVar("CachedMapperT") # used in CachedMapper
IndexOrShapeExpr = TypeVar("IndexOrShapeExpr")
Expand All @@ -101,6 +101,8 @@
.. autoclass:: Mapper
.. autoclass:: CachedMapper
.. autoclass:: TransformMapper
.. autoclass:: TransformMapperWithExtraArgs
.. autoclass:: CopyMapper
.. autoclass:: CopyMapperWithExtraArgs
.. autoclass:: CombineMapper
Expand Down Expand Up @@ -150,7 +152,7 @@
.. class:: _SelfMapper
A type variable used to represent the type of a mapper in
:meth:`CopyMapper.clone_for_callee`.
:meth:`TransformMapper.clone_for_callee`.
"""

transform_logger = logging.getLogger(__file__)
Expand Down Expand Up @@ -230,6 +232,8 @@ class CachedMapper(Mapper, Generic[CachedMapperT]):
"""Mapper class that maps each node in the DAG exactly once. This loses some
information compared to :class:`Mapper` as a node is visited only from
one of its predecessors.
.. automethod:: get_cache_key
"""

def __init__(self) -> None:
Expand All @@ -256,24 +260,23 @@ def __call__(self, expr: ArrayOrNames) -> CachedMapperT:
# }}}


# {{{ CopyMapper
# {{{ TransformMapper

class CopyMapper(CachedMapper[ArrayOrNames]):
"""Performs a deep copy of a :class:`pytato.array.Array`.
The typical use of this mapper is to override individual ``map_`` methods
in subclasses to permit term rewriting on an expression graph.
class TransformMapper(CachedMapper[ArrayOrNames]):
"""Base class for mappers that transform :class:`pytato.array.Array`\\ s into
other :class:`pytato.array.Array`\\ s.
.. automethod:: clone_for_callee
Enables certain operations that can only be done if the mapping results are also
arrays (e.g., calling :meth:`~CachedMapper.get_cache_key` on them). Does not
implement default mapper methods; for that, see :class:`CopyMapper`.
.. note::
This does not copy the data of a :class:`pytato.array.DataWrapper`.
.. automethod:: clone_for_callee
"""
if TYPE_CHECKING:
def rec(self, expr: CopyMapperResultT) -> CopyMapperResultT:
return cast(CopyMapperResultT, super().rec(expr))
def rec(self, expr: TransformMapperResultT) -> TransformMapperResultT:
return cast(TransformMapperResultT, super().rec(expr))

def __call__(self, expr: CopyMapperResultT) -> CopyMapperResultT:
def __call__(self, expr: TransformMapperResultT) -> TransformMapperResultT:
return self.rec(expr)

def clone_for_callee(
Expand All @@ -284,6 +287,76 @@ def clone_for_callee(
"""
return type(self)()

# }}}


# {{{ TransformMapperWithExtraArgs

class TransformMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
"""
Similar to :class:`TransformMapper`, but each mapper method takes extra
``*args``, ``**kwargs`` that are propagated along a path by default.
The logic in :class:`TransformMapper` purposely does not take the extra
arguments to keep the cost of its each call frame low.
.. automethod:: clone_for_callee
"""
def __init__(self) -> None:
super().__init__()
# type-ignored as '._cache' attribute is not coherent with the base
# class
self._cache: dict[tuple[ArrayOrNames,
tuple[Any, ...],
tuple[tuple[str, Any], ...]
],
ArrayOrNames] = {} # type: ignore[assignment]

def get_cache_key(self,
expr: ArrayOrNames,
*args: Any, **kwargs: Any) -> tuple[ArrayOrNames,
tuple[Any, ...],
tuple[tuple[str, Any], ...]
]:
return (expr, args, tuple(sorted(kwargs.items())))

def rec(self,
expr: TransformMapperResultT,
*args: Any, **kwargs: Any) -> TransformMapperResultT:
key = self.get_cache_key(expr, *args, **kwargs)
try:
# type-ignore-reason: self._cache has ArrayOrNames as its values
return self._cache[key] # type: ignore[return-value]
except KeyError:
result = Mapper.rec(self, expr,
*args,
**kwargs)
self._cache[key] = result
# type-ignore-reason: Mapper.rec is imprecise
return result # type: ignore[no-any-return]

def clone_for_callee(
self: _SelfMapper, function: FunctionDefinition) -> _SelfMapper:
"""
Called to clone *self* before starting traversal of a
:class:`pytato.function.FunctionDefinition`.
"""
return type(self)()

# }}}


# {{{ CopyMapper

class CopyMapper(TransformMapper):
"""Performs a deep copy of a :class:`pytato.array.Array`.
The typical use of this mapper is to override individual ``map_`` methods
in subclasses to permit term rewriting on an expression graph.
.. note::
This does not copy the data of a :class:`pytato.array.DataWrapper`.
"""
def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...]
) -> tuple[IndexOrShapeExpr, ...]:
# type-ignore-reason: apparently mypy cannot substitute typevars
Expand Down Expand Up @@ -467,47 +540,14 @@ def map_named_call_result(self, expr: NamedCallResult) -> Array:
return call[expr.name]


class CopyMapperWithExtraArgs(CachedMapper[ArrayOrNames]):
class CopyMapperWithExtraArgs(TransformMapperWithExtraArgs):
"""
Similar to :class:`CopyMapper`, but each mapper method takes extra
``*args``, ``**kwargs`` that are propagated along a path by default.
The logic in :class:`CopyMapper` purposely does not take the extra
arguments to keep the cost of its each call frame low.
"""
def __init__(self) -> None:
super().__init__()
# type-ignored as '._cache' attribute is not coherent with the base
# class
self._cache: dict[tuple[ArrayOrNames,
tuple[Any, ...],
tuple[tuple[str, Any], ...]
],
ArrayOrNames] = {} # type: ignore[assignment]

def get_cache_key(self,
expr: ArrayOrNames,
*args: Any, **kwargs: Any) -> tuple[ArrayOrNames,
tuple[Any, ...],
tuple[tuple[str, Any], ...]
]:
return (expr, args, tuple(sorted(kwargs.items())))

def rec(self,
expr: CopyMapperResultT,
*args: Any, **kwargs: Any) -> CopyMapperResultT:
key = self.get_cache_key(expr, *args, **kwargs)
try:
# type-ignore-reason: self._cache has ArrayOrNames as its values
return self._cache[key] # type: ignore[return-value]
except KeyError:
result = Mapper.rec(self, expr,
*args,
**kwargs)
self._cache[key] = result
# type-ignore-reason: Mapper.rec is imprecise
return result # type: ignore[no-any-return]

def rec_idx_or_size_tuple(self, situp: tuple[IndexOrShapeExpr, ...],
*args: Any, **kwargs: Any
) -> tuple[IndexOrShapeExpr, ...]:
Expand Down

0 comments on commit c9d11e4

Please sign in to comment.