Skip to content

Commit

Permalink
Remove Matmul Operator in favor of Blockwise Dot
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Sep 24, 2023
1 parent 7c58661 commit 071eadd
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 184 deletions.
108 changes: 19 additions & 89 deletions pytensor/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@
stack,
switch,
)
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise
from pytensor.tensor.shape import shape, specify_broadcastable
from pytensor.tensor.type import (
DenseTensorType,
TensorType,
complex_dtypes,
continuous_dtypes,
discrete_dtypes,
Expand Down Expand Up @@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False):
return log(sum(exp(x), axis=axis, keepdims=keepdims))


class MatMul(Op):
__props__ = ("dtype",)

def __init__(self, dtype=None):
self.dtype = dtype

@classmethod
def _get_output_shape(cls, x1, x2, shapes, validate=False):
x1_shape, x2_shape = shapes

if x1.ndim == 1 and x2.ndim == 1:
if validate and x1_shape[0] != x2_shape[0]:
raise ValueError("1d inputs must have the same length.")
return ()
elif x1.ndim == 1 and x2.ndim > 1:
if validate and x1_shape[0] != x2_shape[-2]:
raise ValueError(
"length of input 1 must be equal the length "
"of the 2nd-last dimension of input 2"
)
return x2_shape[:-2] + x2_shape[-1:]
elif x1.ndim > 1 and x2.ndim == 1:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"length of input 2 must be equal the length "
"of the last dimension of input 1"
)
return x1_shape[:-1]
elif x1.ndim == 2 and x2.ndim == 2:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"number of columns of input 1 must be equal to "
"the number of rows of input 2"
)
return x1_shape[:-1] + x2_shape[-1:]
elif x1.ndim > 2 and x2.ndim == 2:
if validate and x1_shape[-1] != x2_shape[0]:
raise ValueError(
"number of rows of input 2 must be equal to "
"the length of the last dimension of input 1"
)
return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
elif x1.ndim == 2 and x2.ndim > 2:
if validate and x1_shape[-1] != x2_shape[-2]:
raise ValueError(
"number of columns of input 1 must be equal "
"the length of the 2nd-last dimension of input 2"
)
return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:]
else:
if validate:
from pytensor.tensor.random.basic import broadcast_shapes

bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2])
if x1_shape[-1] != x2_shape[-2]:
raise ValueError(
"length of the last dimension of input 1 must be equal "
"to the length of the 2nd-last dimension of input 2"
)
else:
from pytensor.tensor.extra_ops import broadcast_shape

bshape = broadcast_shape(
x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True
)
return bshape + x1_shape[-2:-1] + x2_shape[-1:]

def make_node(self, a, b):
a = as_tensor_variable(a)
b = as_tensor_variable(b)

if 0 in {a.ndim, b.ndim}:
raise ValueError("inputs to `matmul` cannot be scalar.")

out_shape = self._get_output_shape(
a, b, (a.type.shape, b.type.shape), validate=True
)
out = TensorType(dtype=self.dtype, shape=out_shape)()
return Apply(self, [a, b], [out])

def perform(self, node, inputs, outputs):
x1, x2 = inputs
outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype)

def infer_shape(self, fgraph, node, shapes):
x1, x2 = node.inputs
return [self._get_output_shape(x1, x2, shapes)]
_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)")


def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None):
Expand Down Expand Up @@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
- Stacks of matrices are broadcast together as if the matrices were elements,
respecting the signature ``(n, k), (k, m) -> (n, m)``:
"""
return MatMul(dtype=dtype)(x1, x2)
x1 = as_tensor_variable(x1)
x2 = as_tensor_variable(x2)
if x1.type.ndim == 0 or x2.type.ndim == 0:
raise ValueError("matmul operand cannot be scalar")
if x1.type.ndim == 1 and x2.type.ndim == 1:
out = _dot(x1, x2)
elif x1.type.ndim == 1:
out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2)
elif x2.type.ndim == 1:
out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1)
else:
out = _matrix_matrix_matmul(x1, x2)

if dtype is not None:
out = out.astype(dtype)

return out


__all__ = [
Expand Down
9 changes: 9 additions & 0 deletions pytensor/tensor/rewriting/blockwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pytensor.graph.replace import vectorize_node
from pytensor.graph.rewriting.basic import copy_stack_trace, out2in
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import _matrix_matrix_matmul
from pytensor.tensor.rewriting.basic import register_canonicalize


@node_rewriter([Blockwise])
Expand Down Expand Up @@ -40,3 +42,10 @@ def local_useless_unbatched_blockwise(fgraph, node):
"blockwise",
position=49,
)


# Avoid redundant cases early on for Ops whose default form is not Blockwised
@register_canonicalize
@node_rewriter(tracks=[_matrix_matrix_matmul])
def local_eager_useless_unbatched_blockwise(fgraph, node):
return local_useless_unbatched_blockwise.fn(fgraph, node)
12 changes: 8 additions & 4 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,8 +647,12 @@ def __rdot__(right, left):
return at.math.dense_dot(left, right)

dot = __dot__
__matmul__ = __dot__
__rmatmul__ = __rdot__

def __matmul__(left, right):
return at.math.matmul(left, right)

def __rmatmul__(right, left):
return at.math.matmul(right, left)

def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See :func:`pytensor.tensor.math.sum`."""
Expand Down Expand Up @@ -797,15 +801,15 @@ def choose(self, choices, mode="raise"):
"""
return at.basic.choose(self, choices, mode="raise")

def squeeze(self):
def squeeze(self, axis=None):
"""
Remove broadcastable dimensions from the shape of an array.
It returns the input array, but with the broadcastable dimensions
removed. This is always `x` itself or a view into `x`.
"""
return at.extra_ops.squeeze(self)
return at.extra_ops.squeeze(self, axis=axis)

def compress(self, a, axis=None):
"""Return selected slices only."""
Expand Down
93 changes: 7 additions & 86 deletions tests/tensor/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,11 @@
get_underlying_scalar_constant_value,
switch,
)
from pytensor.tensor.blas import Dot22
from pytensor.tensor.elemwise import CAReduce, Elemwise
from pytensor.tensor.math import (
Argmax,
Dot,
MatMul,
MaxAndArgmax,
Mean,
Prod,
Expand Down Expand Up @@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim():
assert grad_x_fn(-1e-308) != -np.inf


class TestMatMul(utt.InferShapeTester):
class TestMatMul:
def setup_method(self):
super().setup_method()
self.rng = np.random.default_rng(utt.fetch_seed())
self.op = matmul
self.op_class = MatMul

def _validate_output(self, a, b):
pytensor_sol = self.op(a, b).eval()
Expand Down Expand Up @@ -3467,85 +3465,8 @@ def test_dtype_param(self, dtype):
sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype)
assert sol.eval().dtype == dtype

@pytest.mark.parametrize(
"x1_shape,x2_shape,exp_res,error_regex",
[
((1,), (3,), None, "inputs must have the same length"),
((2,), (3, 1), None, "length of input 1.*2nd-last dimension of input 2"),
((2, 5), (3,), None, "length of input 2.*of the last dimension of input 1"),
(
(2, 5),
(3, 4),
None,
"number of columns of input 1 .* number of rows of input 2",
),
(
(2, 1, 3),
(5, 4),
None,
"number of rows of input 2 .* last dimension of input 1",
),
(
(2, 5),
(2, 4, 3),
None,
"number of columns of input 1 .* 2nd-last dimension of input 2",
),
(
(3, 2, 4, 5),
(1, 6, 7),
None,
"length of the last dimension of input 1 .* 2nd-last dimension of input 2",
),
(
(4, 5, 4),
(3, 2, 2),
None,
"cannot be broadcast to a single shape",
),
(
(4, None, 2),
(4, 2, None),
(4, None, None),
None,
),
],
)
def test_get_output_shape(self, x1_shape, x2_shape, exp_res, error_regex):
x1 = tensor(dtype=np.float64, shape=x1_shape)
x2 = tensor(dtype=np.float64, shape=x2_shape)

if error_regex is not None:
with pytest.raises(ValueError, match=error_regex):
self.op_class._get_output_shape(
x1, x2, (x1_shape, x2_shape), validate=True
)
else:
assert (
self.op_class._get_output_shape(
x1, x2, (x1_shape, x2_shape), validate=True
)
== exp_res
)

def test_infer_shape(self):
for shape_x1, shape_x2 in [
((5,), (5,)),
((5,), (2, 5, 3)),
((2, 5, 3), (3,)),
((2, 5), (5, 4)),
((2, 5), (2, 5, 3)),
((2, 1, 3), (3, 4)),
((3, 2, 4, 5), (1, 5, 7)),
]:
a = tensor(dtype=config.floatX, shape=shape_x1)
b = tensor(dtype=config.floatX, shape=shape_x2)
x1 = self.rng.random(shape_x1).astype(config.floatX)
x2 = self.rng.random(shape_x2).astype(config.floatX)

self._compile_and_check(
[a, b],
[self.op(a, b)],
[x1, x2],
self.op_class,
)
def test_dot22_opt(self):
x, y = matrices("xy")
fn = function([x, y], x @ y, mode="FAST_RUN")
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, Dot22)
24 changes: 19 additions & 5 deletions tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import constant
from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq
from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape
from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor
from pytensor.tensor.type import (
Expand Down Expand Up @@ -79,16 +79,30 @@ def test_infix_dot_method():
X = dmatrix("X")
y = dvector("y")

res = X @ y
exp_res = X.dot(y)
res = X.dot(y)
exp_res = dot(X, y)
assert equal_computations([res], [exp_res])

X_val = np.arange(2 * 3).reshape((2, 3))
res = X_val @ y
res = as_tensor(X_val).dot(y)
exp_res = dot(X_val, y)
assert equal_computations([res], [exp_res])


def test_infix_matmul_method():
X = dmatrix("X")
y = dvector("y")

res = X @ y
exp_res = matmul(X, y)
assert equal_computations([res], [exp_res])

X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y
exp_res = matmul(X_val, y)
assert equal_computations([res], [exp_res])


def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []]
znp = np.zeros((2, 2))[:, ()]
Expand Down

0 comments on commit 071eadd

Please sign in to comment.