Skip to content

Commit

Permalink
Add explicit check for failing ndarray.dot(TensorVariable)
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Oct 4, 2023
1 parent 2858e69 commit 686725c
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions tests/tensor/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,20 @@ def test_numpy_method(fct, value):
utt.assert_allclose(np.nan_to_num(f(value)), np.nan_to_num(fct(value)))


def test_infix_dot_method():
def test_dot_method():
X = dmatrix("X")
y = dvector("y")

res = X.dot(y)
exp_res = dot(X, y)
assert equal_computations([res], [exp_res])

# This doesn't work. Numpy calls TensorVariable.__rmul__ at some point and everything is messed up
X_val = np.arange(2 * 3).reshape((2, 3))
res = as_tensor(X_val).dot(y)
res = X_val.dot(y)
exp_res = dot(X_val, y)
assert equal_computations([res], [exp_res])
with pytest.raises(AssertionError):
assert equal_computations([res], [exp_res])


def test_infix_matmul_method():
Expand Down

0 comments on commit 686725c

Please sign in to comment.