Skip to content

Commit

Permalink
Fix TensorVariable __rmatmul__
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 4, 2023
1 parent 326cb2e commit 58fb850
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,7 +652,7 @@ def __matmul__(left, right):
return at.math.matmul(left, right)

def __rmatmul__(right, left):
return at.math.matmul(right, left)
return at.math.matmul(left, right)

def sum(self, axis=None, dtype=None, keepdims=False, acc_dtype=None):
"""See :func:`pytensor.tensor.math.sum`."""
Expand Down
9 changes: 7 additions & 2 deletions tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from pytensor.compile.mode import get_default_mode
from pytensor.graph.basic import Constant, equal_computations
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import as_tensor, constant
from pytensor.tensor.basic import constant
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import dot, eq, matmul
from pytensor.tensor.shape import Shape
Expand Down Expand Up @@ -98,10 +98,15 @@ def test_infix_matmul_method():
assert equal_computations([res], [exp_res])

X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val) @ y
res = X_val @ y
exp_res = matmul(X_val, y)
assert equal_computations([res], [exp_res])

y_val = np.arange(3)
res = X @ y_val
exp_res = matmul(X, y_val)
assert equal_computations([res], [exp_res])


def test_empty_list_indexing():
ynp = np.zeros((2, 2))[:, []]
Expand Down

0 comments on commit 58fb850

Please sign in to comment.