diff --git a/arraycontext/fake_numpy.py b/arraycontext/fake_numpy.py index a73716a1..9f9df57c 100644 --- a/arraycontext/fake_numpy.py +++ b/arraycontext/fake_numpy.py @@ -23,6 +23,9 @@ """ +import operator +from typing import Any + import numpy as np from arraycontext.container import NotAnArrayContainerError, serialize_container @@ -100,6 +103,89 @@ def conjugate(self, x): conj = conjugate + # {{{ linspace + + # based on + # https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182 + + def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None, + axis=0): + num = operator.index(num) + if num < 0: + raise ValueError("Number of samples, %s, must be non-negative." % num) + div = (num - 1) if endpoint else num + + # Convert float/complex array scalars to float, gh-3504 + # and make sure one can use variables that have an __array_interface__, + # gh-6634 + + if isinstance(start, self._array_context.array_types): + raise NotImplementedError("start as an actx array") + if isinstance(stop, self._array_context.array_types): + raise NotImplementedError("stop as an actx array") + + start = np.array(start) * 1.0 + stop = np.array(stop) * 1.0 + + dt = np.result_type(start, stop, float(num)) + if dtype is None: + dtype = dt + integer_dtype = False + else: + integer_dtype = np.issubdtype(dtype, np.integer) + + delta = stop - start + + y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim) + + if div > 0: + step = delta / div + #any_step_zero = _nx.asanyarray(step == 0).any() + any_step_zero = self._array_context.to_numpy((step == 0)).any() + if any_step_zero: + delta_actx = self._array_context.from_numpy(delta) + + # Special handling for denormal numbers, gh-5437 + y = y / div + y = y * delta_actx + else: + step_actx = self._array_context.from_numpy(step) + y = y * step_actx + else: + delta_actx = self._array_context.from_numpy(delta) + # sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0) + # have an undefined step + step = np.NaN + # Multiply with delta to allow possible override of output class. + y = y * delta_actx + + y += start + + # FIXME reenable, without in-place ops + # if endpoint and num > 1: + # y[-1, ...] = stop + + if axis != 0: + # y = _nx.moveaxis(y, 0, axis) + raise NotImplementedError("axis != 0") + + if integer_dtype: + y = self.floor(y) + + # FIXME: Use astype + # https://github.com/inducer/pytato/issues/456 + if retstep: + return y, step + #return y.astype(dtype), step + else: + return y + #return y.astype(dtype) + + # }}} + + def arange(self, *args: Any, **kwargs: Any): + raise NotImplementedError + # }}} @@ -180,6 +266,7 @@ def norm(self, ary, ord=None): return actx.np.sum(abs(ary)**ord)**(1/ord) else: raise NotImplementedError(f"unsupported value of 'ord': {ord}") + # }}} diff --git a/arraycontext/impl/pytato/fake_numpy.py b/arraycontext/impl/pytato/fake_numpy.py index 7b01a137..21cab42e 100644 --- a/arraycontext/impl/pytato/fake_numpy.py +++ b/arraycontext/impl/pytato/fake_numpy.py @@ -22,6 +22,7 @@ THE SOFTWARE. """ from functools import partial, reduce +from typing import Any import numpy as np diff --git a/test/test_arraycontext.py b/test/test_arraycontext.py index 396414fc..e53f4295 100644 --- a/test/test_arraycontext.py +++ b/test/test_arraycontext.py @@ -1580,6 +1580,28 @@ def test_compile_anonymous_function(actx_factory): 42) +@pytest.mark.parametrize( + ("args", "kwargs"), [ + ((1, 2, 10), {}), + ((1, 2, 10), {"endpoint": False}), + ((1, 2, 10), {"endpoint": True}), + ((2, -3, 20), {}), + ((1, 5j, 20), {"dtype": np.complex128}), + ((1, 5, 20), {"dtype": np.complex128}), + ((1, 5, 20), {"dtype": np.int32}), + ]) +def test_linspace(actx_factory, args, kwargs): + if "Jax" in actx_factory.__class__.__name__: + pytest.xfail("jax actx does not have arange") + + actx = actx_factory() + + actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs)) + np_linspace = np.linspace(*args, **kwargs) + + assert np.allclose(actx_linspace, np_linspace) + + if __name__ == "__main__": import sys if len(sys.argv) > 1: