Skip to content

Commit

Permalink
Make blockwise perform method node dependent
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 23, 2024
1 parent a377c22 commit 50b4c4b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 11 deletions.
16 changes: 5 additions & 11 deletions pytensor/tensor/blockwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from collections.abc import Sequence
from copy import copy
from typing import Any, cast

import numpy as np
Expand Down Expand Up @@ -79,7 +78,6 @@ def __init__(
self.name = name
self.inputs_sig, self.outputs_sig = _parse_gufunc_signature(signature)
self.gufunc_spec = gufunc_spec
self._gufunc = None
if destroy_map is not None:
self.destroy_map = destroy_map
if self.destroy_map != core_op.destroy_map:
Expand All @@ -91,11 +89,6 @@ def __init__(

super().__init__(**kwargs)

def __getstate__(self):
d = copy(self.__dict__)
d["_gufunc"] = None
return d

def _create_dummy_core_node(self, inputs: Sequence[TensorVariable]) -> Apply:
core_input_types = []
for i, (inp, sig) in enumerate(zip(inputs, self.inputs_sig)):
Expand Down Expand Up @@ -320,8 +313,7 @@ def core_func(*inner_inputs):
else:
return tuple(r[0] for r in inner_outputs)

self._gufunc = np.vectorize(core_func, signature=self.signature)
return self._gufunc
node.tag.gufunc = np.vectorize(core_func, signature=self.signature)

def _check_runtime_broadcast(self, node, inputs):
batch_ndim = self.batch_ndim(node)
Expand All @@ -340,10 +332,12 @@ def _check_runtime_broadcast(self, node, inputs):
)

def perform(self, node, inputs, output_storage):
gufunc = self._gufunc
gufunc = getattr(node.tag, "gufunc", None)

if gufunc is None:
gufunc = self._create_gufunc(node)
# Cache it once per node
self._create_gufunc(node)
gufunc = node.tag.gufunc

self._check_runtime_broadcast(node, inputs)

Expand Down
35 changes: 35 additions & 0 deletions tests/tensor/test_blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,41 @@
from pytensor.tensor.utils import _parse_gufunc_signature


def test_perform_method_per_node():
"""Confirm that Blockwise uses one perform method per node.
This is important if the perform method requires node information (such as dtypes)
"""

class NodeDependentPerformOp(Op):
def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inputs, outputs):
[x] = inputs
if node.inputs[0].type.dtype.startswith("float"):
y = x + 1
else:
y = x - 1
outputs[0][0] = y

blockwise_op = Blockwise(core_op=NodeDependentPerformOp(), signature="()->()")
x = tensor("x", shape=(3,), dtype="float32")
y = tensor("y", shape=(3,), dtype="int32")

out_x = blockwise_op(x)
out_y = blockwise_op(y)
fn = pytensor.function([x, y], [out_x, out_y])
[op1, op2] = [node.op for node in fn.maker.fgraph.apply_nodes]
# Confirm both nodes have the same Op
assert op1 is blockwise_op
assert op1 is op2

res_out_x, res_out_y = fn(np.zeros(3, dtype="float32"), np.zeros(3, dtype="int32"))
np.testing.assert_array_equal(res_out_x, np.ones(3, dtype="float32"))
np.testing.assert_array_equal(res_out_y, -np.ones(3, dtype="int32"))


def test_vectorize_blockwise():
mat = tensor(shape=(None, None))
tns = tensor(shape=(None, None, None))
Expand Down

0 comments on commit 50b4c4b

Please sign in to comment.