diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index 43ec6b69..2de78d1b 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -155,7 +155,7 @@ def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None delta_actx = self._array_context.from_numpy(delta) # sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0) # have an undefined step - step = np.NaN + step = np.nan # Multiply with delta to allow possible override of output class. y = y * delta_actx diff --git a/arraycontext/impl/jax/fake_numpy.py b/arraycontext/impl/jax/fake_numpy.py index 27799cfe..d20448a4 100644 --- a/arraycontext/impl/jax/fake_numpy.py +++ b/arraycontext/impl/jax/fake_numpy.py @@ -130,8 +130,8 @@ def vdot(self, x, y, dtype=None): from arraycontext import rec_multimap_reduce_array_container def _rec_vdot(ary1, ary2): - common_dtype = np.find_common_type((ary1.dtype, ary2.dtype), ()) - if dtype not in [None, common_dtype]: + common_dtype = np.result_type(ary1, ary2) + if dtype not in (None, common_dtype): raise NotImplementedError( f"{type(self).__name__} cannot take dtype in vdot.")