From e39fda374fd33a2a7a7017b3b4ac0795811a075b Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Wed, 23 Oct 2024 11:45:43 +0200 Subject: [PATCH] Make blockwise perform method node dependent --- pytensor/tensor/blockwise.py | 55 ++++++++++++++++++---------------- tests/tensor/test_blockwise.py | 35 ++++++++++++++++++++++ 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/pytensor/tensor/blockwise.py b/pytensor/tensor/blockwise.py index 7fa1313cba..01c47c3b80 100644 --- a/pytensor/tensor/blockwise.py +++ b/pytensor/tensor/blockwise.py @@ -1,5 +1,4 @@ from collections.abc import Sequence -from copy import copy from typing import Any, cast import numpy as np @@ -79,7 +78,6 @@ def __init__( self.name = name self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature) self.gufunc_spec = gufunc_spec - self._gufunc = None if destroy_map is not None: self.destroy_map = destroy_map if self.destroy_map != core_op.destroy_map: @@ -91,11 +89,6 @@ def __init__( super().__init__(**kwargs) - def __getstate__(self): - d = copy(self.__dict__) - d["_gufunc"] = None - return d - def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply: core_input_types = [] for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)): @@ -296,32 +289,40 @@ def L_op(self, inputs, outs, ograds): return rval - def _create_gufunc(self, node): + def _create_node_gufunc(self, node) -> None: + """Define (or retrieve) the node gufunc used in `perform`. + + If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly. + Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node. + + The gufunc is stored in the tag of the node. + """ gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) if gufunc_spec is not None: - self._gufunc = import_func_from_string(gufunc_spec[0]) - if self._gufunc: - return self._gufunc - else: + gufunc = import_func_from_string(gufunc_spec[0]) + if gufunc is None: raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}") - n_outs = len(self.outputs_sig) - core_node = self._create_dummy_core_node(node.inputs) + else: + # Wrap core_op perform method in numpy vectorize + n_outs = len(self.outputs_sig) + core_node = self._create_dummy_core_node(node.inputs) - def core_func(*inner_inputs): - inner_outputs = [[None] for _ in range(n_outs)] + def core_func(*inner_inputs): + inner_outputs = [[None] for _ in range(n_outs)] - inner_inputs = [np.asarray(inp) for inp in inner_inputs] - self.core_op.perform(core_node, inner_inputs, inner_outputs) + inner_inputs = [np.asarray(inp) for inp in inner_inputs] + self.core_op.perform(core_node, inner_inputs, inner_outputs) - if len(inner_outputs) == 1: - return inner_outputs[0][0] - else: - return tuple(r[0] for r in inner_outputs) + if len(inner_outputs) == 1: + return inner_outputs[0][0] + else: + return tuple(r[0] for r in inner_outputs) + + gufunc = np.vectorize(core_func, signature=self.signature) - self._gufunc = np.vectorize(core_func, signature=self.signature) - return self._gufunc + node.tag.gufunc = gufunc def _check_runtime_broadcast(self, node, inputs): batch_ndim = self.batch_ndim(node) @@ -340,10 +341,12 @@ def _check_runtime_broadcast(self, node, inputs): ) def perform(self, node, inputs, output_storage): - gufunc = self._gufunc + gufunc = getattr(node.tag, "gufunc", None) if gufunc is None: - gufunc = self._create_gufunc(node) + # Cache it once per node + self._create_node_gufunc(node) + gufunc = node.tag.gufunc self._check_runtime_broadcast(node, inputs) diff --git a/tests/tensor/test_blockwise.py b/tests/tensor/test_blockwise.py index f6783cf945..bd69d809a3 100644 --- a/tests/tensor/test_blockwise.py +++ b/tests/tensor/test_blockwise.py @@ -28,6 +28,41 @@ from pytensor.tensor.utils import _parse_gufunc_signature +def test_perform_method_per_node(): + """Confirm that Blockwise uses one perform method per node. + + This is important if the perform method requires node information (such as dtypes) + """ + + class NodeDependentPerformOp(Op): + def make_node(self, x): + return Apply(self, [x], [x.type()]) + + def perform(self, node, inputs, outputs): + [x] = inputs + if node.inputs[0].type.dtype.startswith("float"): + y = x + 1 + else: + y = x - 1 + outputs[0][0] = y + + blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()") + x = tensor("x", shape=(3,), dtype="float32") + y = tensor("y", shape=(3,), dtype="int32") + + out_x = blockwise_op(x) + out_y = blockwise_op(y) + fn = pytensor.function([x, y], [out_x, out_y]) + [op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes] + # Confirm both nodes have the same Op + assert op1 is blockwise_op + assert op1 is op2 + + res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32")) + np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32")) + np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32")) + + def test_vectorize_blockwise(): mat = tensor(shape=(None, None)) tns = tensor(shape=(None, None, None))