From fd1fcfeae627a2b7cacbbae535594b110d9b99da Mon Sep 17 00:00:00 2001 From: Warren Weckesser Date: Sat, 21 Sep 2024 14:50:47 -0400 Subject: [PATCH] ENH: Handle complex weights in bincount(). --- src/bincount/bincount_gufunc.h | 49 ++++++++++++++++++++++++ src/bincount/define_cxx_gufunc_extmod.py | 20 ++++++++-- ufunclab/tests/test_bincount.py | 43 ++++++++++++++++++--- 3 files changed, 102 insertions(+), 10 deletions(-) diff --git a/src/bincount/bincount_gufunc.h b/src/bincount/bincount_gufunc.h index 578f1d2..8596408 100644 --- a/src/bincount/bincount_gufunc.h +++ b/src/bincount/bincount_gufunc.h @@ -5,8 +5,11 @@ #define PY_SSIZE_T_CLEAN #include "Python.h" +#include + #define NPY_NO_DEPRECATED_API NPY_API_VERSION #include "numpy/ndarraytypes.h" +#include "numpy/npy_math.h" #include "../src/util/strided.hpp" @@ -69,4 +72,50 @@ bincountw_core_calc( } } + +void static inline +npy_complex_add_inplace(npy_cfloat *a, npy_cfloat b) +{ + *((float *) a) += npy_crealf(b); + *((float *) a + 1) += npy_cimagf(b); +} + +void static inline +npy_complex_add_inplace(npy_cdouble *a, npy_cdouble b) +{ + *((double *) a) += npy_creal(b); + *((double *) a + 1) += npy_cimag(b); +} + +// +// `bincountw_complex_core_calc` is the C++ core function +// for the gufunc `bincountw` with signature '(n),(n)->(m)' +// that handles complex weights. +// +template +static void +bincountw_complex_core_calc( + npy_intp n, // core dimension n + npy_intp m, // core dimension m + T *p_x, // pointer to x + npy_intp x_stride, + W *p_w, // pointer to w + npy_intp w_stride, + W *p_out, // pointer to out, a strided 1-d array + npy_intp out_stride +) +{ + // Note that the output array is not initialized to 0. + // This allows repeated calls to accumulate results. + + for (npy_intp i = 0; i < n; ++i) { + T k = get(p_x, x_stride, i); + if (k >= 0 && static_cast(k) < m) { + W w = get(p_w, w_stride, i); + W *p = (W *)((char *)p_out + k*out_stride); + npy_complex_add_inplace(p, w); + } + } +} + #endif diff --git a/src/bincount/define_cxx_gufunc_extmod.py b/src/bincount/define_cxx_gufunc_extmod.py index 4e3a2e4..0786eaf 100644 --- a/src/bincount/define_cxx_gufunc_extmod.py +++ b/src/bincount/define_cxx_gufunc_extmod.py @@ -23,10 +23,17 @@ np.dtype('int16'), np.dtype('uint16'), np.dtype('int32'), np.dtype('uint32'), np.dtype('int64'), np.dtype('uint64')] -float_types = [np.dtype('f'), np.dtype('d')] -input_types = product(int_types, int_types + float_types) + bincount_types = [f'{t.char}->p' for t in int_types] -bincountw_types = [f'{in1.char}{in2.char}->{in2.char}' for (in1, in2) in input_types] + +float_types = [np.dtype('f'), np.dtype('d')] +complex_types = [np.dtype('F'), np.dtype('D')] +bincountw_input_types = product(int_types, int_types + float_types) +bincountw_types = [f'{in1.char}{in2.char}->{in2.char}' + for (in1, in2) in bincountw_input_types] +bincountw_complex_input_types = product(int_types, complex_types) +bincountw_complex_types = [f'{in1.char}{in2.char}->{in2.char}' + for (in1, in2) in bincountw_complex_input_types] bincount_src = UFuncSource( funcname='bincount_core_calc', @@ -46,12 +53,17 @@ typesignatures=bincountw_types, ) +bincountw_complex_src = UFuncSource( + funcname='bincountw_complex_core_calc', + typesignatures=bincountw_complex_types, +) + bincountw = UFunc( name='bincountw', header='bincount_gufunc.h', docstring=BINCOUNTW_DOCSTRING, signature='(n),(n)->(m)', - sources=[bincountw_src], + sources=[bincountw_src, bincountw_complex_src], ) extmod = UFuncExtMod( diff --git a/ufunclab/tests/test_bincount.py b/ufunclab/tests/test_bincount.py index 33e1054..3473b34 100644 --- a/ufunclab/tests/test_bincount.py +++ b/ufunclab/tests/test_bincount.py @@ -4,20 +4,20 @@ from ufunclab import bincount -def test_bincount_1d_default_m(): +def test_1d_default_m(): x = np.array([0, 0, 3, 4, 3, 4, 0, 3, 0]) y = bincount(x) assert_equal(y, [4, 0, 0, 3, 2]) -def test_bincount_1d_given_m(): +def test_1d_given_m(): x = np.array([0, 0, 3, 4, 3, 4, 0, 3, 0]) y = bincount(x, 8) assert_equal(y, [4, 0, 0, 3, 2, 0, 0, 0]) @pytest.mark.parametrize('m', [None, 2, 5, 8]) -def test_bincount_nd_default_m(m): +def test_nd_default_m(m): x = np.array([[0, 4, 4, 3, 2, 1], [1, 1, 2, 2, 3, 3], [4, 4, 4, 4, 4, 2]]) @@ -36,7 +36,7 @@ def test_bincount_nd_default_m(m): @pytest.mark.parametrize('m', [None, 2, 5, 8]) -def test_bincount_axis(m): +def test_axis(m): x = np.array([[0, 4, 4], [1, 1, 2], [4, 4, 4], @@ -57,14 +57,14 @@ def test_bincount_axis(m): assert_equal(y, expected) -def test_bincount_weights(): +def test_weights(): x = np.array([3, 1, 1, 0, 3]) w = np.array([4, 9, 1, 3, 5]) b = bincount(x, weights=w) assert_equal(b, [3, 10, 0, 9]) -def test_bincount_weights_axis(): +def test_weights_axis(): x = np.array([[3, 1, 1], [2, 3, 3], [1, 2, 2], @@ -76,3 +76,34 @@ def test_bincount_weights_axis(): [2.0, 3.0, 7.0]]) b = bincount(x, weights=w, axis=0) assert_equal(b, expected) + + +@pytest.mark.parametrize('dtype', [np.dtype('F'), np.dtype('D')]) +def test_complex_weights(dtype): + x = np.array([3, 1, 1, 0, 3]) + w = np.array([4+1j, 9-1j, 1+0.5j, 3+0.25j, 5.0], dtype=dtype) + b = bincount(x, weights=w) + assert b.dtype == w.dtype + assert_equal(b, [3 + 0.25j, 10 - 0.5j, 0.0, 9.0 + 1.0j]) + + +def test_complex_weights_axis(): + x = np.array([[3, 1, 1], + [2, 3, 3], + [1, 2, 2], + [2, 2, 3]]) + w = np.array([2.0 + 2j, 3.0 + 7j, 5.0, 4.0 - 1j]) + b = bincount(x, weights=w, axis=0) + br = bincount(x, weights=w.real, axis=0) + bi = bincount(x, weights=w.imag, axis=0) + expected_br = np.array([[0.0, 0.0, 0.0], + [5.0, 2.0, 2.0], + [7.0, 9.0, 5.0], + [2.0, 3.0, 7.0]]) + expected_bi = np.array([[0.0, 0.0, 0.0], + [0.0, 2.0, 2.0], + [6.0, -1.0, 0.0], + [2.0, 7.0, 6.0]]) + assert_equal(br, expected_br) + assert_equal(bi, expected_bi) + assert_equal(b, br + 1j*bi)