From 197b144c9eaad8916dabfa7377f48073c7e5a78f Mon Sep 17 00:00:00 2001 From: Drew Hubley Date: Wed, 24 Apr 2024 20:45:39 -0300 Subject: [PATCH] Add xtensor fft implementation --- docs/source/xfft.rst | 17 +++ include/xtensor/xfft.hpp | 241 +++++++++++++++++++++++++++++++++++++++ test/test_xfft.cpp | 86 ++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 docs/source/xfft.rst create mode 100644 include/xtensor/xfft.hpp create mode 100644 test/test_xfft.cpp diff --git a/docs/source/xfft.rst b/docs/source/xfft.rst new file mode 100644 index 000000000..78ba5ff7d --- /dev/null +++ b/docs/source/xfft.rst @@ -0,0 +1,17 @@ +.. Copyright (c) 2016, Johan Mabille, Sylvain Corlay and Wolf Vollprecht + Distributed under the terms of the BSD 3-Clause License. + The full license is in the file LICENSE, distributed with this software. +xfft +==== + +Defined in ``xtensor/xfft.hpp`` + +.. doxygenclass:: xt::fft_convolve + :project: xtensor + :members: + +.. doxygentypedef:: xt::fft + :project: xtensor + +.. doxygentypedef:: xt::ifft + :project: xtensor diff --git a/include/xtensor/xfft.hpp b/include/xtensor/xfft.hpp new file mode 100644 index 000000000..472302d07 --- /dev/null +++ b/include/xtensor/xfft.hpp @@ -0,0 +1,241 @@ +#ifdef XTENSOR_USE_TBB +#include +#endif +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace xt +{ + namespace fft + { + namespace detail + { + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto radix2(E&& e) + { + using namespace xt::placeholders; + using namespace std::complex_literals; + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + auto N = e.size(); + const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); + // check for power of 2 + if (!powerOfTwo || N == 0) + { + // TODO: Replace implementation with dft + XTENSOR_THROW(std::runtime_error, "FFT Implementation requires power of 2"); + } + auto pi = xt::numeric_constants::PI; + xt::xtensor ev = e; + if (N <= 1) + { + return ev; + } + else + { +#ifdef XTENSOR_USE_TBB + xt::xtensor even; + xt::xtensor odd; + oneapi::tbb::parallel_invoke( + [&] + { + even = radix2(xt::view(ev, xt::range(0, _, 2))); + }, + [&] + { + odd = radix2(xt::view(ev, xt::range(1, _, 2))); + } + ); +#else + auto even = radix2(xt::view(ev, xt::range(0, _, 2))); + auto odd = radix2(xt::view(ev, xt::range(1, _, 2))); +#endif + + auto range = xt::arange(N / 2); + auto exp = xt::exp(static_cast(-2i) * pi * range / N); + auto t = exp * odd; + auto first_half = even + t; + auto second_half = even - t; + // TODO: should be a call to stack if performance was improved + auto spectrum = xt::xtensor::from_shape({N}); + xt::view(spectrum, xt::range(0, N / 2)) = first_half; + xt::view(spectrum, xt::range(N / 2, N)) = second_half; + return spectrum; + } + } + + template + auto transform_bluestein(E&& data) + { + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + + // Find a power-of-2 convolution length m such that m >= n * 2 + 1 + const std::size_t n = data.size(); + size_t m = std::ceil(std::log2(n * 2 + 1)); + m = std::pow(2, m); + + // Trignometric table + auto exp_table = xt::xtensor, 1>::from_shape({n}); + xt::xtensor i = xt::pow(xt::linspace(0, n - 1, n), 2); + i %= (n * 2); + + auto angles = xt::eval(precision{3.141592653589793238463} * i / n); + auto j = std::complex(0, 1); + exp_table = xt::exp(-angles * j); + + // Temporary vectors and preprocessing + auto av = xt::empty>({m}); + xt::view(av, xt::range(0, n)) = data * exp_table; + + + auto bv = xt::empty>({m}); + xt::view(bv, xt::range(0, n)) = ::xt::conj(exp_table); + xt::view(bv, xt::range(-n + 1, xt::placeholders::_)) = xt::view( + ::xt::conj(xt::flip(exp_table)), + xt::range(xt::placeholders::_, -1) + ); + + // Convolution + auto xv = radix2(av); + auto yv = radix2(bv); + auto spectrum_k = xv * yv; + auto complex_args = xt::conj(spectrum_k); + auto fft_res = radix2(complex_args); + auto cv = xt::conj(fft_res) / m; + + return xt::eval(xt::view(cv, xt::range(0, n)) * exp_table); + } + } // namespace detail + + /** + * @brief 1D FFT of an Nd array along a specified axis + * @param e an Nd expression to be transformed to the fourier domain + * @param axis the axis along which to perform the 1D FFT + * @return a transformed xarray of the specified precision + */ + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto fft(E&& e, std::ptrdiff_t axis = -1) + { + using value_type = typename std::decay_t::value_type; + using precision = typename value_type::value_type; + const auto saxis = xt::normalize_axis(e.dimension(), axis); + const size_t N = e.shape(saxis); + const bool powerOfTwo = !(N == 0) && !(N & (N - 1)); + xt::xarray> out = xt::eval(e); + auto begin = xt::axis_slice_begin(out, saxis); + auto end = xt::axis_slice_end(out, saxis); + for (auto iter = begin; iter != end; iter++) + { + if (powerOfTwo) + { + xt::noalias(*iter) = detail::radix2(*iter); + } + else + { + xt::noalias(*iter) = detail::transform_bluestein(*iter); + } + } + return out; + } + + /** + * @breif 1D FFT of an Nd array along a specified axis + * @param e an Nd expression to be transformed to the fourier domain + * @param axis the axis along which to perform the 1D FFT + * @return a transformed xarray of the specified precision + */ + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto fft(E&& e, std::ptrdiff_t axis = -1) + { + using value_type = typename std::decay::type::value_type; + return fft(xt::cast>(e), axis); + } + + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + auto ifft(E&& e, std::ptrdiff_t axis = -1) + { + // check the length of the data on that axis + const std::size_t n = e.shape(axis); + if (n == 0) + { + XTENSOR_THROW(std::runtime_error, "Cannot take the iFFT along an empty dimention"); + } + auto complex_args = xt::conj(e); + auto fft_res = xt::fft::fft(complex_args, axis); + fft_res = xt::conj(fft_res); + return fft_res; + } + + template < + class E, + typename std::enable_if::type::value_type>::value, bool>::type = true> + inline auto ifft(E&& e, std::ptrdiff_t axis = -1) + { + using value_type = typename std::decay::type::value_type; + return ifft(xt::cast>(e), axis); + } + + /* + * @brief performs a circular fft convolution xvec and yvec must + * be the same shape. + * @param xvec first array of the convolution + * @param yvec second array of the convolution + * @param axis axis along which to perform the convolution + */ + template + auto convolve(E1&& xvec, E2&& yvec, std::ptrdiff_t axis = -1) + { + // we could broadcast but that could get complicated??? + if (xvec.dimension() != yvec.dimension()) + { + XTENSOR_THROW(std::runtime_error, "Mismatched dimentions"); + } + + auto saxis = xt::normalize_axis(xvec.dimension(), axis); + if (xvec.shape(saxis) != yvec.shape(saxis)) + { + XTENSOR_THROW(std::runtime_error, "Mismatched lengths along slice axis"); + } + + const std::size_t n = xvec.shape(saxis); + + auto xv = fft(xvec, axis); + auto yv = fft(yvec, axis); + + auto begin_x = xt::axis_slice_begin(xv, saxis); + auto end_x = xt::axis_slice_end(xv, saxis); + auto iter_y = xt::axis_slice_begin(yv, saxis); + + for (auto iter = begin_x; iter != end_x; iter++) + { + (*iter) = (*iter_y++) * (*iter); + } + + auto outvec = ifft(xv, axis); + + // Scaling (because this FFT implementation omits it) + outvec = outvec / n; + + return outvec; + } + + } +} // namespace xt::fft diff --git a/test/test_xfft.cpp b/test/test_xfft.cpp new file mode 100644 index 000000000..d7fd78896 --- /dev/null +++ b/test/test_xfft.cpp @@ -0,0 +1,86 @@ +#include "xtensor/xarray.hpp" +#include "xtensor/xfft.hpp" + +#include "test_common_macros.hpp" + +namespace xt +{ + TEST(xfft, fft_power_2) + { + size_t k = 2; + size_t n = 8192; + size_t A = 10; + auto x = xt::linspace(0, static_cast(n - 1), n); + xt::xarray y = A * xt::sin(2 * xt::numeric_constants::PI * x * k / n); + auto res = xt::fft::fft(y) / (n / 2); + REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001)); + } + + TEST(xfft, ifft_power_2) + { + size_t k = 2; + size_t n = 8; + size_t A = 10; + auto x = xt::linspace(0, static_cast(n - 1), n); + xt::xarray y = A * xt::sin(2 * xt::numeric_constants::PI * x * k / n); + auto res = xt::fft::ifft(y) / (n / 2); + REQUIRE(A == doctest::Approx(std::abs(res(k))).epsilon(.0001)); + } + + TEST(xfft, convolve_power_2) + { + xt::xarray x = {1.0, 1.0, 1.0, 5.0}; + xt::xarray y = {5.0, 1.0, 1.0, 1.0}; + xt::xarray expected = {12, 12, 12, 28}; + + auto result = xt::fft::convolve(x, y); + + for (size_t i = 0; i < x.size(); i++) + { + REQUIRE(expected(i) == doctest::Approx(std::abs(result(i))).epsilon(.0001)); + } + } + + TEST(xfft, fft_n_0_axis) + { + size_t k = 2; + size_t n = 10; + size_t A = 1; + size_t dim = 10; + auto x = xt::linspace(0, n - 1, n) * xt::ones({dim, n}); + xt::xarray y = A * xt::sin(2 * xt::numeric_constants::PI * x * k / n); + y = xt::transpose(y); + auto res = xt::fft::fft(y, 0) / (n / 2.0); + REQUIRE(A == doctest::Approx(std::abs(res(k, 0))).epsilon(.0001)); + REQUIRE(A == doctest::Approx(std::abs(res(k, 1))).epsilon(.0001)); + } + + TEST(xfft, fft_n_1_axis) + { + size_t k = 2; + size_t n = 15; + size_t A = 1; + size_t dim = 2; + auto x = xt::linspace(0, n - 1, n) * xt::ones({dim, n}); + xt::xarray y = A * xt::sin(2 * xt::numeric_constants::PI * x * k / n); + auto res = xt::fft::fft(y) / (n / 2.0); + REQUIRE(A == doctest::Approx(std::abs(res(0, k))).epsilon(.0001)); + REQUIRE(A == doctest::Approx(std::abs(res(1, k))).epsilon(.0001)); + } + + TEST(xfft, convolve_n) + { + xt::xarray x = {1.0, 1.0, 1.0, 5.0, 1.0}; + xt::xarray y = {5.0, 1.0, 1.0, 1.0, 1.0}; + xt::xarray expected = {13, 13, 13, 29, 13}; + + auto result = xt::fft::convolve(x, y); + + xt::xarray abs = xt::abs(result); + + for (size_t i = 0; i < abs.size(); i++) + { + REQUIRE(expected(i) == doctest::Approx(abs(i)).epsilon(.0001)); + } + } +}