Skip to content

Commit

Permalink
ENH: Wrap nextn_greater and nextn_less so n can be a parameter.
Browse files Browse the repository at this point in the history
  • Loading branch information
WarrenWeckesser committed Jul 29, 2024
1 parent e32ef5f commit 31b56eb
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 26 deletions.
32 changes: 10 additions & 22 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ function to provide a capability that is not possible with a gufunc.
| Function | Description |
| -------- | ----------- |
| [`convert_to_base`](#convert_to_base) | Convert an integer to a given base. |
| [`nextn_greater`](#nextn_greater) | Next n values greater than the given x. |
| [`nextn_less`](#nextn_less) | Next n values greater than the given x. |

*Other functions.*

| Function | Description |
| -------- | ----------- |
| [`nextn_greater`](#nextn_greater) | Next n values greater than the given x. |
| [`nextn_less`](#nextn_less) | Next n values greater than the given x. |
| [`cross2`](#cross2) | 2-d vector cross product (returns scalar) |
| [`cross3`](#cross3) | 3-d vector cross product |
| [`linear_interp1d`](#linear_interp1d) | Linear interpolation, like `numpy.interp` |
Expand Down Expand Up @@ -1439,44 +1439,32 @@ the weighted Jaccard index, which is defined to be
```
#### `nextn_greater`

`nextn_greater` is a gufunc with signature `()->(n)`. Given a floating
point scalar `x`, it computes the next `n` values greater than `x`.

The `out` parameter must be given, as it determines `n`.
`nextn_greater(x, n, out=None, axis=-1)` is a Python function that wraps a
gufunc with signature `()->(n)`. Given a floating point scalar `x`, it
computes the next `n` values greater than `x`.

```
>>> import numpy as np
>>> from ufunclab import nextn_greater
>>> x = np.float32(2.5)
>>> out = np.zeros(5, dtype=x.dtype)
>>> nextn_greater(x, out=xn)
array([2.5000002, 2.5000005, 2.5000007, 2.500001 , 2.5000012],
dtype=float32)
>>> xn
>>> nextn_greater(x, 5)
array([2.5000002, 2.5000005, 2.5000007, 2.500001 , 2.5000012],
dtype=float32)
```

#### `nextn_less`

`nextn_less` is a gufunc with signature `()->(n)`. Given a floating
point scalar `x`, it computes the next `n` values less than `x`.

The `out` parameter must be given, as it determines `n`.
`nextn_less(x, n, out=None, axis=-1)` is a Python function that wraps a
gufunc with signature `()->(n)`. Given a floating point scalar `x`, it
computes the next `n` values less than `x`.

```
>>> import numpy as np
>>> from ufunclab import nextn_less
>>> x = np.float32(2.5)
>>> out = np.zeros(5, dtype=x.dtype)
>>> nextn_less(x, out=xn)
array([2.4999998, 2.4999995, 2.4999993, 2.499999 , 2.4999988],
dtype=float32)
>>> xn
>>> nextn_less(x, 5)
array([2.4999998, 2.4999995, 2.4999993, 2.499999 , 2.4999988],
dtype=float32)
```
Expand Down
4 changes: 2 additions & 2 deletions ufunclab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
# The keys of this dict are in modules that are lazy-loaded.
_name_to_module = {
'convert_to_base': '._wrapped',
'nextn_greater': '._wrapped',
'nextn_less' : '._wrapped',
'logfactorial': '._logfact',
'loggamma1p': '._loggamma1p',
'issnan': '._issnan',
Expand Down Expand Up @@ -80,8 +82,6 @@
'smoothstep5': '._step',
'next_greater': '._next',
'next_less': '._next',
'nextn_greater': '._nextn',
'nextn_less' : '._nextn',
'gendot': '._gendot_wrap',
'ufunc_inspector': '._ufunc_inspector',
'__version__': '._version',
Expand Down
72 changes: 72 additions & 0 deletions ufunclab/_wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
except ImportError:
from numpy import AxisError
from ufunclab._convert_to_base import convert_to_base as _convert_to_base
from ufunclab._nextn import (nextn_less as _nextn_less,
nextn_greater as _nextn_greater)


# XXX Except for `out` and `axis`, this function does not expose any of the
Expand Down Expand Up @@ -80,3 +82,73 @@ def convert_to_base(k, base, ndigits, out=None, axis=-1):


convert_to_base.gufunc = _convert_to_base


def nextn_greater(x, n, out=None, axis=-1):
"""
Return the next n floating point values greater than x.
x must be one of the real floating point types np.float32,
np.float64 or np.longdouble.
"""
x = np.asarray(x)
if x.dtype.char not in 'fdg':
raise ValueError('x must be an array of np.float32, np.float64 or '
'np.longdouble.')
try:
n = operator.index(n)
except TypeError:
raise ValueError(f'n must be an integer; got {n!r}')
x_shape = x.shape
adjusted_axis = axis
if adjusted_axis < 0:
adjusted_axis += 1 + len(x_shape)
if adjusted_axis < 0 or adjusted_axis > 1 + len(x_shape):
raise AxisError(f'invalid axis {axis}')
out_shape = (x_shape[:adjusted_axis] + (n,)
+ x_shape[adjusted_axis:])
if out is not None:
if out.shape != out_shape:
raise ValueError(f'out.shape must be {out_shape}; '
f'got {out.shape}.')
else:
out = np.empty(out_shape, dtype=x.dtype)
return _nextn_greater(x, out=out, axis=axis)


nextn_greater.gufunc = _nextn_greater


def nextn_less(x, n, out=None, axis=-1):
"""
Return the next n floating point values less than x.
x must be one of the real floating point types np.float32,
np.float64 or np.longdouble.
"""
x = np.asarray(x)
if x.dtype.char not in 'fdg':
raise ValueError('x must be a scalar or array of np.float32, '
'np.float64 or np.longdouble.')
try:
n = operator.index(n)
except TypeError:
raise ValueError(f'n must be an integer; got {n!r}')
x_shape = x.shape
adjusted_axis = axis
if adjusted_axis < 0:
adjusted_axis += 1 + len(x_shape)
if adjusted_axis < 0 or adjusted_axis > 1 + len(x_shape):
raise AxisError(f'invalid axis {axis}')
out_shape = (x_shape[:adjusted_axis] + (n,)
+ x_shape[adjusted_axis:])
if out is not None:
if out.shape != out_shape:
raise ValueError(f'out.shape must be {out_shape}; '
f'got {out.shape}.')
else:
out = np.empty(out_shape, dtype=x.dtype)
return _nextn_less(x, out=out, axis=axis)


nextn_less.gufunc = _nextn_less
19 changes: 17 additions & 2 deletions ufunclab/tests/test_nextn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,27 @@
@pytest.mark.parametrize('dt', [np.dtype('float32'),
np.dtype('float64'),
np.dtype('longdouble')])
def test_nextn_less(func, to, dt):
def test_nextn_gufunc(func, to, dt):
to = dt.type(to)
x = dt.type(2.5)
n = 5
out = np.zeros(n, dtype=dt)
xn = func(x, out=out)
xn = func.gufunc(x, out=out)
for k in range(n):
x = np.nextafter(x, to)
assert_equal(xn[k], x)


@pytest.mark.parametrize('func, to', [(nextn_less, -np.inf),
(nextn_greater, np.inf)])
@pytest.mark.parametrize('dt', [np.dtype('float32'),
np.dtype('float64'),
np.dtype('longdouble')])
def test_nextn(func, to, dt):
to = dt.type(to)
x = dt.type(2.5)
n = 5
xn = func(x, n)
for k in range(n):
x = np.nextafter(x, to)
assert_equal(xn[k], x)

0 comments on commit 31b56eb

Please sign in to comment.