diff --git a/arraycontext/impl/pytato/compile.py b/arraycontext/impl/pytato/compile.py index 5fe9e7b3..dd21ad4c 100644 --- a/arraycontext/impl/pytato/compile.py +++ b/arraycontext/impl/pytato/compile.py @@ -36,7 +36,7 @@ from typing import Any, Callable, Dict, FrozenSet, Mapping, Tuple, Type import numpy as np -from pyrsistent import PMap, pmap +from immutabledict import immutabledict import pytato as pt from pytools import ProcessLogger @@ -131,11 +131,9 @@ def _rec_str(key: Any) -> str: def _get_arg_id_to_arg_and_arg_id_to_descr(args: Tuple[Any, ...], kwargs: Mapping[str, Any] - ) -> "Tuple[PMap[Tuple[Any, ...],\ - Any],\ - PMap[Tuple[Any, ...],\ - AbstractInputDescriptor]\ - ]": + ) -> \ + Tuple[Mapping[Tuple[Any, ...], Any], + Mapping[Tuple[Any, ...], AbstractInputDescriptor]]: """ Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Extracts mappings from argument id to argument values and from argument id to @@ -171,7 +169,7 @@ def id_collector(keys, ary): " either a scalar, pt.Array or an array container. Got" f" '{arg}'.") - return pmap(arg_id_to_arg), pmap(arg_id_to_descr) + return immutabledict(arg_id_to_arg), immutabledict(arg_id_to_descr) def _to_input_for_compiled(ary: ArrayT, actx: PytatoPyOpenCLArrayContext): @@ -259,7 +257,7 @@ class BaseLazilyCompilingFunctionCaller: actx: _BasePytatoArrayContext f: Callable[..., Any] - program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]", + program_cache: Dict[Mapping[Tuple[Any, ...], AbstractInputDescriptor], "CompiledFunction"] = field(default_factory=lambda: {}) # {{{ abstract interface diff --git a/setup.py b/setup.py index b9437965..65631303 100644 --- a/setup.py +++ b/setup.py @@ -42,7 +42,7 @@ def main(): # https://github.com/inducer/arraycontext/pull/147 "pytools>=2022.1.3", - + "immutabledict", "loopy>=2019.1", ], extras_require={