From 4e94b009543d7ff8377514b170b9e1b36f3f3469 Mon Sep 17 00:00:00 2001 From: sarah el kazdadi Date: Mon, 6 Feb 2023 12:47:05 +0100 Subject: [PATCH] feat: implement 128bit fft --- Cargo.toml | 13 +- LICENSE | 7 +- README.md | 57 +- benches/bench.rs | 308 ++++++ benches/fft.rs | 190 ---- benches/lib.rs | 3 - src/fft128/f128_impl.rs | 1075 +++++++++++++++++++ src/fft128/mod.rs | 2204 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 7 + src/ordered.rs | 2 +- src/unordered.rs | 8 +- 11 files changed, 3641 insertions(+), 233 deletions(-) create mode 100644 benches/bench.rs delete mode 100644 benches/fft.rs delete mode 100644 benches/lib.rs create mode 100644 src/fft128/f128_impl.rs create mode 100644 src/fft128/mod.rs diff --git a/Cargo.toml b/Cargo.toml index c6675ff..c52fb55 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,13 +13,16 @@ keywords = ["fft"] [dependencies] num-complex = "0.4" dyn-stack = { version = "0.8", default-features = false } +pulp = "0.10" +bytemuck = "1.13" aligned-vec = { version = "0.5", default-features = false } serde = { version = "1.0", optional = true, default-features = false } [features] -default = ["std"] -nightly = [] +default = ["std", "fft128"] +nightly = ["pulp/nightly", "bytemuck/nightly_stdsimd"] std = [] +fft128 = [] serde = ["dep:serde", "num-complex/serde"] [dev-dependencies] @@ -28,9 +31,13 @@ rustfft = "6.0" fftw-sys = { version = "0.6", default-features = false, features = ["system"] } rand = "0.8" bincode = "1.3" +more-asserts = "0.3.1" + +[target.'cfg(target_os = "linux")'.dev-dependencies] +rug = "1.19.0" [[bench]] -name = "fft" +name = "bench" harness = false [package.metadata.docs.rs] diff --git a/LICENSE b/LICENSE index f1ec11c..62fdc0b 100644 --- a/LICENSE +++ b/LICENSE @@ -16,7 +16,7 @@ materials provided with the distribution. 3. Neither the name of ZAMA nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. -NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE*. +NO EXPRESS OR IMPLIED LICENSES TO ANY PARTY'S PATENT RIGHTS ARE GRANTED BY THIS LICENSE. THIS SOFTWARE IS PROVIDED BY THE ZAMA AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL @@ -26,8 +26,3 @@ OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CA ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -*In addition to the rights carried by this license, ZAMA grants to the user a non-exclusive, -free and non-commercial license on all patents filed in its name relating to the open-source -code (the "Patents") for the sole purpose of evaluation, development, research, prototyping -and experimentation. diff --git a/README.md b/README.md index f3002c6..2cec244 100644 --- a/README.md +++ b/README.md @@ -3,33 +3,38 @@ that processes vectors of sizes that are powers of two. It was made to be used as a backend in Zama's `concrete` library. This library provides two FFT modules: - - The ordered module FFT applies a forward/inverse FFT that takes its input in standard - order, and outputs the result in standard order. For more detail on what the FFT - computes, check the ordered module-level documentation. - - The unordered module FFT applies a forward FFT that takes its input in standard order, - and outputs the result in a certain permuted order that may depend on the FFT plan. On the - other hand, the inverse FFT takes its input in that same permuted order and outputs its result - in standard order. This is useful for cases where the order of the coefficients in the - Fourier domain is not important. An example is using the Fourier transform for vector - convolution. The only operations that are performed in the Fourier domain are elementwise, and - so the order of the coefficients does not affect the results. + +- The ordered module FFT applies a forward/inverse FFT that takes its input in standard + order, and outputs the result in standard order. For more detail on what the FFT + computes, check the ordered module-level documentation. +- The unordered module FFT applies a forward FFT that takes its input in standard order, + and outputs the result in a certain permuted order that may depend on the FFT plan. On the + other hand, the inverse FFT takes its input in that same permuted order and outputs its result + in standard order. This is useful for cases where the order of the coefficients in the + Fourier domain is not important. An example is using the Fourier transform for vector + convolution. The only operations that are performed in the Fourier domain are elementwise, and + so the order of the coefficients does not affect the results. + +Additionally, an optional 128-bit negacyclic FFT module is provided. ## Features - - `std` (default): This enables runtime arch detection for accelerated SIMD - instructions, and an FFT plan that measures the various implementations to - choose the fastest one at runtime. - - `nightly`: This enables unstable Rust features to further speed up the FFT, - by enabling AVX512F instructions on CPUs that support them. This feature - requires a nightly Rust - toolchain. - - `serde`: This enables serialization and deserialization functions for the - unordered plan. These allow for data in the Fourier domain to be serialized - from the permuted order to the standard order, and deserialized from the - standard order to the permuted order. This is needed since the inverse - transform must be used with the same plan that computed/deserialized the - forward transform (or more specifically, a plan with the same internal base - FFT size). +- `std` (default): This enables runtime arch detection for accelerated SIMD + instructions, and an FFT plan that measures the various implementations to + choose the fastest one at runtime. +- `fft128` (default): This flag provides access to the 128-bit FFT, which is accessible in the + `concrete_fft::fft128` module. +- `nightly`: This enables unstable Rust features to further speed up the FFT, + by enabling AVX512F instructions on CPUs that support them. This feature + requires a nightly Rust + toolchain. +- `serde`: This enables serialization and deserialization functions for the + unordered plan. These allow for data in the Fourier domain to be serialized + from the permuted order to the standard order, and deserialized from the + standard order to the permuted order. This is needed since the inverse + transform must be used with the same plan that computed/deserialized the + forward transform (or more specifically, a plan with the same internal base + FFT size). ## Example @@ -65,8 +70,8 @@ for (actual, expected) in transformed_inv.iter().map(|z| z / N as f64).zip(data) ## Links - - [Zama](https://www.zama.ai/) - - [Concrete](https://github.com/zama-ai/concrete) +- [Zama](https://www.zama.ai/) +- [Concrete](https://github.com/zama-ai/concrete) ## License diff --git a/benches/bench.rs b/benches/bench.rs new file mode 100644 index 0000000..53aac52 --- /dev/null +++ b/benches/bench.rs @@ -0,0 +1,308 @@ +use concrete_fft::c64; +use core::ptr::NonNull; +use criterion::{criterion_group, criterion_main, Criterion}; +use dyn_stack::{DynStack, ReborrowMut, StackReq}; + +struct FftwAlloc { + bytes: NonNull, +} + +impl Drop for FftwAlloc { + fn drop(&mut self) { + unsafe { + fftw_sys::fftw_free(self.bytes.as_ptr()); + } + } +} + +impl FftwAlloc { + pub fn new(size_bytes: usize) -> FftwAlloc { + unsafe { + let bytes = fftw_sys::fftw_malloc(size_bytes); + if bytes.is_null() { + use std::alloc::{handle_alloc_error, Layout}; + handle_alloc_error(Layout::from_size_align_unchecked(size_bytes, 1)); + } + FftwAlloc { + bytes: NonNull::new_unchecked(bytes), + } + } + } +} + +pub struct PlanInterleavedC64 { + plan: fftw_sys::fftw_plan, + n: usize, +} + +impl Drop for PlanInterleavedC64 { + fn drop(&mut self) { + unsafe { + fftw_sys::fftw_destroy_plan(self.plan); + } + } +} + +pub enum Sign { + Forward, + Backward, +} + +impl PlanInterleavedC64 { + pub fn new(n: usize, sign: Sign) -> Self { + let size_bytes = n.checked_mul(core::mem::size_of::()).unwrap(); + let src = FftwAlloc::new(size_bytes); + let dst = FftwAlloc::new(size_bytes); + unsafe { + let p = fftw_sys::fftw_plan_dft_1d( + n.try_into().unwrap(), + src.bytes.as_ptr() as _, + dst.bytes.as_ptr() as _, + match sign { + Sign::Forward => fftw_sys::FFTW_FORWARD as _, + Sign::Backward => fftw_sys::FFTW_BACKWARD as _, + }, + fftw_sys::FFTW_MEASURE, + ); + PlanInterleavedC64 { plan: p, n } + } + } + + pub fn print(&self) { + unsafe { + fftw_sys::fftw_print_plan(self.plan); + } + } + + pub fn execute(&self, src: &mut [c64], dst: &mut [c64]) { + assert_eq!(src.len(), self.n); + assert_eq!(dst.len(), self.n); + let src = src.as_mut_ptr(); + let dst = dst.as_mut_ptr(); + unsafe { + use fftw_sys::{fftw_alignment_of, fftw_execute_dft}; + assert_eq!(fftw_alignment_of(src as _), 0); + assert_eq!(fftw_alignment_of(dst as _), 0); + fftw_execute_dft(self.plan, src as _, dst as _); + } + } +} + +pub fn criterion_benchmark(c: &mut Criterion) { + for n in [ + 1 << 8, + 1 << 9, + 1 << 10, + 1 << 11, + 1 << 12, + 1 << 13, + 1 << 14, + 1 << 15, + 1 << 16, + ] { + let mut mem = dyn_stack::GlobalMemBuffer::new( + StackReq::new_aligned::(n, 64) // scratch + .and( + StackReq::new_aligned::(2 * n, 64).or(StackReq::new_aligned::(n, 64)), // src | twiddles + ) + .and(StackReq::new_aligned::(n, 64)), // dst + ); + let mut stack = DynStack::new(&mut mem); + let z = c64::new(0.0, 0.0); + + { + let mut scratch = []; + let bench_duration = std::time::Duration::from_millis(10); + + { + let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + let (mut src, _) = stack.make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("rustfft-fwd-{n}"), |b| { + use rustfft::FftPlannerAvx; + let mut planner = FftPlannerAvx::::new().unwrap(); + let fwd_rustfft = planner.plan_fft_forward(n); + + b.iter(|| { + fwd_rustfft.process_outofplace_with_scratch( + &mut src, + &mut dst, + &mut scratch, + ) + }) + }); + + c.bench_function(&format!("fftw-fwd-{n}"), |b| { + let fwd_fftw = PlanInterleavedC64::new(n, Sign::Forward); + + b.iter(|| { + fwd_fftw.execute(&mut src, &mut dst); + }) + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("concrete-fwd-{n}"), |b| { + let ordered = concrete_fft::ordered::Plan::new( + n, + concrete_fft::ordered::Method::Measure(bench_duration), + ); + + b.iter(|| ordered.fwd(&mut dst, stack.rb_mut())) + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("unordered-fwd-{n}"), |b| { + let unordered = concrete_fft::unordered::Plan::new( + n, + concrete_fft::unordered::Method::Measure(bench_duration), + ); + + b.iter(|| unordered.fwd(&mut dst, stack.rb_mut())); + }); + } + { + let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("unordered-inv-{n}"), |b| { + let unordered = concrete_fft::unordered::Plan::new( + n, + concrete_fft::unordered::Method::Measure(bench_duration), + ); + + b.iter(|| unordered.inv(&mut dst, stack.rb_mut())); + }); + } + } + + // memcpy + { + let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); + let (src, _) = stack.make_aligned_with::(n, 64, |_| z); + + c.bench_function(&format!("memcpy-{n}"), |b| { + b.iter(|| unsafe { + std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), n); + }) + }); + } + } + + use concrete_fft::fft128::*; + for n in [64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384] { + let twid_re0 = vec![0.0; n]; + let twid_re1 = vec![0.0; n]; + let twid_im0 = vec![0.0; n]; + let twid_im1 = vec![0.0; n]; + + let mut data_re0 = vec![0.0; n]; + let mut data_re1 = vec![0.0; n]; + let mut data_im0 = vec![0.0; n]; + let mut data_im1 = vec![0.0; n]; + + c.bench_function(&format!("concrete-fft128-fwd-{n}"), |bench| { + bench.iter(|| { + negacyclic_fwd_fft_scalar( + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + + c.bench_function(&format!("concrete-fft128-inv-{n}"), |bench| { + bench.iter(|| { + negacyclic_inv_fft_scalar( + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + if let Some(simd) = Avx::try_new() { + c.bench_function(&format!("concrete-fft128-avx-fwd-{n}"), |bench| { + bench.iter(|| { + negacyclic_fwd_fft_avxfma( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + c.bench_function(&format!("concrete-fft128-avx-inv-{n}"), |bench| { + bench.iter(|| { + negacyclic_inv_fft_avxfma( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(feature = "nightly")] + if let Some(simd) = Avx512::try_new() { + c.bench_function(&format!("concrete-fft128-avx512-fwd-{n}"), |bench| { + bench.iter(|| { + negacyclic_fwd_fft_avx512( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + c.bench_function(&format!("concrete-fft128-avx512-inv-{n}"), |bench| { + bench.iter(|| { + negacyclic_inv_fft_avx512( + simd, + &mut data_re0, + &mut data_re1, + &mut data_im0, + &mut data_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + }); + }); + } + } +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/benches/fft.rs b/benches/fft.rs deleted file mode 100644 index 7d93b6b..0000000 --- a/benches/fft.rs +++ /dev/null @@ -1,190 +0,0 @@ -use concrete_fft::c64; -use core::ptr::NonNull; -use criterion::{criterion_group, criterion_main, Criterion}; -use dyn_stack::{DynStack, ReborrowMut, StackReq}; - -struct FftwAlloc { - bytes: NonNull, -} - -impl Drop for FftwAlloc { - fn drop(&mut self) { - unsafe { - fftw_sys::fftw_free(self.bytes.as_ptr()); - } - } -} - -impl FftwAlloc { - pub fn new(size_bytes: usize) -> FftwAlloc { - unsafe { - let bytes = fftw_sys::fftw_malloc(size_bytes); - if bytes.is_null() { - use std::alloc::{handle_alloc_error, Layout}; - handle_alloc_error(Layout::from_size_align_unchecked(size_bytes, 1)); - } - FftwAlloc { - bytes: NonNull::new_unchecked(bytes), - } - } - } -} - -pub struct PlanInterleavedC64 { - plan: fftw_sys::fftw_plan, - n: usize, -} - -impl Drop for PlanInterleavedC64 { - fn drop(&mut self) { - unsafe { - fftw_sys::fftw_destroy_plan(self.plan); - } - } -} - -pub enum Sign { - Forward, - Backward, -} - -impl PlanInterleavedC64 { - pub fn new(n: usize, sign: Sign) -> Self { - let size_bytes = n.checked_mul(core::mem::size_of::()).unwrap(); - let src = FftwAlloc::new(size_bytes); - let dst = FftwAlloc::new(size_bytes); - unsafe { - let p = fftw_sys::fftw_plan_dft_1d( - n.try_into().unwrap(), - src.bytes.as_ptr() as _, - dst.bytes.as_ptr() as _, - match sign { - Sign::Forward => fftw_sys::FFTW_FORWARD as _, - Sign::Backward => fftw_sys::FFTW_BACKWARD as _, - }, - fftw_sys::FFTW_MEASURE, - ); - PlanInterleavedC64 { plan: p, n } - } - } - - pub fn print(&self) { - unsafe { - fftw_sys::fftw_print_plan(self.plan); - } - } - - pub fn execute(&self, src: &mut [c64], dst: &mut [c64]) { - assert_eq!(src.len(), self.n); - assert_eq!(dst.len(), self.n); - let src = src.as_mut_ptr(); - let dst = dst.as_mut_ptr(); - unsafe { - use fftw_sys::{fftw_alignment_of, fftw_execute_dft}; - assert_eq!(fftw_alignment_of(src as _), 0); - assert_eq!(fftw_alignment_of(dst as _), 0); - fftw_execute_dft(self.plan, src as _, dst as _); - } - } -} - -pub fn criterion_benchmark(c: &mut Criterion) { - for n in [ - 1 << 8, - 1 << 9, - 1 << 10, - 1 << 11, - 1 << 12, - 1 << 13, - 1 << 14, - 1 << 15, - 1 << 16, - ] { - let mut mem = dyn_stack::GlobalMemBuffer::new( - StackReq::new_aligned::(n, 64) // scratch - .and( - StackReq::new_aligned::(2 * n, 64).or(StackReq::new_aligned::(n, 64)), // src | twiddles - ) - .and(StackReq::new_aligned::(n, 64)), // dst - ); - let mut stack = DynStack::new(&mut mem); - let z = c64::new(0.0, 0.0); - - { - use rustfft::FftPlannerAvx; - let mut planner = FftPlannerAvx::::new().unwrap(); - - let fwd_rustfft = planner.plan_fft_forward(n); - let mut scratch = []; - - let fwd_fftw = PlanInterleavedC64::new(n, Sign::Forward); - - let bench_duration = std::time::Duration::from_millis(10); - let ordered = concrete_fft::ordered::Plan::new( - n, - concrete_fft::ordered::Method::Measure(bench_duration), - ); - let unordered = concrete_fft::unordered::Plan::new( - n, - concrete_fft::unordered::Method::Measure(bench_duration), - ); - - { - let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - let (mut src, _) = stack.make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("rustfft-fwd-{}", n), |b| { - b.iter(|| { - fwd_rustfft.process_outofplace_with_scratch( - &mut src, - &mut dst, - &mut scratch, - ) - }) - }); - - c.bench_function(&format!("fftw-fwd-{}", n), |b| { - b.iter(|| { - fwd_fftw.execute(&mut src, &mut dst); - }) - }); - } - { - let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("concrete-fwd-{}", n), |b| { - b.iter(|| ordered.fwd(&mut *dst, stack.rb_mut())) - }); - } - { - let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("unordered-fwd-{}", n), |b| { - b.iter(|| unordered.fwd(&mut dst, stack.rb_mut())); - }); - } - { - let (mut dst, mut stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("unordered-inv-{}", n), |b| { - b.iter(|| unordered.inv(&mut dst, stack.rb_mut())); - }); - } - } - - // memcpy - { - let (mut dst, stack) = stack.rb_mut().make_aligned_with::(n, 64, |_| z); - let (src, _) = stack.make_aligned_with::(n, 64, |_| z); - - c.bench_function(&format!("memcpy-{}", n), |b| { - b.iter(|| unsafe { - std::ptr::copy_nonoverlapping(src.as_ptr(), dst.as_mut_ptr(), n); - }) - }); - } - } -} - -criterion_group!(benches, criterion_benchmark); -criterion_main!(benches); diff --git a/benches/lib.rs b/benches/lib.rs deleted file mode 100644 index 2ee2f80..0000000 --- a/benches/lib.rs +++ /dev/null @@ -1,3 +0,0 @@ -#![allow(dead_code)] - -mod fft; diff --git a/src/fft128/f128_impl.rs b/src/fft128/f128_impl.rs new file mode 100644 index 0000000..6b426f0 --- /dev/null +++ b/src/fft128/f128_impl.rs @@ -0,0 +1,1075 @@ +use super::f128; + +/// Computes $\operatorname{fl}(a+b)$ and $\operatorname{err}(a+b)$. +/// Assumes $|a| \geq |b|$. +#[inline(always)] +fn quick_two_sum(a: f64, b: f64) -> (f64, f64) { + let s = a + b; + (s, b - (s - a)) +} + +/// Computes $\operatorname{fl}(a-b)$ and $\operatorname{err}(a-b)$. +/// Assumes $|a| \geq |b|$. +#[allow(dead_code)] +#[inline(always)] +fn quick_two_diff(a: f64, b: f64) -> (f64, f64) { + let s = a - b; + (s, (a - s) - b) +} + +/// Computes $\operatorname{fl}(a+b)$ and $\operatorname{err}(a+b)$. +#[inline(always)] +fn two_sum(a: f64, b: f64) -> (f64, f64) { + let s = a + b; + let bb = s - a; + (s, (a - (s - bb)) + (b - bb)) +} + +/// Computes $\operatorname{fl}(a-b)$ and $\operatorname{err}(a-b)$. +#[inline(always)] +fn two_diff(a: f64, b: f64) -> (f64, f64) { + let s = a - b; + let bb = s - a; + (s, (a - (s - bb)) - (b + bb)) +} + +#[inline(always)] +fn two_prod(a: f64, b: f64) -> (f64, f64) { + let p = a * b; + (p, f64::mul_add(a, b, -p)) +} + +use core::{ + cmp::Ordering, + convert::From, + ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign}, +}; + +impl From for f128 { + #[inline(always)] + fn from(value: f64) -> Self { + Self(value, 0.0) + } +} + +impl Add for f128 { + type Output = f128; + + #[inline(always)] + fn add(self, rhs: f128) -> Self::Output { + f128::add_f128_f128(self, rhs) + } +} + +impl Add for f128 { + type Output = f128; + + #[inline(always)] + fn add(self, rhs: f64) -> Self::Output { + f128::add_f128_f64(self, rhs) + } +} + +impl Add for f64 { + type Output = f128; + + #[inline(always)] + fn add(self, rhs: f128) -> Self::Output { + f128::add_f64_f128(self, rhs) + } +} + +impl AddAssign for f128 { + #[inline(always)] + fn add_assign(&mut self, rhs: f64) { + *self = *self + rhs + } +} + +impl AddAssign for f128 { + #[inline(always)] + fn add_assign(&mut self, rhs: f128) { + *self = *self + rhs + } +} + +impl Sub for f128 { + type Output = f128; + + #[inline(always)] + fn sub(self, rhs: f128) -> Self::Output { + f128::sub_f128_f128(self, rhs) + } +} + +impl Sub for f128 { + type Output = f128; + + #[inline(always)] + fn sub(self, rhs: f64) -> Self::Output { + f128::sub_f128_f64(self, rhs) + } +} + +impl Sub for f64 { + type Output = f128; + + #[inline(always)] + fn sub(self, rhs: f128) -> Self::Output { + f128::sub_f64_f128(self, rhs) + } +} + +impl SubAssign for f128 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f64) { + *self = *self - rhs + } +} + +impl SubAssign for f128 { + #[inline(always)] + fn sub_assign(&mut self, rhs: f128) { + *self = *self - rhs + } +} + +impl Mul for f128 { + type Output = f128; + + #[inline(always)] + fn mul(self, rhs: f128) -> Self::Output { + f128::mul_f128_f128(self, rhs) + } +} + +impl Mul for f128 { + type Output = f128; + + #[inline(always)] + fn mul(self, rhs: f64) -> Self::Output { + f128::mul_f128_f64(self, rhs) + } +} + +impl Mul for f64 { + type Output = f128; + + #[inline(always)] + fn mul(self, rhs: f128) -> Self::Output { + f128::mul_f64_f128(self, rhs) + } +} + +impl MulAssign for f128 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f64) { + *self = *self * rhs + } +} + +impl MulAssign for f128 { + #[inline(always)] + fn mul_assign(&mut self, rhs: f128) { + *self = *self * rhs + } +} + +impl Div for f128 { + type Output = f128; + + #[inline(always)] + fn div(self, rhs: f128) -> Self::Output { + f128::div_f128_f128(self, rhs) + } +} + +impl Div for f128 { + type Output = f128; + + #[inline(always)] + fn div(self, rhs: f64) -> Self::Output { + f128::div_f128_f64(self, rhs) + } +} + +impl Div for f64 { + type Output = f128; + + #[inline(always)] + fn div(self, rhs: f128) -> Self::Output { + f128::div_f64_f128(self, rhs) + } +} + +impl DivAssign for f128 { + #[inline(always)] + fn div_assign(&mut self, rhs: f64) { + *self = *self / rhs + } +} + +impl DivAssign for f128 { + #[inline(always)] + fn div_assign(&mut self, rhs: f128) { + *self = *self / rhs + } +} + +impl Neg for f128 { + type Output = f128; + + #[inline(always)] + fn neg(self) -> Self::Output { + Self(-self.0, -self.1) + } +} + +impl PartialEq for f128 { + #[inline(always)] + fn eq(&self, other: &f128) -> bool { + matches!((self.0 == other.0, self.1 == other.1), (true, true)) + } +} + +impl PartialEq for f128 { + #[inline(always)] + fn eq(&self, other: &f64) -> bool { + (*self).eq(&f128(*other, 0.0)) + } +} + +impl PartialEq for f64 { + #[inline(always)] + fn eq(&self, other: &f128) -> bool { + (*other).eq(self) + } +} + +impl PartialOrd for f128 { + #[inline(always)] + fn partial_cmp(&self, other: &f128) -> Option { + let first_cmp = self.0.partial_cmp(&other.0); + let second_cmp = self.1.partial_cmp(&other.1); + + match first_cmp { + Some(Ordering::Equal) => second_cmp, + _ => first_cmp, + } + } +} + +impl PartialOrd for f128 { + #[inline(always)] + fn partial_cmp(&self, other: &f64) -> Option { + (*self).partial_cmp(&f128(*other, 0.0)) + } +} + +impl PartialOrd for f64 { + #[inline(always)] + fn partial_cmp(&self, other: &f128) -> Option { + f128(*self, 0.0).partial_cmp(other) + } +} + +impl f128 { + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f64_f64(a: f64, b: f64) -> Self { + let (s, e) = two_sum(a, b); + Self(s, e) + } + + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f128_f64(a: f128, b: f64) -> Self { + let (s1, s2) = two_sum(a.0, b); + let s2 = s2 + a.1; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f64_f128(a: f64, b: f128) -> Self { + Self::add_f128_f64(b, a) + } + + /// Adds `a` and `b` and returns the result. + /// This function has a slightly higher error bound than [`Self::add_f128_f128`] + #[inline(always)] + pub fn add_estimate_f128_f128(a: f128, b: f128) -> Self { + let (s, e) = two_sum(a.0, b.0); + let e = e + (a.1 + b.1); + let (s, e) = quick_two_sum(s, e); + Self(s, e) + } + + /// Adds `a` and `b` and returns the result. + #[inline(always)] + pub fn add_f128_f128(a: f128, b: f128) -> Self { + let (s1, s2) = two_sum(a.0, b.0); + let (t1, t2) = two_sum(a.1, b.1); + + let s2 = s2 + t1; + let (s1, s2) = quick_two_sum(s1, s2); + let s2 = s2 + t2; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f64_f64(a: f64, b: f64) -> Self { + let (s, e) = two_diff(a, b); + Self(s, e) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f128_f64(a: f128, b: f64) -> Self { + let (s1, s2) = two_diff(a.0, b); + let s2 = s2 + a.1; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f64_f128(a: f64, b: f128) -> Self { + let (s1, s2) = two_diff(a, b.0); + let s2 = s2 - b.1; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Subtracts `b` from `a` and returns the result. + /// This function has a slightly higher error bound than [`Self::sub_f128_f128`] + #[inline(always)] + pub fn sub_estimate_f128_f128(a: f128, b: f128) -> Self { + let (s, e) = two_diff(a.0, b.0); + let e = e + a.1; + let e = e - b.1; + let (s, e) = quick_two_sum(s, e); + Self(s, e) + } + + /// Subtracts `b` from `a` and returns the result. + #[inline(always)] + pub fn sub_f128_f128(a: f128, b: f128) -> Self { + let (s1, s2) = two_diff(a.0, b.0); + let (t1, t2) = two_diff(a.1, b.1); + + let s2 = s2 + t1; + let (s1, s2) = quick_two_sum(s1, s2); + let s2 = s2 + t2; + let (s1, s2) = quick_two_sum(s1, s2); + Self(s1, s2) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f64_f64(a: f64, b: f64) -> Self { + let (p, e) = two_prod(a, b); + Self(p, e) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f128_f64(a: f128, b: f64) -> Self { + let (p1, p2) = two_prod(a.0, b); + let p2 = p2 + (a.1 * b); + let (p1, p2) = quick_two_sum(p1, p2); + Self(p1, p2) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f64_f128(a: f64, b: f128) -> Self { + Self::mul_f128_f64(b, a) + } + + /// Multiplies `a` and `b` and returns the result. + #[inline(always)] + pub fn mul_f128_f128(a: f128, b: f128) -> Self { + let (p1, p2) = two_prod(a.0, b.0); + let p2 = p2 + (a.0 * b.1 + a.1 * b.0); + let (p1, p2) = quick_two_sum(p1, p2); + Self(p1, p2) + } + + /// Squares `self` and returns the result. + #[inline(always)] + pub fn sqr(self) -> Self { + let (p1, p2) = two_prod(self.0, self.0); + let p2 = p2 + 2.0 * (self.0 * self.1); + let (p1, p2) = quick_two_sum(p1, p2); + Self(p1, p2) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f64_f64(a: f64, b: f64) -> Self { + let q1 = a / b; + + // Compute a - q1 * b + let (p1, p2) = two_prod(q1, b); + let (s, e) = two_diff(a, p1); + let e = e - p2; + + // get next approximation + let q2 = (s + e) / b; + + let (s, e) = quick_two_sum(q1, q2); + f128(s, e) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f128_f64(a: f128, b: f64) -> Self { + // approximate quotient + let q1 = a.0 / b; + + // Compute a - q1 * b + let (p1, p2) = two_prod(q1, b); + let (s, e) = two_diff(a.0, p1); + let e = e + a.1; + let e = e - p2; + + // get next approximation + let q2 = (s + e) / b; + + // renormalize + let (r0, r1) = quick_two_sum(q1, q2); + Self(r0, r1) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f64_f128(a: f64, b: f128) -> Self { + Self::div_f128_f128(a.into(), b) + } + + /// Divides `a` by `b` and returns the result. + /// This function has a slightly higher error bound than [`Self::div_f128_f128`] + #[inline(always)] + pub fn div_estimate_f128_f128(a: f128, b: f128) -> Self { + // approximate quotient + let q1 = a.0 / b.0; + + // compute a - q1 * b + let r = b * q1; + let (s1, s2) = two_diff(a.0, r.0); + let s2 = s2 - r.1; + let s2 = s2 + a.1; + + // get next approximation + let q2 = (s1 + s2) / b.0; + + // renormalize + let (r0, r1) = quick_two_sum(q1, q2); + Self(r0, r1) + } + + /// Divides `a` by `b` and returns the result. + #[inline(always)] + pub fn div_f128_f128(a: f128, b: f128) -> Self { + // approximate quotient + let q1 = a.0 / b.0; + + let r = a - b * q1; + + let q2 = r.0 / b.0; + let r = r - q2 * b; + + let q3 = r.0 / b.0; + + let (q1, q2) = quick_two_sum(q1, q2); + Self(q1, q2) + q3 + } + + /// Casts `self` to an `f64`. + #[inline(always)] + pub fn to_f64(self) -> f64 { + self.0 + } + + /// Checks if `self` is `NaN`. + #[inline(always)] + pub fn is_nan(self) -> bool { + !matches!((self.0.is_nan(), self.1.is_nan()), (false, false)) + } + + /// Returns the absolute value of `self`. + #[inline(always)] + pub fn abs(self) -> Self { + if self.0 < 0.0 { + -self + } else { + self + } + } + + fn sincospi_taylor(self) -> (Self, Self) { + let mut sinc = Self::PI; + let mut cos = f128(1.0, 0.0); + + let sqr = self.sqr(); + let mut pow = f128(1.0, 0.0); + for (s, c) in Self::SINPI_TAYLOR + .iter() + .copied() + .zip(Self::COSPI_TAYLOR.iter().copied()) + { + pow *= sqr; + sinc += s * pow; + cos += c * pow; + } + + (sinc * self, cos) + } + + /// Takes and input in `(-1.0, 1.0)`, and returns the sine and cosine of `self`. + pub fn sincospi(self) -> (Self, Self) { + #[allow(clippy::manual_range_contains)] + if self > 1.0 || self < -1.0 { + panic!("only inputs in [-1, 1] are currently supported, received: {self:?}"); + } + // approximately reduce modulo 1/2 + let p = (self.0 * 2.0).round(); + let r = self - p * 0.5; + + // approximately reduce modulo 1/16 + let q = (r.0 * 16.0).round(); + let r = r - q * (1.0 / 16.0); + + let p = p as isize; + let q = q as isize; + + let q_abs = q.unsigned_abs(); + + let (sin_r, cos_r) = r.sincospi_taylor(); + + let (s, c) = if q == 0 { + (sin_r, cos_r) + } else { + let u = Self::COS_K_PI_OVER_16_TABLE[q_abs - 1]; + let v = Self::SIN_K_PI_OVER_16_TABLE[q_abs - 1]; + if q > 0 { + (u * sin_r + v * cos_r, u * cos_r - v * sin_r) + } else { + (u * sin_r - v * cos_r, u * cos_r + v * sin_r) + } + }; + + if p == 0 { + (s, c) + } else if p == 1 { + (c, -s) + } else if p == -1 { + (-c, s) + } else { + (-s, -c) + } + } +} + +#[allow(clippy::approx_constant)] +impl f128 { + pub const PI: Self = f128(3.141592653589793, 1.2246467991473532e-16); + + const SINPI_TAYLOR: &'static [Self; 9] = &[ + f128(-5.16771278004997, 2.2665622825789447e-16), + f128(2.5501640398773455, -7.931006345326556e-17), + f128(-0.5992645293207921, 2.845026112698218e-17), + f128(0.08214588661112823, -3.847292805297656e-18), + f128(-0.0073704309457143504, -3.328281165603432e-19), + f128(0.00046630280576761255, 1.0704561733683463e-20), + f128(-2.1915353447830217e-5, 1.4648526682685598e-21), + f128(7.952054001475513e-7, 1.736540361519021e-23), + f128(-2.2948428997269873e-8, -7.376346207041088e-26), + ]; + + const COSPI_TAYLOR: &'static [Self; 9] = &[ + f128(-4.934802200544679, -3.1326477543698557e-16), + f128(4.0587121264167685, -2.6602000824298645e-16), + f128(-1.3352627688545895, 3.1815237892149862e-18), + f128(0.2353306303588932, -1.2583065576724427e-18), + f128(-0.02580689139001406, 1.170191067939226e-18), + f128(0.0019295743094039231, -9.669517939986956e-20), + f128(-0.0001046381049248457, -2.421206183964864e-21), + f128(4.303069587032947e-6, -2.864010082936791e-22), + f128(-1.3878952462213771e-7, -7.479362090417238e-24), + ]; + + const SIN_K_PI_OVER_16_TABLE: &'static [Self; 4] = &[ + f128(0.19509032201612828, -7.991079068461731e-18), + f128(0.3826834323650898, -1.0050772696461588e-17), + f128(0.5555702330196022, 4.709410940561677e-17), + f128(0.7071067811865476, -4.833646656726457e-17), + ]; + + const COS_K_PI_OVER_16_TABLE: &'static [Self; 4] = &[ + f128(0.9807852804032304, 1.8546939997825006e-17), + f128(0.9238795325112867, 1.7645047084336677e-17), + f128(0.8314696123025452, 1.4073856984728024e-18), + f128(0.7071067811865476, -4.833646656726457e-17), + ]; +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64", doc))] +#[cfg_attr(docsrs, doc(cfg(any(target_arch = "x86", target_arch = "x86_64"))))] +pub mod x86 { + #[cfg(target_arch = "x86")] + use core::arch::x86::*; + #[cfg(target_arch = "x86_64")] + use core::arch::x86_64::*; + + pulp::simd_type! { + /// Avx SIMD type. + pub struct Avx { + pub sse: "sse", + pub sse2: "sse2", + pub avx: "avx", + pub fma: "fma", + } + } + + #[cfg(feature = "nightly")] + pulp::simd_type! { + /// Avx512 SIMD type. + pub struct Avx512 { + pub sse: "sse", + pub sse2: "sse2", + pub avx: "avx", + pub avx2: "avx2", + pub fma: "fma", + pub avx512f: "avx512f", + } + } + + #[inline(always)] + pub(crate) fn _mm256_quick_two_sum(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let s = simd.avx._mm256_add_pd(a, b); + (s, simd.avx._mm256_sub_pd(b, simd.avx._mm256_sub_pd(s, a))) + } + + #[inline(always)] + pub(crate) fn _mm256_two_sum(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let s = simd.avx._mm256_add_pd(a, b); + let bb = simd.avx._mm256_sub_pd(s, a); + ( + s, + simd.avx._mm256_add_pd( + simd.avx._mm256_sub_pd(a, simd.avx._mm256_sub_pd(s, bb)), + simd.avx._mm256_sub_pd(b, bb), + ), + ) + } + + #[inline(always)] + pub(crate) fn _mm256_two_diff(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let s = simd.avx._mm256_sub_pd(a, b); + let bb = simd.avx._mm256_sub_pd(s, a); + ( + s, + simd.avx._mm256_sub_pd( + simd.avx._mm256_sub_pd(a, simd.avx._mm256_sub_pd(s, bb)), + simd.avx._mm256_add_pd(b, bb), + ), + ) + } + + #[inline(always)] + pub(crate) fn _mm256_two_prod(simd: Avx, a: __m256d, b: __m256d) -> (__m256d, __m256d) { + let p = simd.avx._mm256_mul_pd(a, b); + (p, simd.fma._mm256_fmsub_pd(a, b, p)) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_quick_two_sum(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let s = simd.avx512f._mm512_add_pd(a, b); + ( + s, + simd.avx512f + ._mm512_sub_pd(b, simd.avx512f._mm512_sub_pd(s, a)), + ) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_two_sum(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let s = simd.avx512f._mm512_add_pd(a, b); + let bb = simd.avx512f._mm512_sub_pd(s, a); + ( + s, + simd.avx512f._mm512_add_pd( + simd.avx512f + ._mm512_sub_pd(a, simd.avx512f._mm512_sub_pd(s, bb)), + simd.avx512f._mm512_sub_pd(b, bb), + ), + ) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_two_diff(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let s = simd.avx512f._mm512_sub_pd(a, b); + let bb = simd.avx512f._mm512_sub_pd(s, a); + ( + s, + simd.avx512f._mm512_sub_pd( + simd.avx512f + ._mm512_sub_pd(a, simd.avx512f._mm512_sub_pd(s, bb)), + simd.avx512f._mm512_add_pd(b, bb), + ), + ) + } + + #[cfg(feature = "nightly")] + #[inline(always)] + pub(crate) fn _mm512_two_prod(simd: Avx512, a: __m512d, b: __m512d) -> (__m512d, __m512d) { + let p = simd.avx512f._mm512_mul_pd(a, b); + (p, simd.avx512f._mm512_fmsub_pd(a, b, p)) + } + + impl Avx { + #[inline(always)] + pub fn _mm256_add_estimate_f128_f128( + self, + a0: __m256d, + a1: __m256d, + b0: __m256d, + b1: __m256d, + ) -> (__m256d, __m256d) { + let (s, e) = _mm256_two_sum(self, a0, b0); + let e = self.avx._mm256_add_pd(e, self.avx._mm256_add_pd(a1, b1)); + _mm256_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm256_sub_estimate_f128_f128( + self, + a0: __m256d, + a1: __m256d, + b0: __m256d, + b1: __m256d, + ) -> (__m256d, __m256d) { + let (s, e) = _mm256_two_diff(self, a0, b0); + let e = self.avx._mm256_add_pd(e, a1); + let e = self.avx._mm256_sub_pd(e, b1); + _mm256_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm256_mul_f128_f128( + self, + a0: __m256d, + a1: __m256d, + b0: __m256d, + b1: __m256d, + ) -> (__m256d, __m256d) { + let (p1, p2) = _mm256_two_prod(self, a0, b0); + let p2 = self.avx._mm256_add_pd( + p2, + self.avx._mm256_add_pd( + self.avx._mm256_mul_pd(a0, b1), + self.avx._mm256_mul_pd(a1, b0), + ), + ); + _mm256_quick_two_sum(self, p1, p2) + } + } + + #[cfg(feature = "nightly")] + impl Avx512 { + #[inline(always)] + pub fn _mm512_add_estimate_f128_f128( + self, + a0: __m512d, + a1: __m512d, + b0: __m512d, + b1: __m512d, + ) -> (__m512d, __m512d) { + let (s, e) = _mm512_two_sum(self, a0, b0); + let e = self + .avx512f + ._mm512_add_pd(e, self.avx512f._mm512_add_pd(a1, b1)); + _mm512_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm512_sub_estimate_f128_f128( + self, + a0: __m512d, + a1: __m512d, + b0: __m512d, + b1: __m512d, + ) -> (__m512d, __m512d) { + let (s, e) = _mm512_two_diff(self, a0, b0); + let e = self.avx512f._mm512_add_pd(e, a1); + let e = self.avx512f._mm512_sub_pd(e, b1); + _mm512_quick_two_sum(self, s, e) + } + + #[inline(always)] + pub fn _mm512_mul_f128_f128( + self, + a0: __m512d, + a1: __m512d, + b0: __m512d, + b1: __m512d, + ) -> (__m512d, __m512d) { + let (p1, p2) = _mm512_two_prod(self, a0, b0); + let p2 = self.avx512f._mm512_add_pd( + p2, + self.avx512f._mm512_add_pd( + self.avx512f._mm512_mul_pd(a0, b1), + self.avx512f._mm512_mul_pd(a1, b0), + ), + ); + _mm512_quick_two_sum(self, p1, p2) + } + } +} + +#[cfg(all(test, target_os = "linux"))] +mod tests { + use super::*; + use more_asserts::assert_le; + use rug::{ops::Pow, Float, Integer}; + + const PREC: u32 = 1024; + + fn float_to_f128(value: &Float) -> f128 { + let x0: f64 = value.to_f64(); + let diff = value.clone() - x0; + let x1 = diff.to_f64(); + f128(x0, x1) + } + + fn f128_to_float(value: f128) -> Float { + Float::with_val(PREC, value.0) + Float::with_val(PREC, value.1) + } + + #[test] + fn test_add() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(0u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let sum = Float::with_val(PREC, &a + &b); + let sum_rug_f128 = float_to_f128(&sum); + let sum_f128 = a_f128 + b_f128; + + assert_le!( + (sum_f128 - sum_rug_f128).abs(), + 2.0f64.powi(-104) * sum_f128.abs() + ); + } + } + + #[test] + fn test_sub() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(1u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let diff = Float::with_val(PREC, &a - &b); + let diff_rug_f128 = float_to_f128(&diff); + let diff_f128 = a_f128 - b_f128; + + assert_le!( + (diff_f128 - diff_rug_f128).abs(), + 2.0f64.powi(-104) * diff_f128.abs() + ); + } + } + + #[test] + fn test_mul() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(2u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let prod = Float::with_val(PREC, &a * &b); + let prod_rug_f128 = float_to_f128(&prod); + let prod_f128 = a_f128 * b_f128; + + assert_le!( + (prod_f128 - prod_rug_f128).abs(), + 2.0f64.powi(-104) * prod_f128.abs() + ); + } + } + + #[test] + fn test_div() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(3u64)); + + for _ in 0..100 { + let a = Float::with_val(PREC, Float::random_normal(&mut rng)); + let b = Float::with_val(PREC, Float::random_normal(&mut rng)); + + let a_f128 = float_to_f128(&a); + let b_f128 = float_to_f128(&b); + let a = f128_to_float(a_f128); + let b = f128_to_float(b_f128); + + let quot = Float::with_val(PREC, &a / &b); + let quot_rug_f128 = float_to_f128("); + let quot_f128 = a_f128 / b_f128; + + assert_le!( + (quot_f128 - quot_rug_f128).abs(), + 2.0f64.powi(-104) * quot_f128.abs() + ); + } + } + + #[test] + fn test_sincos_taylor() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(4u64)); + + for _ in 0..10000 { + let a = (Float::with_val(PREC, Float::random_bits(&mut rng)) * 2.0 - 1.0) / 32; + let a_f128 = float_to_f128(&a); + let a = f128_to_float(a_f128); + + let sin = Float::with_val(PREC, a.clone().sin_pi()); + let cos = Float::with_val(PREC, a.clone().cos_pi()); + let sin_rug_f128 = float_to_f128(&sin); + let cos_rug_f128 = float_to_f128(&cos); + let (sin_f128, cos_f128) = a_f128.sincospi_taylor(); + assert_le!( + (cos_f128 - cos_rug_f128).abs(), + 2.0f64.powi(-103) * cos_f128.abs() + ); + assert_le!( + (sin_f128 - sin_rug_f128).abs(), + 2.0f64.powi(-103) * sin_f128.abs() + ); + } + } + + #[test] + fn test_sincos() { + let mut rng = rug::rand::RandState::new(); + rng.seed(&Integer::from(5u64)); + + #[track_caller] + fn test_sincos(a: Float) { + let a_f128 = float_to_f128(&a); + let a = f128_to_float(a_f128); + + let sin = Float::with_val(PREC, a.clone().sin_pi()); + let cos = Float::with_val(PREC, a.cos_pi()); + let sin_rug_f128 = float_to_f128(&sin); + let cos_rug_f128 = float_to_f128(&cos); + let (sin_f128, cos_f128) = a_f128.sincospi(); + assert_le!( + (cos_f128 - cos_rug_f128).abs(), + 2.0f64.powi(-103) * cos_f128.abs() + ); + assert_le!( + (sin_f128 - sin_rug_f128).abs(), + 2.0f64.powi(-103) * sin_f128.abs() + ); + } + + test_sincos(Float::with_val(PREC, 0.00)); + test_sincos(Float::with_val(PREC, 0.25)); + test_sincos(Float::with_val(PREC, 0.50)); + test_sincos(Float::with_val(PREC, 0.75)); + test_sincos(Float::with_val(PREC, 1.00)); + + for _ in 0..10000 { + test_sincos(Float::with_val(PREC, Float::random_bits(&mut rng)) * 2.0 - 1.0); + } + } + + #[cfg(feature = "std")] + #[test] + fn generate_constants() { + let pi = Float::with_val(PREC, rug::float::Constant::Pi); + + println!(); + println!("###############################################################################"); + println!("impl f128 {{"); + println!(" pub const PI: Self = {:?};", float_to_f128(&pi)); + + println!(); + println!(" const SINPI_TAYLOR: &'static [Self; 9] = &["); + let mut factorial = 1_u64; + for i in 1..10 { + let k = 2 * i + 1; + factorial *= (k - 1) * k; + println!( + " {:?},", + (-1.0f64).powi(i as i32) * float_to_f128(&(pi.clone().pow(k) / factorial)), + ); + } + println!(" ];"); + + println!(); + println!(" const COSPI_TAYLOR: &'static [Self; 9] = &["); + let mut factorial = 1_u64; + for i in 1..10 { + let k = 2 * i; + factorial *= (k - 1) * k; + println!( + " {:?},", + (-1.0f64).powi(i as i32) * float_to_f128(&(pi.clone().pow(k) / factorial)), + ); + } + println!(" ];"); + + println!(); + println!(" const SIN_K_PI_OVER_16_TABLE: &'static [Self; 4] = &["); + for k in 1..5 { + let x: Float = Float::with_val(PREC, k as f64 / 16.0); + println!(" {:?},", float_to_f128(&x.clone().sin_pi()),); + } + println!(" ];"); + + println!(); + println!(" const COS_K_PI_OVER_16_TABLE: &'static [Self; 4] = &["); + for k in 1..5 { + let x: Float = Float::with_val(PREC, k as f64 / 16.0); + println!(" {:?},", float_to_f128(&x.clone().cos_pi()),); + } + println!(" ];"); + + println!("}}"); + println!("###############################################################################"); + assert_eq!(float_to_f128(&pi), f128::PI); + } +} diff --git a/src/fft128/mod.rs b/src/fft128/mod.rs new file mode 100644 index 0000000..eda2bbd --- /dev/null +++ b/src/fft128/mod.rs @@ -0,0 +1,2204 @@ +mod f128_impl; + +/// 128-bit floating point number. +#[allow(non_camel_case_types)] +#[derive(Copy, Clone, Debug)] +#[repr(C)] +pub struct f128(pub f64, pub f64); + +use aligned_vec::{avec, ABox}; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +pub use f128_impl::x86::Avx; +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))] +pub use f128_impl::x86::Avx512; + +use pulp::{as_arrays, as_arrays_mut, cast}; + +#[allow(unused_macros)] +macro_rules! izip { + (@ __closure @ ($a:expr)) => { |a| (a,) }; + (@ __closure @ ($a:expr, $b:expr)) => { |(a, b)| (a, b) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr)) => { |((a, b), c)| (a, b, c) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr)) => { |(((a, b), c), d)| (a, b, c, d) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr)) => { |((((a, b), c), d), e)| (a, b, c, d, e) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr)) => { |(((((a, b), c), d), e), f)| (a, b, c, d, e, f) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr)) => { |((((((a, b), c), d), e), f), g)| (a, b, c, d, e, f, g) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr)) => { |(((((((a, b), c), d), e), f), g), h)| (a, b, c, d, e, f, g, h) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr)) => { |((((((((a, b), c), d), e), f), g), h), i)| (a, b, c, d, e, f, g, h, i) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr)) => { |(((((((((a, b), c), d), e), f), g), h), i), j)| (a, b, c, d, e, f, g, h, i, j) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr)) => { |((((((((((a, b), c), d), e), f), g), h), i), j), k)| (a, b, c, d, e, f, g, h, i, j, k) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr)) => { |(((((((((((a, b), c), d), e), f), g), h), i), j), k), l)| (a, b, c, d, e, f, g, h, i, j, k, l) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr)) => { |((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m)| (a, b, c, d, e, f, g, h, i, j, k, l, m) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr)) => { |(((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n) }; + (@ __closure @ ($a:expr, $b:expr, $c:expr, $d:expr, $e: expr, $f:expr, $g:expr, $h:expr, $i: expr, $j: expr, $k: expr, $l: expr, $m:expr, $n:expr, $o:expr)) => { |((((((((((((((a, b), c), d), e), f), g), h), i), j), k), l), m), n), o)| (a, b, c, d, e, f, g, h, i, j, k, l, m, n, o) }; + + ( $first:expr $(,)?) => { + { + ::core::iter::IntoIterator::into_iter($first) + } + }; + ( $first:expr, $($rest:expr),+ $(,)?) => { + { + ::core::iter::IntoIterator::into_iter($first) + $(.zip($rest))* + .map(izip!(@ __closure @ ($first, $($rest),*))) + } + }; +} + +trait FftSimdF128: Copy { + type Reg: Copy + core::fmt::Debug; + + fn splat(self, value: f64) -> Self::Reg; + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg); + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg); + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg); +} + +#[derive(Copy, Clone)] +struct Scalar; + +impl FftSimdF128 for Scalar { + type Reg = f64; + + #[inline(always)] + fn splat(self, value: f64) -> Self::Reg { + value + } + + #[inline(always)] + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let f128(o0, o1) = f128(a.0, a.1) + f128(b.0, b.1); + (o0, o1) + } + + #[inline(always)] + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let f128(o0, o1) = f128(a.0, a.1) - f128(b.0, b.1); + (o0, o1) + } + + #[inline(always)] + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let f128(o0, o1) = f128(a.0, a.1) * f128(b.0, b.1); + (o0, o1) + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +impl FftSimdF128 for Avx { + type Reg = [f64; 4]; + + #[inline(always)] + fn splat(self, value: f64) -> Self::Reg { + cast(self.avx._mm256_set1_pd(value)) + } + + #[inline(always)] + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm256_add_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm256_sub_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm256_mul_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +impl FftSimdF128 for Avx512 { + type Reg = [f64; 8]; + + #[inline(always)] + fn splat(self, value: f64) -> Self::Reg { + cast(self.avx512f._mm512_set1_pd(value)) + } + + #[inline(always)] + fn add(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm512_add_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn sub(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm512_sub_estimate_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } + + #[inline(always)] + fn mul(self, a: (Self::Reg, Self::Reg), b: (Self::Reg, Self::Reg)) -> (Self::Reg, Self::Reg) { + let result = self._mm512_mul_f128_f128(cast(a.0), cast(a.1), cast(b.0), cast(b.1)); + (cast(result.0), cast(result.1)) + } +} + +trait FftSimdF128Ext: FftSimdF128 { + #[inline(always)] + fn cplx_add( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + (self.add(a_re, b_re), self.add(a_im, b_im)) + } + + #[inline(always)] + fn cplx_sub( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + (self.sub(a_re, b_re), self.sub(a_im, b_im)) + } + + /// `a * b` + #[inline(always)] + fn cplx_mul( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + let a_re_x_b_re = self.mul(a_re, b_re); + let a_re_x_b_im = self.mul(a_re, b_im); + let a_im_x_b_re = self.mul(a_im, b_re); + let a_im_x_b_im = self.mul(a_im, b_im); + + ( + self.sub(a_re_x_b_re, a_im_x_b_im), + self.add(a_im_x_b_re, a_re_x_b_im), + ) + } + + /// `a * conj(b)` + #[inline(always)] + fn cplx_mul_conj( + self, + a_re: (Self::Reg, Self::Reg), + a_im: (Self::Reg, Self::Reg), + b_re: (Self::Reg, Self::Reg), + b_im: (Self::Reg, Self::Reg), + ) -> ((Self::Reg, Self::Reg), (Self::Reg, Self::Reg)) { + let a_re_x_b_re = self.mul(a_re, b_re); + let a_re_x_b_im = self.mul(a_re, b_im); + let a_im_x_b_re = self.mul(a_im, b_re); + let a_im_x_b_im = self.mul(a_im, b_im); + + ( + self.add(a_re_x_b_re, a_im_x_b_im), + self.sub(a_im_x_b_re, a_re_x_b_im), + ) + } +} + +impl FftSimdF128Ext for T {} + +#[doc(hidden)] +pub fn negacyclic_fwd_fft_scalar( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + let mut t = n; + let mut m = 1; + let simd = Scalar; + + while m < n { + t /= 2; + + for i in 0..m { + let w1_re = (twid_re0[m + i], twid_re1[m + i]); + let w1_im = (twid_im0[m + i], twid_im1[m + i]); + + let start = 2 * i * t; + + let data_re0 = &mut data_re0[start..][..2 * t]; + let data_re1 = &mut data_re1[start..][..2 * t]; + let data_im0 = &mut data_im0[start..][..2 * t]; + let data_im1 = &mut data_im1[start..][..2 * t]; + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + } + } + + m *= 2; + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[doc(hidden)] +pub fn negacyclic_fwd_fft_avxfma( + simd: Avx, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + { + let mut t = n; + let mut m = 1; + + while m < n / 4 { + t /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in + iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<4, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<4, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<4, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<4, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<4, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<4, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<4, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<4, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + } + } + + m *= 2; + } + } + + // m = n / 4 + // t = 2 + { + let m = n / 4; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 2 3 | 4 5 6 7 -> 0 1 4 5 | 2 3 6 7 + // + // is its own inverse since: + // 0 1 4 5 | 2 3 6 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_permute2f128_pd::<0b00100000>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + cast(simd.avx._mm256_permute2f128_pd::<0b00110001>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let splat2 = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 4] { + let w00 = simd.sse2._mm_set1_pd(w[0]); + let w11 = simd.sse2._mm_set1_pd(w[1]); + + let w0011 = simd.avx._mm256_insertf128_pd::<0b1>( + simd.avx._mm256_castpd128_pd256(w00), + w11, + ); + + cast(w0011) + } + }; + + let w1_re = (splat2(*w1_re0), splat2(*w1_re1)); + let w1_im = (splat2(*w1_im0), splat2(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + } + + // m = n / 2 + // t = 1 + { + let m = n / 2; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 -> 0 2 1 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 4] { + let avx = simd.avx; + let w0123 = cast(w); + let w0101 = avx._mm256_permute2f128_pd::<0b00000000>(w0123, w0123); + let w2323 = avx._mm256_permute2f128_pd::<0b00110011>(w0123, w0123); + let w0213 = avx._mm256_shuffle_pd::<0b1100>(w0101, w2323); + cast(w0213) + } + }; + + // 0 1 2 3 | 4 5 6 7 -> 0 4 2 6 | 1 5 3 7 + // + // is its own inverse since: + // 0 4 2 6 | 1 5 3 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(simd.avx._mm256_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + } + }); +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +#[doc(hidden)] +pub fn negacyclic_fwd_fft_avx512( + simd: Avx512, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + { + let mut t = n; + let mut m = 1; + + while m < n / 8 { + t /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in + iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<8, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<8, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<8, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<8, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<8, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<8, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<8, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<8, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + } + } + + m *= 2; + } + } + + // m = n / 8 + // t = 4 + { + let m = n / 8; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 -> 0 0 0 0 1 1 1 1 + let permute = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w01xxxxxx = avx512f._mm512_castpd128_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 0, 0, 1, 1, 1, 1); + + cast(avx512f._mm512_permutexvar_pd(idx, w01xxxxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 2 3 8 9 a b | 4 5 6 7 c d e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb); + let idx_1 = + avx512f._mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + + // m = n / 4 + // t = 2 + { + let m = n / 4; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 -> 0 0 2 2 1 1 3 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w0123xxxx = avx512f._mm512_castpd256_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 2, 2, 1, 1, 3, 3); + + cast(avx512f._mm512_permutexvar_pd(idx, w0123xxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 8 9 4 5 c d | 2 3 a b 6 7 e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd); + let idx_1 = + avx512f._mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + + // m = n / 2 + // t = 1 + { + let m = n / 2; + + let twid_re0 = as_arrays::<8, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<8, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<8, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<8, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 4 5 6 7 -> 0 4 1 5 2 6 3 7 + let permute = { + #[inline(always)] + |w: [f64; 8]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let idx = avx512f._mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7); + cast(avx512f._mm512_permutexvar_pd(idx, w)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f + // + // is its own inverse since: + // 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f -> 0 1 2 3 4 5 6 7 | 8 9 a b c d e f + let interleave = { + #[inline(always)] + |z0z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + [ + cast(avx512f._mm512_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(avx512f._mm512_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z1w_re, z1w_im) = simd.cplx_mul(z1_re, z1_im, w1_re, w1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1w_re, z1w_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_sub(z0_re, z0_im, z1w_re, z1w_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + } + } + }); +} + +#[doc(hidden)] +pub fn negacyclic_fwd_fft( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + #[cfg(feature = "nightly")] + if let Some(simd) = Avx512::try_new() { + return negacyclic_fwd_fft_avx512( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + if let Some(simd) = Avx::try_new() { + return negacyclic_fwd_fft_avxfma( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + } + negacyclic_fwd_fft_scalar( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1, + ) +} + +#[doc(hidden)] +pub fn negacyclic_inv_fft( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + { + #[cfg(feature = "nightly")] + if let Some(simd) = Avx512::try_new() { + return negacyclic_inv_fft_avx512( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + if let Some(simd) = Avx::try_new() { + return negacyclic_inv_fft_avxfma( + simd, data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, + twid_im1, + ); + } + } + negacyclic_inv_fft_scalar( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1, + ) +} + +#[doc(hidden)] +pub fn negacyclic_inv_fft_scalar( + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + let mut t = 1; + let mut m = n; + let simd = Scalar; + + while m > 1 { + m /= 2; + + for i in 0..m { + let w1_re = (twid_re0[m + i], twid_re1[m + i]); + let w1_im = (twid_im0[m + i], twid_im1[m + i]); + + let start = 2 * i * t; + + let data_re0 = &mut data_re0[start..][..2 * t]; + let data_re1 = &mut data_re1[start..][..2 * t]; + let data_im0 = &mut data_im0[start..][..2 * t]; + let data_im1 = &mut data_im1[start..][..2 * t]; + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) + { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + } + } + + t *= 2; + } +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[doc(hidden)] +pub fn negacyclic_inv_fft_avxfma( + simd: Avx, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + let mut t = 1; + let mut m = n; + + // m = n / 2 + // t = 1 + { + m /= 2; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 -> 0 2 1 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 4] { + let avx = simd.avx; + let w0123 = cast(w); + let w0101 = avx._mm256_permute2f128_pd::<0b00000000>(w0123, w0123); + let w2323 = avx._mm256_permute2f128_pd::<0b00110011>(w0123, w0123); + let w0213 = avx._mm256_shuffle_pd::<0b1100>(w0101, w2323); + cast(w0213) + } + }; + + // 0 1 2 3 | 4 5 6 7 -> 0 4 2 6 | 1 5 3 7 + // + // is its own inverse since: + // 0 4 2 6 | 1 5 3 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(simd.avx._mm256_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + // m = n / 4 + // t = 2 + { + m /= 2; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<4, _>(data_re0).0; + let data_re1 = as_arrays_mut::<4, _>(data_re1).0; + let data_im0 = as_arrays_mut::<4, _>(data_im0).0; + let data_im1 = as_arrays_mut::<4, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 2 3 | 4 5 6 7 -> 0 1 4 5 | 2 3 6 7 + // + // is its own inverse since: + // 0 1 4 5 | 2 3 6 7 -> 0 1 2 3 | 4 5 6 7 + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 4]; 2]| -> [[f64; 4]; 2] { + [ + cast(simd.avx._mm256_permute2f128_pd::<0b00100000>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + cast(simd.avx._mm256_permute2f128_pd::<0b00110001>( + cast(z0z0z1z1[0]), + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let splat2 = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 4] { + let w00 = simd.sse2._mm_set1_pd(w[0]); + let w11 = simd.sse2._mm_set1_pd(w[1]); + + let w0011 = simd.avx._mm256_insertf128_pd::<0b1>( + simd.avx._mm256_castpd128_pd256(w00), + w11, + ); + + cast(w0011) + } + }; + + let w1_re = (splat2(*w1_re0), splat2(*w1_re1)); + let w1_im = (splat2(*w1_im0), splat2(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + while m > 1 { + m /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<4, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<4, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<4, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<4, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<4, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<4, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<4, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<4, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + } + } + + t *= 2; + } + } + }); +} + +#[cfg(any(target_arch = "x86", target_arch = "x86_64"))] +#[cfg(feature = "nightly")] +#[doc(hidden)] +pub fn negacyclic_inv_fft_avx512( + simd: Avx512, + data_re0: &mut [f64], + data_re1: &mut [f64], + data_im0: &mut [f64], + data_im1: &mut [f64], + twid_re0: &[f64], + twid_re1: &[f64], + twid_im0: &[f64], + twid_im1: &[f64], +) { + let n = data_re0.len(); + assert!(n >= 32); + + simd.vectorize({ + #[inline(always)] + || { + let mut t = 1; + let mut m = n; + + // m = n / 2 + // t = 1 + { + m /= 2; + + let twid_re0 = as_arrays::<8, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<8, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<8, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<8, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for (z0z1_re0, z0z1_re1, z0z1_im0, z0z1_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + // 0 1 2 3 4 5 6 7 -> 0 4 1 5 2 6 3 7 + let permute = { + #[inline(always)] + |w: [f64; 8]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let idx = avx512f._mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7); + cast(avx512f._mm512_permutexvar_pd(idx, w)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f + // + // is its own inverse since: + // 0 8 2 a 4 c 6 e | 1 9 3 b 5 d 7 f -> 0 1 2 3 4 5 6 7 | 8 9 a b c d e f + let interleave = { + #[inline(always)] + |z0z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + [ + cast(avx512f._mm512_unpacklo_pd(cast(z0z1[0]), cast(z0z1[1]))), + cast(avx512f._mm512_unpackhi_pd(cast(z0z1[0]), cast(z0z1[1]))), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z1_re0 = interleave([z0_re0, z1_re0]); + *z0z1_re1 = interleave([z0_re1, z1_re1]); + *z0z1_im0 = interleave([z0_im0, z1_im0]); + *z0z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + // m = n / 4 + // t = 2 + { + m /= 2; + + let twid_re0 = as_arrays::<4, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<4, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<4, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<4, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 2 3 -> 0 0 2 2 1 1 3 3 + let permute = { + #[inline(always)] + |w: [f64; 4]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w0123xxxx = avx512f._mm512_castpd256_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 2, 2, 1, 1, 3, 3); + + cast(avx512f._mm512_permutexvar_pd(idx, w0123xxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 8 9 4 5 c d | 2 3 a b 6 7 e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x8, 0x9, 0x4, 0x5, 0xc, 0xd); + let idx_1 = + avx512f._mm512_setr_epi64(0x2, 0x3, 0xa, 0xb, 0x6, 0x7, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + // m = n / 8 + // t = 4 + { + m /= 2; + + let twid_re0 = as_arrays::<2, _>(&twid_re0[m..]).0; + let twid_re1 = as_arrays::<2, _>(&twid_re1[m..]).0; + let twid_im0 = as_arrays::<2, _>(&twid_im0[m..]).0; + let twid_im1 = as_arrays::<2, _>(&twid_im1[m..]).0; + + let data_re0 = as_arrays_mut::<8, _>(data_re0).0; + let data_re1 = as_arrays_mut::<8, _>(data_re1).0; + let data_im0 = as_arrays_mut::<8, _>(data_im0).0; + let data_im1 = as_arrays_mut::<8, _>(data_im1).0; + + let data_re0 = as_arrays_mut::<2, _>(data_re0).0; + let data_re1 = as_arrays_mut::<2, _>(data_re1).0; + let data_im0 = as_arrays_mut::<2, _>(data_im0).0; + let data_im1 = as_arrays_mut::<2, _>(data_im1).0; + + let iter = izip!( + data_re0, data_re1, data_im0, data_im1, twid_re0, twid_re1, twid_im0, twid_im1 + ); + for ( + z0z0z1z1_re0, + z0z0z1z1_re1, + z0z0z1z1_im0, + z0z0z1z1_im1, + w1_re0, + w1_re1, + w1_im0, + w1_im1, + ) in iter + { + // 0 1 -> 0 0 0 0 1 1 1 1 + let permute = { + #[inline(always)] + |w: [f64; 2]| -> [f64; 8] { + let avx512f = simd.avx512f; + let w = cast(w); + let w01xxxxxx = avx512f._mm512_castpd128_pd512(w); + let idx = avx512f._mm512_setr_epi64(0, 0, 0, 0, 1, 1, 1, 1); + + cast(avx512f._mm512_permutexvar_pd(idx, w01xxxxxx)) + } + }; + + // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f -> 0 1 2 3 8 9 a b | 4 5 6 7 c d e f + let interleave = { + #[inline(always)] + |z0z0z1z1: [[f64; 8]; 2]| -> [[f64; 8]; 2] { + let avx512f = simd.avx512f; + let idx_0 = + avx512f._mm512_setr_epi64(0x0, 0x1, 0x2, 0x3, 0x8, 0x9, 0xa, 0xb); + let idx_1 = + avx512f._mm512_setr_epi64(0x4, 0x5, 0x6, 0x7, 0xc, 0xd, 0xe, 0xf); + [ + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_0, + cast(z0z0z1z1[1]), + )), + cast(avx512f._mm512_permutex2var_pd( + cast(z0z0z1z1[0]), + idx_1, + cast(z0z0z1z1[1]), + )), + ] + } + }; + + let w1_re = (permute(*w1_re0), permute(*w1_re1)); + let w1_im = (permute(*w1_im0), permute(*w1_im1)); + + let [mut z0_re0, mut z1_re0] = interleave(*z0z0z1z1_re0); + let [mut z0_re1, mut z1_re1] = interleave(*z0z0z1z1_re1); + let [mut z0_im0, mut z1_im0] = interleave(*z0z0z1z1_im0); + let [mut z0_im1, mut z1_im1] = interleave(*z0z0z1z1_im1); + + let (z0_re, z0_im) = ((z0_re0, z0_re1), (z0_im0, z0_im1)); + let (z1_re, z1_im) = ((z1_re0, z1_re1), (z1_im0, z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((z0_re0, z0_re1), (z0_im0, z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((z1_re0, z1_re1), (z1_im0, z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + + *z0z0z1z1_re0 = interleave([z0_re0, z1_re0]); + *z0z0z1z1_re1 = interleave([z0_re1, z1_re1]); + *z0z0z1z1_im0 = interleave([z0_im0, z1_im0]); + *z0z0z1z1_im1 = interleave([z0_im1, z1_im1]); + } + + t *= 2; + } + + while m > 1 { + m /= 2; + + let twid_re0 = &twid_re0[m..]; + let twid_re1 = &twid_re1[m..]; + let twid_im0 = &twid_im0[m..]; + let twid_im1 = &twid_im1[m..]; + + let iter = izip!( + data_re0.chunks_mut(2 * t), + data_re1.chunks_mut(2 * t), + data_im0.chunks_mut(2 * t), + data_im1.chunks_mut(2 * t), + twid_re0, + twid_re1, + twid_im0, + twid_im1, + ); + for (data_re0, data_re1, data_im0, data_im1, w1_re0, w1_re1, w1_im0, w1_im1) in iter + { + let w1_re = (*w1_re0, *w1_re1); + let w1_im = (*w1_im0, *w1_im1); + + let w1_re = (simd.splat(w1_re.0), simd.splat(w1_re.1)); + let w1_im = (simd.splat(w1_im.0), simd.splat(w1_im.1)); + + let (z0_re0, z1_re0) = data_re0.split_at_mut(t); + let (z0_re1, z1_re1) = data_re1.split_at_mut(t); + let (z0_im0, z1_im0) = data_im0.split_at_mut(t); + let (z0_im1, z1_im1) = data_im1.split_at_mut(t); + + let z0_re0 = as_arrays_mut::<8, _>(z0_re0).0; + let z0_re1 = as_arrays_mut::<8, _>(z0_re1).0; + let z0_im0 = as_arrays_mut::<8, _>(z0_im0).0; + let z0_im1 = as_arrays_mut::<8, _>(z0_im1).0; + let z1_re0 = as_arrays_mut::<8, _>(z1_re0).0; + let z1_re1 = as_arrays_mut::<8, _>(z1_re1).0; + let z1_im0 = as_arrays_mut::<8, _>(z1_im0).0; + let z1_im1 = as_arrays_mut::<8, _>(z1_im1).0; + + let iter = + izip!(z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1); + for (z0_re0, z0_re1, z0_im0, z0_im1, z1_re0, z1_re1, z1_im0, z1_im1) in iter { + let (z0_re, z0_im) = ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)); + let (z1_re, z1_im) = ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)); + let (z0mz1_re, z0mz1_im) = simd.cplx_sub(z0_re, z0_im, z1_re, z1_im); + + ((*z0_re0, *z0_re1), (*z0_im0, *z0_im1)) = + simd.cplx_add(z0_re, z0_im, z1_re, z1_im); + ((*z1_re0, *z1_re1), (*z1_im0, *z1_im1)) = + simd.cplx_mul_conj(z0mz1_re, z0mz1_im, w1_re, w1_im); + } + } + + t *= 2; + } + } + }); +} + +fn bitreverse(i: usize, n: usize) -> usize { + let logn = n.trailing_zeros(); + let mut result = 0; + for k in 0..logn { + let kth_bit = (i >> k) & 1_usize; + result |= kth_bit << (logn - k - 1); + } + result +} + +#[doc(hidden)] +pub fn init_negacyclic_twiddles( + twid_re0: &mut [f64], + twid_re1: &mut [f64], + twid_im0: &mut [f64], + twid_im1: &mut [f64], +) { + let n = twid_re0.len(); + let mut m = 1_usize; + + while m < n { + for i in 0..m { + let k = 2 * m + i; + let pos = m + i; + + let theta_over_pi = f128(bitreverse(k, 2 * n) as f64 / (2 * n) as f64, 0.0); + let (s, c) = theta_over_pi.sincospi(); + twid_re0[pos] = c.0; + twid_re1[pos] = c.1; + twid_im0[pos] = s.0; + twid_im1[pos] = s.1; + } + m *= 2; + } +} + +/// 128-bit negacyclic FFT plan. +#[derive(Clone)] +pub struct Plan { + twid_re0: ABox<[f64]>, + twid_re1: ABox<[f64]>, + twid_im0: ABox<[f64]>, + twid_im1: ABox<[f64]>, +} + +impl core::fmt::Debug for Plan { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.debug_struct("Plan") + .field("fft_size", &self.fft_size()) + .finish() + } +} + +impl Plan { + /// Returns a new negacyclic FFT plan for the given vector size, following the algorithm in + /// [Fast and Error-Free Negacyclic Integer Convolution using Extended Fourier Transform][paper] + /// + /// # Panics + /// + /// - Panics if `n` is not a power of two, or if it is less than `32`. + /// + /// # Example + /// + /// ``` + /// use concrete_fft::fft128::Plan; + /// let plan = Plan::new(32); + /// ``` + /// + /// [paper]: https://eprint.iacr.org/2021/480 + #[track_caller] + pub fn new(n: usize) -> Self { + assert!(n.is_power_of_two()); + assert!(n >= 32); + + let mut twid_re0 = avec![0.0f64; n].into_boxed_slice(); + let mut twid_re1 = avec![0.0f64; n].into_boxed_slice(); + let mut twid_im0 = avec![0.0f64; n].into_boxed_slice(); + let mut twid_im1 = avec![0.0f64; n].into_boxed_slice(); + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + Self { + twid_re0, + twid_re1, + twid_im0, + twid_im1, + } + } + + /// Returns the vector size of the negacyclic FFT. + /// + /// # Example + /// + /// ``` + /// use concrete_fft::fft128::Plan; + /// let plan = Plan::new(32); + /// assert_eq!(plan.fft_size(), 32); + /// ``` + pub fn fft_size(&self) -> usize { + self.twid_re0.len() + } + + /// Performs a forward negacyclic FFT in place. + /// + /// # Note + /// + /// The values in `buf_re0`, `buf_re1`, `buf_im0`, `buf_im1` must be in standard order prior to + /// calling this function. When this function returns, the values in `buf_re0`, `buf_re1`, `buf_im0`, `buf_im1` will contain the + /// terms of the forward transform in bit-reversed order. + #[track_caller] + pub fn fwd( + &self, + buf_re0: &mut [f64], + buf_re1: &mut [f64], + buf_im0: &mut [f64], + buf_im1: &mut [f64], + ) { + assert_eq!(buf_re0.len(), self.fft_size()); + assert_eq!(buf_re1.len(), self.fft_size()); + assert_eq!(buf_im0.len(), self.fft_size()); + assert_eq!(buf_im1.len(), self.fft_size()); + + negacyclic_fwd_fft( + buf_re0, + buf_re1, + buf_im0, + buf_im1, + &self.twid_re0, + &self.twid_re1, + &self.twid_im0, + &self.twid_im1, + ); + } + + /// Performs an inverse negacyclic FFT in place. + /// + /// # Note + /// + /// The values in `buf_re0`, `buf_re1`, `buf_im0`, `buf_im1` must be in bit-reversed order + /// prior to calling this function. When this function returns, the values in `buf_re0`, + /// `buf_re1`, `buf_im0`, `buf_im1` will contain the terms of the inverse transform in standard + /// order. + #[track_caller] + pub fn inv( + &self, + buf_re0: &mut [f64], + buf_re1: &mut [f64], + buf_im0: &mut [f64], + buf_im1: &mut [f64], + ) { + assert_eq!(buf_re0.len(), self.fft_size()); + assert_eq!(buf_re1.len(), self.fft_size()); + assert_eq!(buf_im0.len(), self.fft_size()); + assert_eq!(buf_im1.len(), self.fft_size()); + + negacyclic_inv_fft( + buf_re0, + buf_re1, + buf_im0, + buf_im1, + &self.twid_re0, + &self.twid_re1, + &self.twid_im0, + &self.twid_im1, + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use alloc::vec; + use rand::random; + + extern crate alloc; + + #[test] + fn test_wrapper() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + let plan = Plan::new(n / 2); + + plan.fwd( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + ); + plan.fwd( + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + ); + + let factor = 2.0 / n as f64; + let simd = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = simd.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + plan.inv( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + + #[test] + fn test_product() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut twid_re0 = vec![0.0; n / 2]; + let mut twid_re1 = vec![0.0; n / 2]; + let mut twid_im0 = vec![0.0; n / 2]; + let mut twid_im1 = vec![0.0; n / 2]; + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + negacyclic_fwd_fft_scalar( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + negacyclic_fwd_fft_scalar( + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + let factor = 2.0 / n as f64; + let simd = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = simd.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + negacyclic_inv_fft_scalar( + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[test] + fn test_product_avxfma() { + if let Some(simd) = Avx::try_new() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut twid_re0 = vec![0.0; n / 2]; + let mut twid_re1 = vec![0.0; n / 2]; + let mut twid_im0 = vec![0.0; n / 2]; + let mut twid_im1 = vec![0.0; n / 2]; + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + negacyclic_fwd_fft_avxfma( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + negacyclic_fwd_fft_avxfma( + simd, + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + let factor = 2.0 / n as f64; + let scalar = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = scalar.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + negacyclic_inv_fft_avxfma( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + } + + #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] + #[cfg(feature = "nightly")] + #[test] + fn test_product_avx512() { + if let Some(simd) = Avx512::try_new() { + let n = 1024; + + let mut lhs = vec![f128(0.0, 0.0); n]; + let mut rhs = vec![f128(0.0, 0.0); n]; + let mut result = vec![f128(0.0, 0.0); n]; + + for x in &mut lhs { + x.0 = random(); + } + for x in &mut rhs { + x.0 = random(); + } + + let mut full_convolution = vec![f128(0.0, 0.0); 2 * n]; + let mut negacyclic_convolution = vec![f128(0.0, 0.0); n]; + for i in 0..n { + for j in 0..n { + full_convolution[i + j] += lhs[i] * rhs[j]; + } + } + for i in 0..n { + negacyclic_convolution[i] = full_convolution[i] - full_convolution[i + n]; + } + + let mut twid_re0 = vec![0.0; n / 2]; + let mut twid_re1 = vec![0.0; n / 2]; + let mut twid_im0 = vec![0.0; n / 2]; + let mut twid_im1 = vec![0.0; n / 2]; + + let mut lhs_fourier_re0 = vec![0.0; n / 2]; + let mut lhs_fourier_re1 = vec![0.0; n / 2]; + let mut lhs_fourier_im0 = vec![0.0; n / 2]; + let mut lhs_fourier_im1 = vec![0.0; n / 2]; + + let mut rhs_fourier_re0 = vec![0.0; n / 2]; + let mut rhs_fourier_re1 = vec![0.0; n / 2]; + let mut rhs_fourier_im0 = vec![0.0; n / 2]; + let mut rhs_fourier_im1 = vec![0.0; n / 2]; + + init_negacyclic_twiddles(&mut twid_re0, &mut twid_re1, &mut twid_im0, &mut twid_im1); + + for i in 0..n / 2 { + lhs_fourier_re0[i] = lhs[i].0; + lhs_fourier_re1[i] = lhs[i].1; + lhs_fourier_im0[i] = lhs[i + n / 2].0; + lhs_fourier_im1[i] = lhs[i + n / 2].1; + + rhs_fourier_re0[i] = rhs[i].0; + rhs_fourier_re1[i] = rhs[i].1; + rhs_fourier_im0[i] = rhs[i + n / 2].0; + rhs_fourier_im1[i] = rhs[i + n / 2].1; + } + + negacyclic_fwd_fft_avx512( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + negacyclic_fwd_fft_avx512( + simd, + &mut rhs_fourier_re0, + &mut rhs_fourier_re1, + &mut rhs_fourier_im0, + &mut rhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + let factor = 2.0 / n as f64; + let scalar = Scalar; + for i in 0..n / 2 { + let (prod_re, prod_im) = scalar.cplx_mul( + (lhs_fourier_re0[i], lhs_fourier_re1[i]), + (lhs_fourier_im0[i], lhs_fourier_im1[i]), + (rhs_fourier_re0[i], rhs_fourier_re1[i]), + (rhs_fourier_im0[i], rhs_fourier_im1[i]), + ); + + lhs_fourier_re0[i] = prod_re.0 * factor; + lhs_fourier_re1[i] = prod_re.1 * factor; + lhs_fourier_im0[i] = prod_im.0 * factor; + lhs_fourier_im1[i] = prod_im.1 * factor; + } + + negacyclic_inv_fft_avx512( + simd, + &mut lhs_fourier_re0, + &mut lhs_fourier_re1, + &mut lhs_fourier_im0, + &mut lhs_fourier_im1, + &twid_re0, + &twid_re1, + &twid_im0, + &twid_im1, + ); + + for i in 0..n / 2 { + result[i] = f128(lhs_fourier_re0[i], lhs_fourier_re1[i]); + result[i + n / 2] = f128(lhs_fourier_im0[i], lhs_fourier_im1[i]); + } + + for i in 0..n { + assert!((result[i] - negacyclic_convolution[i]).abs() < 1e-28); + } + } + } +} diff --git a/src/lib.rs b/src/lib.rs index cf76154..d95cae6 100755 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,10 +13,14 @@ //! convolution. The only operations that are performed in the Fourier domain are elementwise, and //! so the order of the coefficients does not affect the results. //! +//! Additionally, an optional 128-bit negacyclic FFT module is provided. +//! //! # Features //! //! - `std` (default): This enables runtime arch detection for accelerated SIMD instructions, and //! an FFT plan that measures the various implementations to choose the fastest one at runtime. +//! - `fft128` (default): This flag provides access to the 128-bit FFT, which is accessible in the +//! [`fft128`] module. //! - `nightly`: This enables unstable Rust features to further speed up the FFT, by enabling //! AVX512F instructions on CPUs that support them. This feature requires a nightly Rust //! toolchain. @@ -89,6 +93,9 @@ pub(crate) mod dit2; pub(crate) mod dit4; pub(crate) mod dit8; +#[cfg(feature = "fft128")] +#[cfg_attr(docsrs, doc(cfg(feature = "fft128")))] +pub mod fft128; pub mod ordered; pub mod unordered; diff --git a/src/ordered.rs b/src/ordered.rs index 4cd24da..737c645 100644 --- a/src/ordered.rs +++ b/src/ordered.rs @@ -87,7 +87,7 @@ fn measure_n_runs( #[cfg(feature = "std")] fn duration_div_f64(duration: Duration, n: f64) -> Duration { - Duration::from_secs_f64(duration.as_secs_f64() / n as f64) + Duration::from_secs_f64(duration.as_secs_f64() / n) } #[cfg(feature = "std")] diff --git a/src/unordered.rs b/src/unordered.rs index 962a76b..e565fa4 100644 --- a/src/unordered.rs +++ b/src/unordered.rs @@ -958,7 +958,7 @@ impl Plan { /// # Note /// /// The values in `buf` must be in permuted order prior to calling this function. - /// When this function returns, the values in `buf` will contain the terms of the forward + /// When this function returns, the values in `buf` will contain the terms of the inverse /// transform in standard order. /// /// # Example @@ -1144,7 +1144,7 @@ mod tests { ); let base_n = plan.algo().1; let mut mem = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); - let stack = DynStack::new(&mut *mem); + let stack = DynStack::new(&mut mem); plan.fwd(&mut z, stack); for i in 0..n { @@ -1177,7 +1177,7 @@ mod tests { }, ); let mut mem = GlobalMemBuffer::new(plan.fft_scratch().unwrap()); - let mut stack = DynStack::new(&mut *mem); + let mut stack = DynStack::new(&mut mem); plan.fwd(&mut z, stack.rb_mut()); plan.inv(&mut z, stack); @@ -1232,7 +1232,7 @@ mod tests_serde { .unwrap() .or(plan2.fft_scratch().unwrap()), ); - let mut stack = DynStack::new(&mut *mem); + let mut stack = DynStack::new(&mut mem); plan1.fwd(&mut z, stack.rb_mut());