Skip to content

Commit

Permalink
Make blockwise perform method node dependent
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 24, 2024
1 parent a377c22 commit e39fda3
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 26 deletions.
55 changes: 29 additions & 26 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from copy import copy
from typing import Any, cast

import numpy as np
Expand Down Expand Up @@ -79,7 +78,6 @@ def __init__(
self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
Expand All @@ -91,11 +89,6 @@ def __init__(

super().__init__(**kwargs)

def __getstate__(self):
d = copy(self.__dict__)
d["_gufunc"] = None
return d

def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
Expand Down Expand Up @@ -296,32 +289,40 @@ def L_op(self, inputs, outs, ograds):

return rval

def _create_gufunc(self, node):
def _create_node_gufunc(self, node) -> None:
"""Define (or retrieve) the node gufunc used in `perform`.
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
The gufunc is stored in the tag of the node.
"""
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)

if gufunc_spec is not None:
self._gufunc = import_func_from_string(gufunc_spec[0])
if self._gufunc:
return self._gufunc
else:
gufunc = import_func_from_string(gufunc_spec[0])
if gufunc is None:
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")

n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)
else:
# Wrap core_op perform method in numpy vectorize
n_outs = len(self.outputs_sig)
core_node = self._create_dummy_core_node(node.inputs)

def core_func(*inner_inputs):
inner_outputs = [[None] for _ in range(n_outs)]
def core_func(*inner_inputs):
inner_outputs = [[None] for _ in range(n_outs)]

inner_inputs = [np.asarray(inp) for inp in inner_inputs]
self.core_op.perform(core_node, inner_inputs, inner_outputs)
inner_inputs = [np.asarray(inp) for inp in inner_inputs]
self.core_op.perform(core_node, inner_inputs, inner_outputs)

if len(inner_outputs) == 1:
return inner_outputs[0][0]
else:
return tuple(r[0] for r in inner_outputs)
if len(inner_outputs) == 1:
return inner_outputs[0][0]
else:
return tuple(r[0] for r in inner_outputs)

gufunc = np.vectorize(core_func, signature=self.signature)

self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc
node.tag.gufunc = gufunc

def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)
Expand All @@ -340,10 +341,12 @@ def _check_runtime_broadcast(self, node, inputs):
)

def perform(self, node, inputs, output_storage):
gufunc = self._gufunc
gufunc = getattr(node.tag, "gufunc", None)

if gufunc is None:
gufunc = self._create_gufunc(node)
# Cache it once per node
self._create_node_gufunc(node)
gufunc = node.tag.gufunc

self._check_runtime_broadcast(node, inputs)

Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,41 @@
from pytensor.tensor.utils import _parse_gufunc_signature


def test_perform_method_per_node():
"""Confirm that Blockwise uses one perform method per node.
This is important if the perform method requires node information (such as dtypes)
"""

class NodeDependentPerformOp(Op):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
[x] = inputs
if node.inputs[0].type.dtype.startswith("float"):
y = x + 1
else:
y = x - 1
outputs[0][0] = y

blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
x = tensor("x", shape=(3,), dtype="float32")
y = tensor("y", shape=(3,), dtype="int32")

out_x = blockwise_op(x)
out_y = blockwise_op(y)
fn = pytensor.function([x, y], [out_x, out_y])
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
# Confirm both nodes have the same Op
assert op1 is blockwise_op
assert op1 is op2

res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))


def test_vectorize_blockwise():
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
Expand Down

0 comments on commit e39fda3

Please sign in to comment.