Skip to content

Commit

Permalink
MAINT: Update test_cross.py to avoid a warning from numpy.cross in Nu…
Browse files Browse the repository at this point in the history
…mPy 2.0
  • Loading branch information
WarrenWeckesser committed Oct 26, 2023
1 parent 8cd7c77 commit ebc19bb
Showing 1 changed file with 19 additions and 2 deletions.
21 changes: 19 additions & 2 deletions ufunclab/tests/test_cross.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,23 @@
from ufunclab import cross3, cross2


# In numpy 2.0, the handling of length-2 vectors in np.cross
# is deprecated. To avoid the deprecation warning, this function
# is used as a wrapper of np.cross. This wrapper does not provide
# the parameter `axisc`.
def numpy_cross2(a, b, axisa=-1, axisb=-1):
a = np.asarray(a)
b = np.asarray(b)
apad = [(0, 0)]*a.ndim
apad[axisa] = (0, 1)
bpad = [(0, 0)]*b.ndim
bpad[axisb] = (0, 1)
a3 = np.pad(a, apad)
b3 = np.pad(b, bpad)
c = np.cross(a3, b3, axisa=axisa, axisb=axisb)
return c[..., -1]


@pytest.mark.parametrize('u, v', [([1, 2, 3], [5, 3, 1]),
([1.5, 0.5, -1.5], [2.0, 9.0, -3.0]),
([1+2j, 3, -4j], [3-1j, 2j, 6])])
Expand Down Expand Up @@ -49,14 +66,14 @@ def test_cross3_nontrivial_axes():
([1+2j, 3], [3-1j, 2j])])
def test_cross2_basic(u, v):
w = cross2(u, v)
assert_equal(w, np.cross(u, v))
assert_equal(w, numpy_cross2(u, v))


def test_cross2_broadcasting():
x = np.arange(70).reshape(7, 2, 5)
y = np.arange(10).reshape(5, 2)
z = cross2(x, y, axes=[(1,), (1,)])
assert_equal(z, np.cross(x, y, axisa=1, axisb=1))
assert_equal(z, numpy_cross2(x, y, axisa=1, axisb=1))


def test_cross2_object():
Expand Down

0 comments on commit ebc19bb

Please sign in to comment.