Skip to content

Commit

Permalink
ENH: Handle complex weights in bincount().
Browse files Browse the repository at this point in the history
  • Loading branch information
WarrenWeckesser committed Sep 21, 2024
1 parent 9456a6d commit fd1fcfe
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 10 deletions.
49 changes: 49 additions & 0 deletions src/bincount/bincount_gufunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@
#define PY_SSIZE_T_CLEAN
#include "Python.h"

#include <complex>

#define NPY_NO_DEPRECATED_API NPY_API_VERSION
#include "numpy/ndarraytypes.h"
#include "numpy/npy_math.h"

#include "../src/util/strided.hpp"

Expand Down Expand Up @@ -69,4 +72,50 @@ bincountw_core_calc(
}
}


void static inline
npy_complex_add_inplace(npy_cfloat *a, npy_cfloat b)
{
*((float *) a) += npy_crealf(b);
*((float *) a + 1) += npy_cimagf(b);
}

void static inline
npy_complex_add_inplace(npy_cdouble *a, npy_cdouble b)
{
*((double *) a) += npy_creal(b);
*((double *) a + 1) += npy_cimag(b);
}

//
// `bincountw_complex_core_calc` is the C++ core function
// for the gufunc `bincountw` with signature '(n),(n)->(m)'
// that handles complex weights.
//
template<typename T, typename W>
static void
bincountw_complex_core_calc(
npy_intp n, // core dimension n
npy_intp m, // core dimension m
T *p_x, // pointer to x
npy_intp x_stride,
W *p_w, // pointer to w
npy_intp w_stride,
W *p_out, // pointer to out, a strided 1-d array
npy_intp out_stride
)
{
// Note that the output array is not initialized to 0.
// This allows repeated calls to accumulate results.

for (npy_intp i = 0; i < n; ++i) {
T k = get(p_x, x_stride, i);
if (k >= 0 && static_cast<npy_intp>(k) < m) {
W w = get(p_w, w_stride, i);
W *p = (W *)((char *)p_out + k*out_stride);
npy_complex_add_inplace(p, w);
}
}
}

#endif
20 changes: 16 additions & 4 deletions src/bincount/define_cxx_gufunc_extmod.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,17 @@
np.dtype('int16'), np.dtype('uint16'),
np.dtype('int32'), np.dtype('uint32'),
np.dtype('int64'), np.dtype('uint64')]
float_types = [np.dtype('f'), np.dtype('d')]
input_types = product(int_types, int_types + float_types)

bincount_types = [f'{t.char}->p' for t in int_types]
bincountw_types = [f'{in1.char}{in2.char}->{in2.char}' for (in1, in2) in input_types]

float_types = [np.dtype('f'), np.dtype('d')]
complex_types = [np.dtype('F'), np.dtype('D')]
bincountw_input_types = product(int_types, int_types + float_types)
bincountw_types = [f'{in1.char}{in2.char}->{in2.char}'
for (in1, in2) in bincountw_input_types]
bincountw_complex_input_types = product(int_types, complex_types)
bincountw_complex_types = [f'{in1.char}{in2.char}->{in2.char}'
for (in1, in2) in bincountw_complex_input_types]

bincount_src = UFuncSource(
funcname='bincount_core_calc',
Expand All @@ -46,12 +53,17 @@
typesignatures=bincountw_types,
)

bincountw_complex_src = UFuncSource(
funcname='bincountw_complex_core_calc',
typesignatures=bincountw_complex_types,
)

bincountw = UFunc(
name='bincountw',
header='bincount_gufunc.h',
docstring=BINCOUNTW_DOCSTRING,
signature='(n),(n)->(m)',
sources=[bincountw_src],
sources=[bincountw_src, bincountw_complex_src],
)

extmod = UFuncExtMod(
Expand Down
43 changes: 37 additions & 6 deletions ufunclab/tests/test_bincount.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,20 @@
from ufunclab import bincount


def test_bincount_1d_default_m():
def test_1d_default_m():
x = np.array([0, 0, 3, 4, 3, 4, 0, 3, 0])
y = bincount(x)
assert_equal(y, [4, 0, 0, 3, 2])


def test_bincount_1d_given_m():
def test_1d_given_m():
x = np.array([0, 0, 3, 4, 3, 4, 0, 3, 0])
y = bincount(x, 8)
assert_equal(y, [4, 0, 0, 3, 2, 0, 0, 0])


@pytest.mark.parametrize('m', [None, 2, 5, 8])
def test_bincount_nd_default_m(m):
def test_nd_default_m(m):
x = np.array([[0, 4, 4, 3, 2, 1],
[1, 1, 2, 2, 3, 3],
[4, 4, 4, 4, 4, 2]])
Expand All @@ -36,7 +36,7 @@ def test_bincount_nd_default_m(m):


@pytest.mark.parametrize('m', [None, 2, 5, 8])
def test_bincount_axis(m):
def test_axis(m):
x = np.array([[0, 4, 4],
[1, 1, 2],
[4, 4, 4],
Expand All @@ -57,14 +57,14 @@ def test_bincount_axis(m):
assert_equal(y, expected)


def test_bincount_weights():
def test_weights():
x = np.array([3, 1, 1, 0, 3])
w = np.array([4, 9, 1, 3, 5])
b = bincount(x, weights=w)
assert_equal(b, [3, 10, 0, 9])


def test_bincount_weights_axis():
def test_weights_axis():
x = np.array([[3, 1, 1],
[2, 3, 3],
[1, 2, 2],
Expand All @@ -76,3 +76,34 @@ def test_bincount_weights_axis():
[2.0, 3.0, 7.0]])
b = bincount(x, weights=w, axis=0)
assert_equal(b, expected)


@pytest.mark.parametrize('dtype', [np.dtype('F'), np.dtype('D')])
def test_complex_weights(dtype):
x = np.array([3, 1, 1, 0, 3])
w = np.array([4+1j, 9-1j, 1+0.5j, 3+0.25j, 5.0], dtype=dtype)
b = bincount(x, weights=w)
assert b.dtype == w.dtype
assert_equal(b, [3 + 0.25j, 10 - 0.5j, 0.0, 9.0 + 1.0j])


def test_complex_weights_axis():
x = np.array([[3, 1, 1],
[2, 3, 3],
[1, 2, 2],
[2, 2, 3]])
w = np.array([2.0 + 2j, 3.0 + 7j, 5.0, 4.0 - 1j])
b = bincount(x, weights=w, axis=0)
br = bincount(x, weights=w.real, axis=0)
bi = bincount(x, weights=w.imag, axis=0)
expected_br = np.array([[0.0, 0.0, 0.0],
[5.0, 2.0, 2.0],
[7.0, 9.0, 5.0],
[2.0, 3.0, 7.0]])
expected_bi = np.array([[0.0, 0.0, 0.0],
[0.0, 2.0, 2.0],
[6.0, -1.0, 0.0],
[2.0, 7.0, 6.0]])
assert_equal(br, expected_br)
assert_equal(bi, expected_bi)
assert_equal(b, br + 1j*bi)

0 comments on commit fd1fcfe

Please sign in to comment.