Skip to content

Commit

Permalink
Implement np.linspace in fake_numpy
Browse files Browse the repository at this point in the history
  • Loading branch information
inducer committed Sep 18, 2023
1 parent b89fdac commit d57a560
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 0 deletions.
87 changes: 87 additions & 0 deletions arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
"""


import operator
from typing import Any

import numpy as np

from arraycontext.container import NotAnArrayContainerError, serialize_container
Expand Down Expand Up @@ -100,6 +103,89 @@ def conjugate(self, x):

conj = conjugate

# {{{ linspace

# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182

def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError("Number of samples, %s, must be non-negative." % num)
div = (num - 1) if endpoint else num

# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634

if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")

start = np.array(start) * 1.0
stop = np.array(stop) * 1.0

dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)

delta = stop - start

y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)

if div > 0:
step = delta / div
#any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy((step == 0)).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)

# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.NaN
# Multiply with delta to allow possible override of output class.
y = y * delta_actx

y += start

# FIXME reenable, without in-place ops
# if endpoint and num > 1:
# y[-1, ...] = stop

if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")

if integer_dtype:
y = self.floor(y)

# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
#return y.astype(dtype), step
else:
return y
#return y.astype(dtype)

# }}}

def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError

# }}}


Expand Down Expand Up @@ -180,6 +266,7 @@ def norm(self, ary, ord=None):
return actx.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")

# }}}


Expand Down
1 change: 1 addition & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import Any

import numpy as np

Expand Down
22 changes: 22 additions & 0 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,28 @@ def test_compile_anonymous_function(actx_factory):
42)


@pytest.mark.parametrize(
("args", "kwargs"), [
((1, 2, 10), {}),
((1, 2, 10), {"endpoint": False}),
((1, 2, 10), {"endpoint": True}),
((2, -3, 20), {}),
((1, 5j, 20), {"dtype": np.complex128}),
((1, 5, 20), {"dtype": np.complex128}),
((1, 5, 20), {"dtype": np.int32}),
])
def test_linspace(actx_factory, args, kwargs):
if "Jax" in actx_factory.__class__.__name__:
pytest.xfail("jax actx does not have arange")

actx = actx_factory()

actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs))
np_linspace = np.linspace(*args, **kwargs)

assert np.allclose(actx_linspace, np_linspace)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down

0 comments on commit d57a560

Please sign in to comment.