From 81341eb65757019b2d47c8bc8604053b8e1d8281 Mon Sep 17 00:00:00 2001 From: Alex Merose Date: Tue, 23 Jul 2024 22:24:12 +0100 Subject: [PATCH] Adding Jit functionality to Plan finialization. Created an affordance to apply a jit to blockwise functions after operator fusion. This will let the user better use accelerators. #508 needs to be merged first. --- cubed/core/plan.py | 35 +++++++++++++++++++++++++-- cubed/tests/test_executor_features.py | 27 +++++++++++++++++++++ 2 files changed, 60 insertions(+), 2 deletions(-) diff --git a/cubed/core/plan.py b/cubed/core/plan.py index f9d51455..1078d73e 100644 --- a/cubed/core/plan.py +++ b/cubed/core/plan.py @@ -1,3 +1,4 @@ +import dataclasses import inspect import tempfile import uuid @@ -21,6 +22,8 @@ sym_counter = 0 +Decorator = Callable[[Callable], Callable] + def gensym(name="op"): global sym_counter @@ -182,13 +185,40 @@ def _create_lazy_zarr_arrays(self, dag): return dag + def _compile_blockwise(self, dag, jit_function: Decorator) -> nx.MultiDiGraph: + """JIT-compiles the functions from all blockwise ops by mutating the input dag.""" + # Recommended: make a copy of the dag before calling this function. + for n in dag.nodes: + node = dag.nodes[n] + + if "primitive_op" not in node: + continue + + if not isinstance(node["pipeline"].config, BlockwiseSpec): + continue + + # node is a blockwise primitive_op. + # maybe we should investigate some sort of optics library for frozen dataclasses... + new_pipeline = dataclasses.replace( + node["pipeline"], + config=dataclasses.replace( + node["pipeline"].config, + function=jit_function(node["pipeline"].config.function) + ) + ) + node["pipeline"] = new_pipeline + + return dag + @lru_cache def _finalize_dag( - self, optimize_graph: bool = True, optimize_function=None + self, optimize_graph: bool = True, optimize_function=None, jit_function: Optional[Decorator] = None, ) -> nx.MultiDiGraph: dag = self.optimize(optimize_function).dag if optimize_graph else self.dag # create a copy since _create_lazy_zarr_arrays mutates the dag dag = dag.copy() + if callable(jit_function): + dag = self._compile_blockwise(dag, jit_function) dag = self._create_lazy_zarr_arrays(dag) return nx.freeze(dag) @@ -198,11 +228,12 @@ def execute( callbacks=None, optimize_graph=True, optimize_function=None, + jit_function=None, resume=None, spec=None, **kwargs, ): - dag = self._finalize_dag(optimize_graph, optimize_function) + dag = self._finalize_dag(optimize_graph, optimize_function, jit_function) compute_id = f"compute-{datetime.now().strftime('%Y%m%dT%H%M%S.%f')}" diff --git a/cubed/tests/test_executor_features.py b/cubed/tests/test_executor_features.py index 60bb397b..0b697187 100644 --- a/cubed/tests/test_executor_features.py +++ b/cubed/tests/test_executor_features.py @@ -1,4 +1,5 @@ import contextlib +import os import platform import fsspec @@ -264,3 +265,29 @@ def test_check_runtime_memory_modal(spec, modal_executor): match=r"Runtime memory \(2097152000\) is less than allowed_mem \(4000000000\)", ): c.compute(executor=modal_executor) + + +JIT_FUNCTIONS = [lambda fn: fn] + +try: + from numba import jit as numba_jit + JIT_FUNCTIONS.append(numba_jit) +except ModuleNotFoundError: + pass + +try: + if 'jax' in os.environ.get('CUBED_BACKEND_ARRAY_API_MODULE', ''): + from jax import jit as jax_jit + JIT_FUNCTIONS.append(jax_jit) +except ModuleNotFoundError: + pass + + +@pytest.mark.parametrize("jit_function", JIT_FUNCTIONS) +def test_check_jit_compliation(spec, executor, jit_function): + a = xp.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]], chunks=(2, 2), spec=spec) + b = xp.asarray([[1, 1, 1], [1, 1, 1], [1, 1, 1]], chunks=(2, 2), spec=spec) + c = xp.add(a, b) + assert_array_equal( + c.compute(executor=executor, jit_function=jit_function), np.array([[2, 3, 4], [5, 6, 7], [8, 9, 10]]) + ) \ No newline at end of file