Skip to content

Commit

Permalink
port deprecated numpy functions
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Mar 13, 2024
1 parent 8a1f66a commit 80a8c11
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.NaN
step = np.nan
# Multiply with delta to allow possible override of output class.
y = y * delta_actx

Expand Down
4 changes: 2 additions & 2 deletions arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def vdot(self, x, y, dtype=None):
from arraycontext import rec_multimap_reduce_array_container

def _rec_vdot(ary1, ary2):
common_dtype = np.find_common_type((ary1.dtype, ary2.dtype), ())
if dtype not in [None, common_dtype]:
common_dtype = np.result_type(ary1, ary2)
if dtype not in (None, common_dtype):
raise NotImplementedError(
f"{type(self).__name__} cannot take dtype in vdot.")

Expand Down

0 comments on commit 80a8c11

Please sign in to comment.