Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use Blockwise for matmul #452

Merged
merged 2 commits into from
Sep 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/ifelse.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ def cond_make_inplace(fgraph, node):
at.basic.Alloc,
at.elemwise.Elemwise,
at.elemwise.DimShuffle,
at.blockwise.Blockwise,
)


Expand Down
108 changes: 19 additions & 89 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
stack,
switch,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import (
DenseTensorType,
TensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
Expand Down Expand Up @@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))


class MatMul(Op):
__props__ = ("dtype",)

def __init__(self, dtype=None):
self.dtype = dtype

@classmethod
def _get_output_shape(cls, x1, x2, shapes, validate=False):
x1_shape, x2_shape = shapes

if x1.ndim == 1 and x2.ndim == 1:
if validate and x1_shape[0] != x2_shape[0]:
raise ValueError("1d inputs must have the same length.")
return ()
elif x1.ndim == 1 and x2.ndim > 1:
if validate and x1_shape[0] != x2_shape[-2]:
raise ValueError(
"length of input 1 must be equal the length "
"of the 2nd-last dimension of input 2"
)
return x2_shape[:-2] + x2_shape[-1:]
elif x1.ndim > 1 and x2.ndim == 1:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"length of input 2 must be equal the length "
"of the last dimension of input 1"
)
return x1_shape[:-1]
elif x1.ndim == 2 and x2.ndim == 2:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"number of columns of input 1 must be equal to "
"the number of rows of input 2"
)
return x1_shape[:-1] + x2_shape[-1:]
elif x1.ndim > 2 and x2.ndim == 2:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"number of rows of input 2 must be equal to "
"the length of the last dimension of input 1"
)
return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
elif x1.ndim == 2 and x2.ndim > 2:
if validate and x1_shape[-1] != x2_shape[-2]:
raise ValueError(
"number of columns of input 1 must be equal "
"the length of the 2nd-last dimension of input 2"
)
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
else:
if validate:
from pytensor.tensor.random.basic import broadcast_shapes

bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2])
if x1_shape[-1] != x2_shape[-2]:
raise ValueError(
"length of the last dimension of input 1 must be equal "
"to the length of the 2nd-last dimension of input 2"
)
else:
from pytensor.tensor.extra_ops import broadcast_shape

bshape = broadcast_shape(
x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True
)
return bshape + x1_shape[-2:-1] + x2_shape[-1:]

def make_node(self, a, b):
a = as_tensor_variable(a)
b = as_tensor_variable(b)

if 0 in {a.ndim, b.ndim}:
raise ValueError("inputs to `matmul` cannot be scalar.")

out_shape = self._get_output_shape(
a, b, (a.type.shape, b.type.shape), validate=True
)
out = TensorType(dtype=self.dtype, shape=out_shape)()
return Apply(self, [a, b], [out])

def perform(self, node, inputs, outputs):
x1, x2 = inputs
outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype)

def infer_shape(self, fgraph, node, shapes):
x1, x2 = node.inputs
return [self._get_output_shape(x1, x2, shapes)]
_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)")


def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
Expand Down Expand Up @@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
- Stacks of matrices are broadcast together as if the matrices were elements,
respecting the signature ``(n, k), (k, m) -> (n, m)``:
"""
return MatMul(dtype=dtype)(x1, x2)
x1 = as_tensor_variable(x1)
x2 = as_tensor_variable(x2)
if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matmul operand cannot be scalar")
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is all this better than a separate _matrix_vector_matmul function? I only ask because BLAS makes the distinction.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be fine, once we ever go into optimizing this further in jax/numba backends we should be able to know which case is which by inspecting the input static types.

else:
out = _matrix_matrix_matmul(x1, x2)

if dtype is not None:
out = out.astype(dtype)

return out


__all__ = [
Expand Down
9 changes: 9 additions & 0 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import _matrix_matrix_matmul
from pytensor.tensor.rewriting.basic import register_canonicalize


@node_rewriter([Blockwise])
Expand Down Expand Up @@ -40,3 +42,10 @@ def local_useless_unbatched_blockwise(fgraph, node):
"blockwise",
position=49,
)


# Avoid redundant cases early on for Ops whose default form is not Blockwised
@register_canonicalize
@node_rewriter(tracks=[_matrix_matrix_matmul])
def local_eager_useless_unbatched_blockwise(fgraph, node):
return local_useless_unbatched_blockwise.fn(fgraph, node)
12 changes: 8 additions & 4 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,12 @@ def __rdot__(right, left):
return at.math.dense_dot(left, right)

dot = __dot__
__matmul__ = __dot__
__rmatmul__ = __rdot__

def __matmul__(left, right):
return at.math.matmul(left, right)

def __rmatmul__(right, left):
return at.math.matmul(right, left)

def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See :func:`pytensor.tensor.math.sum`."""
Expand Down Expand Up @@ -797,15 +801,15 @@ def choose(self, choices, mode="raise"):
"""
return at.basic.choose(self, choices, mode="raise")

def squeeze(self):
def squeeze(self, axis=None):
"""
Remove broadcastable dimensions from the shape of an array.

It returns the input array, but with the broadcastable dimensions
removed. This is always `x` itself or a view into `x`.

"""
return at.extra_ops.squeeze(self)
return at.extra_ops.squeeze(self, axis=axis)

def compress(self, a, axis=None):
"""Return selected slices only."""
Expand Down
93 changes: 7 additions & 86 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
get_underlying_scalar_constant_value,
switch,
)
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import (
Argmax,
Dot,
MatMul,
MaxAndArgmax,
Mean,
Prod,
Expand Down Expand Up @@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim():
assert grad_x_fn(-1e-308) != -np.inf


class TestMatMul(utt.InferShapeTester):
class TestMatMul:
def setup_method(self):
super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed())
self.op = matmul
self.op_class = MatMul

def _validate_output(self, a, b):
pytensor_sol = self.op(a, b).eval()
Expand Down Expand Up @@ -3467,85 +3465,8 @@ def test_dtype_param(self, dtype):
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
assert sol.eval().dtype == dtype

@pytest.mark.parametrize(
"x1_shape,x2_shape,exp_res,error_regex",
[
((1,), (3,), None, "inputs must have the same length"),
((2,), (3, 1), None, "length of input 1.*2nd-last dimension of input 2"),
((2, 5), (3,), None, "length of input 2.*of the last dimension of input 1"),
(
(2, 5),
(3, 4),
None,
"number of columns of input 1 .* number of rows of input 2",
),
(
(2, 1, 3),
(5, 4),
None,
"number of rows of input 2 .* last dimension of input 1",
),
(
(2, 5),
(2, 4, 3),
None,
"number of columns of input 1 .* 2nd-last dimension of input 2",
),
(
(3, 2, 4, 5),
(1, 6, 7),
None,
"length of the last dimension of input 1 .* 2nd-last dimension of input 2",
),
(
(4, 5, 4),
(3, 2, 2),
None,
"cannot be broadcast to a single shape",
),
(
(4, None, 2),
(4, 2, None),
(4, None, None),
None,
),
],
)
def test_get_output_shape(self, x1_shape, x2_shape, exp_res, error_regex):
x1 = tensor(dtype=np.float64, shape=x1_shape)
x2 = tensor(dtype=np.float64, shape=x2_shape)

if error_regex is not None:
with pytest.raises(ValueError, match=error_regex):
self.op_class._get_output_shape(
x1, x2, (x1_shape, x2_shape), validate=True
)
else:
assert (
self.op_class._get_output_shape(
x1, x2, (x1_shape, x2_shape), validate=True
)
== exp_res
)

def test_infer_shape(self):
for shape_x1, shape_x2 in [
((5,), (5,)),
((5,), (2, 5, 3)),
((2, 5, 3), (3,)),
((2, 5), (5, 4)),
((2, 5), (2, 5, 3)),
((2, 1, 3), (3, 4)),
((3, 2, 4, 5), (1, 5, 7)),
]:
a = tensor(dtype=config.floatX, shape=shape_x1)
b = tensor(dtype=config.floatX, shape=shape_x2)
x1 = self.rng.random(shape_x1).astype(config.floatX)
x2 = self.rng.random(shape_x2).astype(config.floatX)

self._compile_and_check(
[a, b],
[self.op(a, b)],
[x1, x2],
self.op_class,
)
def test_dot22_opt(self):
x, y = matrices("xy")
fn = function([x, y], x @ y, mode="FAST_RUN")
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, Dot22)
24 changes: 19 additions & 5 deletions tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import constant
from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq
from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.type import (
Expand Down Expand Up @@ -79,16 +79,30 @@ def test_infix_dot_method():
X = dmatrix("X")
y = dvector("y")

res = X @ y
exp_res = X.dot(y)
res = X.dot(y)
exp_res = dot(X, y)
assert equal_computations([res], [exp_res])

X_val = np.arange(2 * 3).reshape((2, 3))
res = X_val @ y
res = as_tensor(X_val).dot(y)
exp_res = dot(X_val, y)
assert equal_computations([res], [exp_res])


def test_infix_matmul_method():
X = dmatrix("X")
y = dvector("y")

res = X @ y
exp_res = matmul(X, y)
assert equal_computations([res], [exp_res])

X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y
exp_res = matmul(X_val, y)
assert equal_computations([res], [exp_res])


def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []]
znp = np.zeros((2, 2))[:, ()]
Expand Down
Loading