Skip to content

Commit

Permalink
ENH: Implement one_hot as a wrapped gufunc.
Browse files Browse the repository at this point in the history
  • Loading branch information
WarrenWeckesser committed Jul 30, 2024
1 parent 31b56eb commit c93470b
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 12 deletions.
45 changes: 33 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,7 @@ processed with the script in `ufunclab/tools/conv_template.py`.
| [`sosfilter_ic`](#sosfilter_ic) | SOS linear filter with initial condition |
| [`sosfilter_ic_contig`](#sosfilter_ic_contig) | SOS linear filter with contiguous array inputs |

*Wrapped generalized ufuncs*

These are Python functions that wrap a gufunc. The wrapper allows the
function to provide a capability that is not possible with a gufunc.

| Function | Description |
| -------- | ----------- |
| [`convert_to_base`](#convert_to_base) | Convert an integer to a given base. |
| [`nextn_greater`](#nextn_greater) | Next n values greater than the given x. |
| [`nextn_less`](#nextn_less) | Next n values greater than the given x. |

*Other functions.*
*Other generalized ufuncs.*

| Function | Description |
| -------- | ----------- |
Expand All @@ -147,6 +136,18 @@ function to provide a capability that is not possible with a gufunc.
| [`tri_area_indexed`](#tri_area_indexed) | Area of triangles in n-dimensional space |
| [`multivariate_logbeta`](#multivariate_logbeta) | Logarithm of the multivariate beta function |

*Wrapped generalized ufuncs*

These are Python functions that wrap a gufunc. The wrapper allows the
function to provide a capability that is not possible with a gufunc.

| Function | Description |
| -------- | ----------- |
| [`convert_to_base`](#convert_to_base) | Convert an integer to a given base. |
| [`nextn_greater`](#nextn_greater) | Next n values greater than the given x. |
| [`nextn_less`](#nextn_less) | Next n values greater than the given x. |
| [`one_hot`](#one_hot) | Create 1-d array that is 1 at index k and 0 elsewhere. |

*Other tools*

| Function | Description |
Expand Down Expand Up @@ -1469,6 +1470,26 @@ array([2.4999998, 2.4999995, 2.4999993, 2.499999 , 2.4999988],
dtype=float32)
```

#### `one_hot`

`one_hot(k, n, out=None, axis=-1)` is a Python function that wraps
a gufunc with signature `()->(n)`. Given integers `k` and `n`,
it returns a 1-d integer array with length `n`, where the value is
1 at index `k` and 0 elsewhere. If `k` is less than 0 or greater
than `n - 1`, the array will be all zeros.

```
>>> from ufunclab import one_hot
>>> one_hot(3, 10)
array([0, 0, 0, 1, 0, 0, 0, 0, 0, 0])
>>> one_hot([3, 7, 8], 10)
array([[0, 0, 0, 1, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0]])
```

#### `cross2`

`cross2(u, v)` is a gufunc with signature `(2),(2)->()`. It computes
Expand Down
34 changes: 34 additions & 0 deletions src/one_hot/define_cxx_gufunc_extmod.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

from ufunc_config_types import UFuncExtMod, UFunc, UFuncSource


ONE_HOT_DOCSTRING = """\
one_hot(k, /, ...)
Fill the 1-d output with zeros except at k, where the output is 1.
The `out` parameter of this ufunc must be given. The length
of `out` determines n.
"""

one_hot_src = UFuncSource(
funcname='one_hot_core_calc',
typesignatures=[
'p->p',
]
)

one_hot = UFunc(
name='one_hot',
header='one_hot_gufunc.h',
docstring=ONE_HOT_DOCSTRING,
signature='()->(n)',
sources=[one_hot_src],
)


extmod = UFuncExtMod(
module='_one_hot',
docstring=("This extension module defines the gufunc 'one_hot'."),
ufuncs=[one_hot],
)
43 changes: 43 additions & 0 deletions src/one_hot/one_hot_gufunc.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@

#ifndef ONE_HOT_GUFUNC_H
#define ONE_HOT_GUFUNC_H

#define PY_SSIZE_T_CLEAN
#include "Python.h"

#include <cmath>
#include <cstdio>

#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include "numpy/ndarraytypes.h"

#include "../src/util/strided.hpp"


//
// `one_hot_core_calc` is the C++ core function
// for the gufunc `one_hot` with signature '()->(n)'
//
template<typename T>
static void
one_hot_core_calc(
npy_intp n, // core dimension n
T *p_x, // pointer to x
T *p_out, // pointer to out, a strided 1-d array
npy_intp out_stride
)
{
if (out_stride == sizeof(T)) {
memset(p_out, 0, n*sizeof(T));
}
else {
for (npy_intp i = 0; i < n; ++i) {
set(p_out, out_stride, i, static_cast<npy_intp>(0));
}
}
if (0 <= *p_x && *p_x < n) {
set(p_out, out_stride, *p_x, static_cast<npy_intp>(1));
}
}

#endif
2 changes: 2 additions & 0 deletions tools/cxxgen/generate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
'M': 'NPY_DATETIME',
'm': 'NPY_TIMEDELTA',
'O': 'NPY_OBJECT',
'p': 'NPY_INTP',
'P': 'NPY_UINTP'
}


Expand Down
1 change: 1 addition & 0 deletions ufunclab/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
'convert_to_base': '._wrapped',
'nextn_greater': '._wrapped',
'nextn_less' : '._wrapped',
'one_hot': '._wrapped',
'logfactorial': '._logfact',
'loggamma1p': '._loggamma1p',
'issnan': '._issnan',
Expand Down
32 changes: 32 additions & 0 deletions ufunclab/_wrapped.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from ufunclab._convert_to_base import convert_to_base as _convert_to_base
from ufunclab._nextn import (nextn_less as _nextn_less,
nextn_greater as _nextn_greater)
from ufunclab._one_hot import one_hot as _one_hot


# XXX Except for `out` and `axis`, this function does not expose any of the
Expand Down Expand Up @@ -152,3 +153,34 @@ def nextn_less(x, n, out=None, axis=-1):


nextn_less.gufunc = _nextn_less


def one_hot(k, n, out=None, axis=-1):
"""
Create a 1-d integer array of length n, all zero except for 1 at index k.
"""
k = np.asarray(k)
if k.dtype.char not in np.typecodes['AllInteger']:
raise ValueError('k must be an integer scalar or array.')
try:
n = operator.index(n)
except TypeError:
raise ValueError(f'n must be an integer; got {n!r}')
k_shape = k.shape
adjusted_axis = axis
if adjusted_axis < 0:
adjusted_axis += 1 + len(k_shape)
if adjusted_axis < 0 or adjusted_axis > 1 + len(k_shape):
raise AxisError(f'invalid axis {axis}')
out_shape = (k_shape[:adjusted_axis] + (n,)
+ k_shape[adjusted_axis:])
if out is not None:
if out.shape != out_shape:
raise ValueError(f'out.shape must be {out_shape}; '
f'got {out.shape}.')
else:
out = np.empty(out_shape, dtype=k.dtype)
return _one_hot(k, out=out, axis=axis)


one_hot.gufunc = _one_hot
1 change: 1 addition & 0 deletions ufunclab/meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ gufunc_cxx_src_dirs = [
'minmax',
'multivariate_logbeta',
'nextn',
'one_hot',
'sosfilter',
'tri_area',
'vnorm',
Expand Down
15 changes: 15 additions & 0 deletions ufunclab/tests/test_one_hot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
from numpy.testing import assert_equal
from ufunclab import one_hot


def test_basic_scalar():
a = one_hot(3, 8)
assert_equal(a, np.array([0, 0, 0, 1, 0, 0, 0, 0]))


def test_basic_1d_k():
a = one_hot([3, 5, 6], 8)
assert_equal(a, np.array([[0, 0, 0, 1, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 1, 0]]))

0 comments on commit c93470b

Please sign in to comment.