diff --git a/README.md b/README.md index bfdbbc4..4e8d174 100644 --- a/README.md +++ b/README.md @@ -133,13 +133,13 @@ function to provide a capability that is not possible with a gufunc. | Function | Description | | -------- | ----------- | | [`convert_to_base`](#convert_to_base) | Convert an integer to a given base. | +| [`nextn_greater`](#nextn_greater) | Next n values greater than the given x. | +| [`nextn_less`](#nextn_less) | Next n values greater than the given x. | *Other functions.* | Function | Description | | -------- | ----------- | -| [`nextn_greater`](#nextn_greater) | Next n values greater than the given x. | -| [`nextn_less`](#nextn_less) | Next n values greater than the given x. | | [`cross2`](#cross2) | 2-d vector cross product (returns scalar) | | [`cross3`](#cross3) | 3-d vector cross product | | [`linear_interp1d`](#linear_interp1d) | Linear interpolation, like `numpy.interp` | @@ -1439,44 +1439,32 @@ the weighted Jaccard index, which is defined to be ``` #### `nextn_greater` -`nextn_greater` is a gufunc with signature `()->(n)`. Given a floating -point scalar `x`, it computes the next `n` values greater than `x`. - -The `out` parameter must be given, as it determines `n`. +`nextn_greater(x, n, out=None, axis=-1)` is a Python function that wraps a +gufunc with signature `()->(n)`. Given a floating point scalar `x`, it +computes the next `n` values greater than `x`. ``` >>> import numpy as np >>> from ufunclab import nextn_greater >>> x = np.float32(2.5) ->>> out = np.zeros(5, dtype=x.dtype) - ->>> nextn_greater(x, out=xn) -array([2.5000002, 2.5000005, 2.5000007, 2.500001 , 2.5000012], - dtype=float32) ->>> xn +>>> nextn_greater(x, 5) array([2.5000002, 2.5000005, 2.5000007, 2.500001 , 2.5000012], dtype=float32) ``` #### `nextn_less` -`nextn_less` is a gufunc with signature `()->(n)`. Given a floating -point scalar `x`, it computes the next `n` values less than `x`. - -The `out` parameter must be given, as it determines `n`. +`nextn_less(x, n, out=None, axis=-1)` is a Python function that wraps a +gufunc with signature `()->(n)`. Given a floating point scalar `x`, it +computes the next `n` values less than `x`. ``` >>> import numpy as np >>> from ufunclab import nextn_less >>> x = np.float32(2.5) ->>> out = np.zeros(5, dtype=x.dtype) - ->>> nextn_less(x, out=xn) -array([2.4999998, 2.4999995, 2.4999993, 2.499999 , 2.4999988], - dtype=float32) ->>> xn +>>> nextn_less(x, 5) array([2.4999998, 2.4999995, 2.4999993, 2.499999 , 2.4999988], dtype=float32) ``` diff --git a/ufunclab/__init__.py b/ufunclab/__init__.py index 99075f1..ff5d8c2 100644 --- a/ufunclab/__init__.py +++ b/ufunclab/__init__.py @@ -18,6 +18,8 @@ # The keys of this dict are in modules that are lazy-loaded. _name_to_module = { 'convert_to_base': '._wrapped', + 'nextn_greater': '._wrapped', + 'nextn_less' : '._wrapped', 'logfactorial': '._logfact', 'loggamma1p': '._loggamma1p', 'issnan': '._issnan', @@ -80,8 +82,6 @@ 'smoothstep5': '._step', 'next_greater': '._next', 'next_less': '._next', - 'nextn_greater': '._nextn', - 'nextn_less' : '._nextn', 'gendot': '._gendot_wrap', 'ufunc_inspector': '._ufunc_inspector', '__version__': '._version', diff --git a/ufunclab/_wrapped.py b/ufunclab/_wrapped.py index c04886a..fdb9b3c 100644 --- a/ufunclab/_wrapped.py +++ b/ufunclab/_wrapped.py @@ -9,6 +9,8 @@ except ImportError: from numpy import AxisError from ufunclab._convert_to_base import convert_to_base as _convert_to_base +from ufunclab._nextn import (nextn_less as _nextn_less, + nextn_greater as _nextn_greater) # XXX Except for `out` and `axis`, this function does not expose any of the @@ -80,3 +82,73 @@ def convert_to_base(k, base, ndigits, out=None, axis=-1): convert_to_base.gufunc = _convert_to_base + + +def nextn_greater(x, n, out=None, axis=-1): + """ + Return the next n floating point values greater than x. + + x must be one of the real floating point types np.float32, + np.float64 or np.longdouble. + """ + x = np.asarray(x) + if x.dtype.char not in 'fdg': + raise ValueError('x must be an array of np.float32, np.float64 or ' + 'np.longdouble.') + try: + n = operator.index(n) + except TypeError: + raise ValueError(f'n must be an integer; got {n!r}') + x_shape = x.shape + adjusted_axis = axis + if adjusted_axis < 0: + adjusted_axis += 1 + len(x_shape) + if adjusted_axis < 0 or adjusted_axis > 1 + len(x_shape): + raise AxisError(f'invalid axis {axis}') + out_shape = (x_shape[:adjusted_axis] + (n,) + + x_shape[adjusted_axis:]) + if out is not None: + if out.shape != out_shape: + raise ValueError(f'out.shape must be {out_shape}; ' + f'got {out.shape}.') + else: + out = np.empty(out_shape, dtype=x.dtype) + return _nextn_greater(x, out=out, axis=axis) + + +nextn_greater.gufunc = _nextn_greater + + +def nextn_less(x, n, out=None, axis=-1): + """ + Return the next n floating point values less than x. + + x must be one of the real floating point types np.float32, + np.float64 or np.longdouble. + """ + x = np.asarray(x) + if x.dtype.char not in 'fdg': + raise ValueError('x must be a scalar or array of np.float32, ' + 'np.float64 or np.longdouble.') + try: + n = operator.index(n) + except TypeError: + raise ValueError(f'n must be an integer; got {n!r}') + x_shape = x.shape + adjusted_axis = axis + if adjusted_axis < 0: + adjusted_axis += 1 + len(x_shape) + if adjusted_axis < 0 or adjusted_axis > 1 + len(x_shape): + raise AxisError(f'invalid axis {axis}') + out_shape = (x_shape[:adjusted_axis] + (n,) + + x_shape[adjusted_axis:]) + if out is not None: + if out.shape != out_shape: + raise ValueError(f'out.shape must be {out_shape}; ' + f'got {out.shape}.') + else: + out = np.empty(out_shape, dtype=x.dtype) + return _nextn_less(x, out=out, axis=axis) + + +nextn_less.gufunc = _nextn_less diff --git a/ufunclab/tests/test_nextn.py b/ufunclab/tests/test_nextn.py index 2bf8649..aef4167 100644 --- a/ufunclab/tests/test_nextn.py +++ b/ufunclab/tests/test_nextn.py @@ -9,12 +9,27 @@ @pytest.mark.parametrize('dt', [np.dtype('float32'), np.dtype('float64'), np.dtype('longdouble')]) -def test_nextn_less(func, to, dt): +def test_nextn_gufunc(func, to, dt): to = dt.type(to) x = dt.type(2.5) n = 5 out = np.zeros(n, dtype=dt) - xn = func(x, out=out) + xn = func.gufunc(x, out=out) + for k in range(n): + x = np.nextafter(x, to) + assert_equal(xn[k], x) + + +@pytest.mark.parametrize('func, to', [(nextn_less, -np.inf), + (nextn_greater, np.inf)]) +@pytest.mark.parametrize('dt', [np.dtype('float32'), + np.dtype('float64'), + np.dtype('longdouble')]) +def test_nextn(func, to, dt): + to = dt.type(to) + x = dt.type(2.5) + n = 5 + xn = func(x, n) for k in range(n): x = np.nextafter(x, to) assert_equal(xn[k], x)