Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More array creation functions #244

Merged
merged 3 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions arraycontext/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
"""


import operator
from typing import Any

import numpy as np

from arraycontext.container import NotAnArrayContainerError, serialize_container
Expand Down Expand Up @@ -100,6 +103,89 @@ def conjugate(self, x):

conj = conjugate

# {{{ linspace

# based on
# https://github.com/numpy/numpy/blob/v1.25.0/numpy/core/function_base.py#L24-L182

def linspace(self, start, stop, num=50, endpoint=True, retstep=False, dtype=None,
axis=0):
num = operator.index(num)
if num < 0:
raise ValueError("Number of samples, %s, must be non-negative." % num)
div = (num - 1) if endpoint else num

# Convert float/complex array scalars to float, gh-3504
# and make sure one can use variables that have an __array_interface__,
# gh-6634

if isinstance(start, self._array_context.array_types):
raise NotImplementedError("start as an actx array")
if isinstance(stop, self._array_context.array_types):
raise NotImplementedError("stop as an actx array")

start = np.array(start) * 1.0
stop = np.array(stop) * 1.0

dt = np.result_type(start, stop, float(num))
if dtype is None:
dtype = dt
integer_dtype = False
else:
integer_dtype = np.issubdtype(dtype, np.integer)

delta = stop - start

y = self.arange(0, num, dtype=dt).reshape((-1,) + (1,) * delta.ndim)

if div > 0:
step = delta / div
#any_step_zero = _nx.asanyarray(step == 0).any()
any_step_zero = self._array_context.to_numpy((step == 0)).any()
if any_step_zero:
delta_actx = self._array_context.from_numpy(delta)

# Special handling for denormal numbers, gh-5437
y = y / div
y = y * delta_actx
else:
step_actx = self._array_context.from_numpy(step)
y = y * step_actx
else:
delta_actx = self._array_context.from_numpy(delta)
# sequences with 0 items or 1 item with endpoint=True (i.e. div <= 0)
# have an undefined step
step = np.NaN
# Multiply with delta to allow possible override of output class.
y = y * delta_actx

y += start

# FIXME reenable, without in-place ops
# if endpoint and num > 1:
# y[-1, ...] = stop

if axis != 0:
# y = _nx.moveaxis(y, 0, axis)
raise NotImplementedError("axis != 0")

if integer_dtype:
y = self.floor(y) # pylint: disable=no-member

# FIXME: Use astype
# https://github.com/inducer/pytato/issues/456
if retstep:
return y, step
#return y.astype(dtype), step
else:
return y
#return y.astype(dtype)

# }}}

def arange(self, *args: Any, **kwargs: Any):
raise NotImplementedError

# }}}


Expand Down Expand Up @@ -180,6 +266,7 @@ def norm(self, ary, ord=None):
return actx.np.sum(abs(ary)**ord)**(1/ord)
else:
raise NotImplementedError(f"unsupported value of 'ord': {ord}")

# }}}


Expand Down
4 changes: 3 additions & 1 deletion arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ def _copy(subary):

return self._array_context._rec_map_container(_copy, ary)

def arange(self, *args, **kwargs):
return cl_array.arange(self._array_context.queue, *args, **kwargs)

# }}}

# {{{ array manipulation routines
Expand Down Expand Up @@ -360,7 +363,6 @@ def where_inner(inner_crit, inner_then, inner_else):

# }}}


# }}}


Expand Down
7 changes: 7 additions & 0 deletions arraycontext/impl/pytato/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import Any

import numpy as np

Expand Down Expand Up @@ -100,6 +101,12 @@ def _full_like(subary):
return self._array_context._rec_map_container(
_full_like, ary, default_scalar=fill_value)

def arange(self, *args: Any, **kwargs: Any):
return pt.arange(*args, **kwargs)

def full(self, shape, fill_value, dtype=None):
return pt.full(shape, fill_value, dtype)

# }}}

# {{{ array manipulation routines
Expand Down
22 changes: 22 additions & 0 deletions test/test_arraycontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1580,6 +1580,28 @@ def test_compile_anonymous_function(actx_factory):
42)


@pytest.mark.parametrize(
("args", "kwargs"), [
((1, 2, 10), {}),
((1, 2, 10), {"endpoint": False}),
((1, 2, 10), {"endpoint": True}),
((2, -3, 20), {}),
((1, 5j, 20), {"dtype": np.complex128}),
((1, 5, 20), {"dtype": np.complex128}),
((1, 5, 20), {"dtype": np.int32}),
])
def test_linspace(actx_factory, args, kwargs):
if "Jax" in actx_factory.__class__.__name__:
pytest.xfail("jax actx does not have arange")

actx = actx_factory()

actx_linspace = actx.to_numpy(actx.np.linspace(*args, **kwargs))
np_linspace = np.linspace(*args, **kwargs)

assert np.allclose(actx_linspace, np_linspace)


if __name__ == "__main__":
import sys
if len(sys.argv) > 1:
Expand Down
Loading