diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 350cd077c9..56777eeb67 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -25,11 +25,11 @@ stack, switch, ) +from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise, scalar_elemwise from pytensor.tensor.shape import shape, specify_broadcastable from pytensor.tensor.type import ( DenseTensorType, - TensorType, complex_dtypes, continuous_dtypes, discrete_dtypes, @@ -2868,93 +2868,7 @@ def logsumexp(x, axis=None, keepdims=False): return log(sum(exp(x), axis=axis, keepdims=keepdims)) -class MatMul(Op): - __props__ = ("dtype",) - - def __init__(self, dtype=None): - self.dtype = dtype - - @classmethod - def _get_output_shape(cls, x1, x2, shapes, validate=False): - x1_shape, x2_shape = shapes - - if x1.ndim == 1 and x2.ndim == 1: - if validate and x1_shape[0] != x2_shape[0]: - raise ValueError("1d inputs must have the same length.") - return () - elif x1.ndim == 1 and x2.ndim > 1: - if validate and x1_shape[0] != x2_shape[-2]: - raise ValueError( - "length of input 1 must be equal the length " - "of the 2nd-last dimension of input 2" - ) - return x2_shape[:-2] + x2_shape[-1:] - elif x1.ndim > 1 and x2.ndim == 1: - if validate and x1_shape[-1] != x2_shape[0]: - raise ValueError( - "length of input 2 must be equal the length " - "of the last dimension of input 1" - ) - return x1_shape[:-1] - elif x1.ndim == 2 and x2.ndim == 2: - if validate and x1_shape[-1] != x2_shape[0]: - raise ValueError( - "number of columns of input 1 must be equal to " - "the number of rows of input 2" - ) - return x1_shape[:-1] + x2_shape[-1:] - elif x1.ndim > 2 and x2.ndim == 2: - if validate and x1_shape[-1] != x2_shape[0]: - raise ValueError( - "number of rows of input 2 must be equal to " - "the length of the last dimension of input 1" - ) - return x1_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] - elif x1.ndim == 2 and x2.ndim > 2: - if validate and x1_shape[-1] != x2_shape[-2]: - raise ValueError( - "number of columns of input 1 must be equal " - "the length of the 2nd-last dimension of input 2" - ) - return x2_shape[:-2] + x1_shape[-2:-1] + x2_shape[-1:] - else: - if validate: - from pytensor.tensor.random.basic import broadcast_shapes - - bshape = broadcast_shapes(x1_shape[:-2], x2_shape[:-2]) - if x1_shape[-1] != x2_shape[-2]: - raise ValueError( - "length of the last dimension of input 1 must be equal " - "to the length of the 2nd-last dimension of input 2" - ) - else: - from pytensor.tensor.extra_ops import broadcast_shape - - bshape = broadcast_shape( - x1_shape[:-2], x2_shape[:-2], arrays_are_shapes=True - ) - return bshape + x1_shape[-2:-1] + x2_shape[-1:] - - def make_node(self, a, b): - a = as_tensor_variable(a) - b = as_tensor_variable(b) - - if 0 in {a.ndim, b.ndim}: - raise ValueError("inputs to `matmul` cannot be scalar.") - - out_shape = self._get_output_shape( - a, b, (a.type.shape, b.type.shape), validate=True - ) - out = TensorType(dtype=self.dtype, shape=out_shape)() - return Apply(self, [a, b], [out]) - - def perform(self, node, inputs, outputs): - x1, x2 = inputs - outputs[0][0] = np.matmul(x1, x2, dtype=self.dtype) - - def infer_shape(self, fgraph, node, shapes): - x1, x2 = node.inputs - return [self._get_output_shape(x1, x2, shapes)] +_matrix_matrix_matmul = Blockwise(_dot, signature="(n,k),(k,m)->(n,m)") def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None): @@ -2999,7 +2913,23 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None - Stacks of matrices are broadcast together as if the matrices were elements, respecting the signature ``(n, k), (k, m) -> (n, m)``: """ - return MatMul(dtype=dtype)(x1, x2) + x1 = as_tensor_variable(x1) + x2 = as_tensor_variable(x2) + if x1.type.ndim == 0 or x2.type.ndim == 0: + raise ValueError("matmul operand cannot be scalar") + if x1.type.ndim == 1 and x2.type.ndim == 1: + out = _dot(x1, x2) + elif x1.type.ndim == 1: + out = _matrix_matrix_matmul(x1[None], x2).squeeze(-2) + elif x2.type.ndim == 1: + out = _matrix_matrix_matmul(x1, x2[:, None]).squeeze(-1) + else: + out = _matrix_matrix_matmul(x1, x2) + + if dtype is not None: + out = out.astype(dtype) + + return out __all__ = [ diff --git a/pytensor/tensor/rewriting/blockwise.py b/pytensor/tensor/rewriting/blockwise.py index 2533eb7aaa..101aeec368 100644 --- a/pytensor/tensor/rewriting/blockwise.py +++ b/pytensor/tensor/rewriting/blockwise.py @@ -3,6 +3,8 @@ from pytensor.graph.replace import vectorize_node from pytensor.graph.rewriting.basic import copy_stack_trace, out2in from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.math import _matrix_matrix_matmul +from pytensor.tensor.rewriting.basic import register_canonicalize @node_rewriter([Blockwise]) @@ -40,3 +42,10 @@ def local_useless_unbatched_blockwise(fgraph, node): "blockwise", position=49, ) + + +# Avoid redundant cases early on for Ops whose default form is not Blockwised +@register_canonicalize +@node_rewriter(tracks=[_matrix_matrix_matmul]) +def local_eager_useless_unbatched_blockwise(fgraph, node): + return local_useless_unbatched_blockwise.fn(fgraph, node) diff --git a/pytensor/tensor/variable.py b/pytensor/tensor/variable.py index 8494d94a0b..0f41e1c1d2 100644 --- a/pytensor/tensor/variable.py +++ b/pytensor/tensor/variable.py @@ -647,8 +647,12 @@ def __rdot__(right, left): return at.math.dense_dot(left, right) dot = __dot__ - __matmul__ = __dot__ - __rmatmul__ = __rdot__ + + def __matmul__(left, right): + return at.math.matmul(left, right) + + def __rmatmul__(right, left): + return at.math.matmul(right, left) def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None): """See :func:`pytensor.tensor.math.sum`.""" @@ -797,7 +801,7 @@ def choose(self, choices, mode="raise"): """ return at.basic.choose(self, choices, mode="raise") - def squeeze(self): + def squeeze(self, axis=None): """ Remove broadcastable dimensions from the shape of an array. @@ -805,7 +809,7 @@ def squeeze(self): removed. This is always `x` itself or a view into `x`. """ - return at.extra_ops.squeeze(self) + return at.extra_ops.squeeze(self, axis=axis) def compress(self, a, axis=None): """Return selected slices only.""" diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 8ee5ac4544..fccb02e215 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -30,11 +30,11 @@ get_underlying_scalar_constant_value, switch, ) +from pytensor.tensor.blas import Dot22 from pytensor.tensor.elemwise import CAReduce, Elemwise from pytensor.tensor.math import ( Argmax, Dot, - MatMul, MaxAndArgmax, Mean, Prod, @@ -3412,12 +3412,10 @@ def test_log1mexp_grad_lim(): assert grad_x_fn(-1e-308) != -np.inf -class TestMatMul(utt.InferShapeTester): +class TestMatMul: def setup_method(self): - super().setup_method() self.rng = np.random.default_rng(utt.fetch_seed()) self.op = matmul - self.op_class = MatMul def _validate_output(self, a, b): pytensor_sol = self.op(a, b).eval() @@ -3467,85 +3465,8 @@ def test_dtype_param(self, dtype): sol = self.op([1, 2, 3], [3, 2, 1], dtype=dtype) assert sol.eval().dtype == dtype - @pytest.mark.parametrize( - "x1_shape,x2_shape,exp_res,error_regex", - [ - ((1,), (3,), None, "inputs must have the same length"), - ((2,), (3, 1), None, "length of input 1.*2nd-last dimension of input 2"), - ((2, 5), (3,), None, "length of input 2.*of the last dimension of input 1"), - ( - (2, 5), - (3, 4), - None, - "number of columns of input 1 .* number of rows of input 2", - ), - ( - (2, 1, 3), - (5, 4), - None, - "number of rows of input 2 .* last dimension of input 1", - ), - ( - (2, 5), - (2, 4, 3), - None, - "number of columns of input 1 .* 2nd-last dimension of input 2", - ), - ( - (3, 2, 4, 5), - (1, 6, 7), - None, - "length of the last dimension of input 1 .* 2nd-last dimension of input 2", - ), - ( - (4, 5, 4), - (3, 2, 2), - None, - "cannot be broadcast to a single shape", - ), - ( - (4, None, 2), - (4, 2, None), - (4, None, None), - None, - ), - ], - ) - def test_get_output_shape(self, x1_shape, x2_shape, exp_res, error_regex): - x1 = tensor(dtype=np.float64, shape=x1_shape) - x2 = tensor(dtype=np.float64, shape=x2_shape) - - if error_regex is not None: - with pytest.raises(ValueError, match=error_regex): - self.op_class._get_output_shape( - x1, x2, (x1_shape, x2_shape), validate=True - ) - else: - assert ( - self.op_class._get_output_shape( - x1, x2, (x1_shape, x2_shape), validate=True - ) - == exp_res - ) - - def test_infer_shape(self): - for shape_x1, shape_x2 in [ - ((5,), (5,)), - ((5,), (2, 5, 3)), - ((2, 5, 3), (3,)), - ((2, 5), (5, 4)), - ((2, 5), (2, 5, 3)), - ((2, 1, 3), (3, 4)), - ((3, 2, 4, 5), (1, 5, 7)), - ]: - a = tensor(dtype=config.floatX, shape=shape_x1) - b = tensor(dtype=config.floatX, shape=shape_x2) - x1 = self.rng.random(shape_x1).astype(config.floatX) - x2 = self.rng.random(shape_x2).astype(config.floatX) - - self._compile_and_check( - [a, b], - [self.op(a, b)], - [x1, x2], - self.op_class, - ) + def test_dot22_opt(self): + x, y = matrices("xy") + fn = function([x, y], x @ y, mode="FAST_RUN") + [node] = fn.maker.fgraph.apply_nodes + assert isinstance(node.op, Dot22) diff --git a/tests/tensor/test_variable.py b/tests/tensor/test_variable.py index fac4d2f435..4d0d6b46d6 100644 --- a/tests/tensor/test_variable.py +++ b/tests/tensor/test_variable.py @@ -10,9 +10,9 @@ from pytensor.compile.mode import get_default_mode from pytensor.graph.basic import Constant, equal_computations from pytensor.tensor import get_vector_length -from pytensor.tensor.basic import constant +from pytensor.tensor.basic import as_tensor, constant from pytensor.tensor.elemwise import DimShuffle -from pytensor.tensor.math import dot, eq +from pytensor.tensor.math import dot, eq, matmul from pytensor.tensor.shape import Shape from pytensor.tensor.subtensor import AdvancedSubtensor, Subtensor from pytensor.tensor.type import ( @@ -79,16 +79,30 @@ def test_infix_dot_method(): X = dmatrix("X") y = dvector("y") - res = X @ y - exp_res = X.dot(y) + res = X.dot(y) + exp_res = dot(X, y) assert equal_computations([res], [exp_res]) X_val = np.arange(2 * 3).reshape((2, 3)) - res = X_val @ y + res = as_tensor(X_val).dot(y) exp_res = dot(X_val, y) assert equal_computations([res], [exp_res]) +def test_infix_matmul_method(): + X = dmatrix("X") + y = dvector("y") + + res = X @ y + exp_res = matmul(X, y) + assert equal_computations([res], [exp_res]) + + X_val = np.arange(2 * 3).reshape((2, 3)) + res = as_tensor(X_val) @ y + exp_res = matmul(X_val, y) + assert equal_computations([res], [exp_res]) + + def test_empty_list_indexing(): ynp = np.zeros((2, 2))[:, []] znp = np.zeros((2, 2))[:, ()]