diff --git a/math/src/circle/cfft.rs b/math/src/circle/cfft.rs new file mode 100644 index 000000000..f060e02d6 --- /dev/null +++ b/math/src/circle/cfft.rs @@ -0,0 +1,220 @@ +extern crate alloc; +use crate::field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}; +use alloc::vec::Vec; + +#[cfg(feature = "alloc")] +/// fft in place algorithm used to evaluate a polynomial of degree 2^n - 1 in 2^n points. +/// Input must be of size 2^n for some n. +pub fn cfft( + input: &mut [FieldElement], + twiddles: Vec>>, +) { + // If the input size is 2^n, then log_2_size is n. + let log_2_size = input.len().trailing_zeros(); + + // The cfft has n layers. + (0..log_2_size).for_each(|i| { + // In each layer i we split the current input in chunks of size 2^{i+1}. + let chunk_size = 1 << (i + 1); + let half_chunk_size = 1 << i; + input.chunks_mut(chunk_size).for_each(|chunk| { + // We split each chunk in half, calling the first half hi_part and the second hal low_part. + let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); + + // We apply the corresponding butterfly for every element j of the high and low part. + hi_part + .iter_mut() + .zip(low_part) + .enumerate() + .for_each(|(j, (hi, low))| { + let temp = *low * twiddles[i as usize][j]; + *low = *hi - temp; + *hi += temp + }); + }); + }); +} + +#[cfg(feature = "alloc")] +/// The inverse fft algorithm used to interpolate 2^n points. +/// Input must be of size 2^n for some n. +pub fn icfft( + input: &mut [FieldElement], + twiddles: Vec>>, +) { + // If the input size is 2^n, then log_2_size is n. + let log_2_size = input.len().trailing_zeros(); + + // The icfft has n layers. + (0..log_2_size).for_each(|i| { + // In each layer i we split the current input in chunks of size 2^{n - i}. + let chunk_size = 1 << (log_2_size - i); + let half_chunk_size = chunk_size >> 1; + input.chunks_mut(chunk_size).for_each(|chunk| { + // We split each chunk in half, calling the first half hi_part and the second hal low_part. + let (hi_part, low_part) = chunk.split_at_mut(half_chunk_size); + + // We apply the corresponding butterfly for every element j of the high and low part. + hi_part + .iter_mut() + .zip(low_part) + .enumerate() + .for_each(|(j, (hi, low))| { + let temp = *hi + *low; + *low = (*hi - *low) * twiddles[i as usize][j]; + *hi = temp; + }); + }); + }); +} + +/// This function permutes a slice of field elements to order the result of the cfft in the natural way. +/// We call the natural order to [P(x0, y0), P(x1, y1), P(x2, y2), ...], +/// where (x0, y0) is the first point of the corresponding coset. +/// The cfft doesn't return the evaluations in the natural order. +/// For example, if we apply the cfft to 8 coefficients of a polynomial of degree 7 we'll get the evaluations in this order: +/// [P(x0, y0), P(x2, y2), P(x4, y4), P(x6, y6), P(x7, y7), P(x5, y5), P(x3, y3), P(x1, y1)], +/// where the even indices are found first in ascending order and then the odd indices in descending order. +/// This function permutes the slice [0, 2, 4, 6, 7, 5, 3, 1] into [0, 1, 2, 3, 4, 5, 6, 7]. +/// TODO: This can be optimized by performing in-place value swapping (WIP). +pub fn order_cfft_result_naive( + input: &[FieldElement], +) -> Vec> { + let mut result = Vec::new(); + let length = input.len(); + for i in 0..length / 2 { + result.push(input[i]); // We push the left index. + result.push(input[length - i - 1]); // We push the right index. + } + result +} + +/// This function permutes a slice of field elements to order the input of the icfft in a specific way. +/// For example, if we want to interpolate 8 points we should input them in the icfft in this order: +/// [(x0, y0), (x2, y2), (x4, y4), (x6, y6), (x7, y7), (x5, y5), (x3, y3), (x1, y1)], +/// where the even indices are found first in ascending order and then the odd indices in descending order. +/// This function permutes the slice [0, 1, 2, 3, 4, 5, 6, 7] into [0, 2, 4, 6, 7, 5, 3, 1]. +/// TODO: This can be optimized by performing in-place value swapping (WIP). +pub fn order_icfft_input_naive( + input: &mut [FieldElement], +) -> Vec> { + let mut result = Vec::new(); + + // We push the even indices. + (0..input.len()).step_by(2).for_each(|i| { + result.push(input[i]); + }); + + // We push the odd indices. + (1..input.len()).step_by(2).rev().for_each(|i| { + result.push(input[i]); + }); + result +} + +#[cfg(test)] +mod tests { + use super::*; + type FE = FieldElement; + + #[test] + fn ordering_cfft_result_works_for_4_points() { + let expected_slice = [FE::from(0), FE::from(1), FE::from(2), FE::from(3)]; + + let slice = [FE::from(0), FE::from(2), FE::from(3), FE::from(1)]; + + let res = order_cfft_result_naive(&slice); + + assert_eq!(res, expected_slice) + } + + #[test] + fn ordering_cfft_result_works_for_16_points() { + let expected_slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + ]; + + let slice = [ + FE::from(0), + FE::from(2), + FE::from(4), + FE::from(6), + FE::from(8), + FE::from(10), + FE::from(12), + FE::from(14), + FE::from(15), + FE::from(13), + FE::from(11), + FE::from(9), + FE::from(7), + FE::from(5), + FE::from(3), + FE::from(1), + ]; + + let res = order_cfft_result_naive(&slice); + + assert_eq!(res, expected_slice) + } + + #[test] + fn from_natural_to_icfft_input_order_works() { + let mut slice = [ + FE::from(0), + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + ]; + + let expected_slice = [ + FE::from(0), + FE::from(2), + FE::from(4), + FE::from(6), + FE::from(8), + FE::from(10), + FE::from(12), + FE::from(14), + FE::from(15), + FE::from(13), + FE::from(11), + FE::from(9), + FE::from(7), + FE::from(5), + FE::from(3), + FE::from(1), + ]; + + let res = order_icfft_input_naive(&mut slice); + + assert_eq!(res, expected_slice) + } +} diff --git a/math/src/circle/cosets.rs b/math/src/circle/cosets.rs new file mode 100644 index 000000000..957097efb --- /dev/null +++ b/math/src/circle/cosets.rs @@ -0,0 +1,95 @@ +extern crate alloc; +use crate::circle::point::CirclePoint; +use crate::field::fields::mersenne31::field::Mersenne31Field; +use alloc::vec::Vec; + +/// Given g_n, a generator of the subgroup of size n of the circle, i.e. , +/// and given a shift, that is a another point of the circle, +/// we define the coset shift + which is the set of all the points in +/// plus the shift. +/// For example, if = {p1, p2, p3, p4}, then g_8 + = {g_8 + p1, g_8 + p2, g_8 + p3, g_8 + p4}. + +#[derive(Debug, Clone)] +pub struct Coset { + // Coset: shift + where n = 2^{log_2_size}. + // Example: g_16 + , n = 8, log_2_size = 3, shift = g_16. + pub log_2_size: u32, //TODO: Change log_2_size to u8 because log_2_size < 31. + pub shift: CirclePoint, +} + +impl Coset { + pub fn new(log_2_size: u32, shift: CirclePoint) -> Self { + Coset { log_2_size, shift } + } + + /// Returns the coset g_2n + + pub fn new_standard(log_2_size: u32) -> Self { + // shift is a generator of the subgroup of order 2n = 2^{log_2_size + 1}. + let shift = CirclePoint::get_generator_of_subgroup(log_2_size + 1); + Coset { log_2_size, shift } + } + + /// Returns g_n, the generator of the subgroup of order n = 2^log_2_size. + pub fn get_generator(&self) -> CirclePoint { + CirclePoint::GENERATOR.repeated_double(31 - self.log_2_size) + } + + /// Given a standard coset g_2n + , returns the subcoset with half size g_2n + + pub fn half_coset(coset: Self) -> Self { + Coset { + log_2_size: coset.log_2_size - 1, + shift: coset.shift, + } + } + + /// Given a coset shift + G returns the coset -shift + G. + /// Note that (g_2n + ) U (-g_2n + ) = g_2n + . + pub fn conjugate(coset: Self) -> Self { + Coset { + log_2_size: coset.log_2_size, + shift: coset.shift.conjugate(), + } + } + + /// Returns the vector of shift + g for every g in . + /// where g = i * g_n for i = 0, ..., n-1. + #[cfg(feature = "alloc")] + pub fn get_coset_points(coset: &Self) -> Vec> { + // g_n the generator of the subgroup of order n. + let generator_n = CirclePoint::get_generator_of_subgroup(coset.log_2_size); + let size: u8 = 1 << coset.log_2_size; + core::iter::successors(Some(coset.shift.clone()), move |prev| { + Some(prev + &generator_n) + }) + .take(size.into()) + .collect() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn coset_points_vector_has_right_size() { + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + assert_eq!(1 << coset.log_2_size, points.len()) + } + + #[test] + fn antipode_of_coset_point_is_in_coset() { + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + let point = points[2].clone(); + let anitpode_point = points[6].clone(); + assert_eq!(anitpode_point, point.antipode()) + } + + #[test] + fn coset_generator_has_right_order() { + let coset = Coset::new(2, CirclePoint::GENERATOR * 3); + let generator_n = coset.get_generator(); + assert_eq!(generator_n.repeated_double(2), CirclePoint::zero()); + } +} diff --git a/math/src/circle/errors.rs b/math/src/circle/errors.rs new file mode 100644 index 000000000..51dcb720b --- /dev/null +++ b/math/src/circle/errors.rs @@ -0,0 +1,4 @@ +#[derive(Debug)] +pub enum CircleError { + PointDoesntSatisfyCircleEquation, +} diff --git a/math/src/circle/mod.rs b/math/src/circle/mod.rs new file mode 100644 index 000000000..ac576194f --- /dev/null +++ b/math/src/circle/mod.rs @@ -0,0 +1,6 @@ +pub mod cfft; +pub mod cosets; +pub mod errors; +pub mod point; +pub mod polynomial; +pub mod twiddles; diff --git a/math/src/circle/point.rs b/math/src/circle/point.rs new file mode 100644 index 000000000..e0e7aa210 --- /dev/null +++ b/math/src/circle/point.rs @@ -0,0 +1,302 @@ +use super::errors::CircleError; +use crate::field::traits::IsField; +use crate::field::{ + element::FieldElement, + fields::mersenne31::{extensions::Degree4ExtensionField, field::Mersenne31Field}, +}; +use core::ops::{Add, AddAssign, Mul, MulAssign}; + +/// Given a Field F, we implement here the Group which consists of all the points (x, y) such as +/// x in F, y in F and x^2 + y^2 = 1, i.e. the Circle. The operation of the group will have +/// additive notation and is as follows: +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) +#[derive(Debug, Clone)] +pub struct CirclePoint { + pub x: FieldElement, + pub y: FieldElement, +} + +impl> CirclePoint { + pub fn new(x: FieldElement, y: FieldElement) -> Result { + if x.square() + y.square() == FieldElement::one() { + Ok(Self { x, y }) + } else { + Err(CircleError::PointDoesntSatisfyCircleEquation) + } + } + + /// Neutral element of the Circle group (with additive notation). + pub fn zero() -> Self { + Self::new(FieldElement::one(), FieldElement::zero()).unwrap() + } + + /// Computes 2(x, y) = (2x^2 - 1, 2xy). + pub fn double(&self) -> Self { + Self::new( + self.x.square().double() - FieldElement::one(), + self.x.double() * self.y.clone(), + ) + .unwrap() + } + + /// Computes 2^n * (x, y). + pub fn repeated_double(self, n: u32) -> Self { + let mut res = self; + for _ in 0..n { + res = res.double(); + } + res + } + + /// Computes the inverse of the point. + /// We are using -(x, y) = (x, -y), i.e. the inverse of the group opertion is conjugation + /// because the norm of every point in the circle is one. + pub fn conjugate(self) -> Self { + Self { + x: self.x, + y: -self.y, + } + } + + pub fn antipode(self) -> Self { + Self { + x: -self.x, + y: -self.y, + } + } + + pub const GENERATOR: Self = Self { + x: F::CIRCLE_GENERATOR_X, + y: F::CIRCLE_GENERATOR_Y, + }; + + /// Returns the generator of the subgroup of order n = 2^log_2_size. + /// We are using that 2^k * g is a generator of the subgroup of order 2^{31 - k}. + pub fn get_generator_of_subgroup(log_2_size: u32) -> Self { + Self::GENERATOR.repeated_double(31 - log_2_size) + } + + pub const ORDER: u128 = F::ORDER; +} + +/// Parameters of the base field that we'll need to define its Circle. +pub trait HasCircleParams { + type FE; + + /// Coordinate x of the generator of the circle group. + const CIRCLE_GENERATOR_X: FieldElement; + + /// Coordinate y of the generator of the circle group. + const CIRCLE_GENERATOR_Y: FieldElement; + + const ORDER: u128; +} + +impl HasCircleParams for Mersenne31Field { + type FE = FieldElement; + + const CIRCLE_GENERATOR_X: Self::FE = Self::FE::const_from_raw(2); + + const CIRCLE_GENERATOR_Y: Self::FE = Self::FE::const_from_raw(1268011823); + + /// ORDER = 2^31 + const ORDER: u128 = 2147483648; +} + +impl HasCircleParams for Degree4ExtensionField { + type FE = FieldElement; + + // These parameters were taken from stwo's implementation: + // https://github.com/starkware-libs/stwo/blob/9cfd48af4e8ac5dd67643a92927c894066fa989c/crates/prover/src/core/circle.rs + const CIRCLE_GENERATOR_X: Self::FE = + Degree4ExtensionField::const_from_coefficients(1, 0, 478637715, 513582971); + + const CIRCLE_GENERATOR_Y: Self::FE = + Degree4ExtensionField::const_from_coefficients(992285211, 649143431, 740191619, 1186584352); + + /// ORDER = (2^31 - 1)^4 - 1 + const ORDER: u128 = 21267647892944572736998860269687930880; +} + +/// Equality between two cricle points. +impl> PartialEq for CirclePoint { + fn eq(&self, other: &Self) -> bool { + self.x == other.x && self.y == other.y + } +} + +/// Addition (i.e. group operation) between two points: +/// (a, b) + (c, d) = (a * c - b * d, a * d + b * c) +impl> Add for &CirclePoint { + type Output = CirclePoint; + fn add(self, other: Self) -> Self::Output { + let x = &self.x * &other.x - &self.y * &other.y; + let y = &self.x * &other.y + &self.y * &other.x; + CirclePoint { x, y } + } +} +impl> Add for CirclePoint { + type Output = CirclePoint; + fn add(self, rhs: CirclePoint) -> Self::Output { + &self + &rhs + } +} +impl> Add> for &CirclePoint { + type Output = CirclePoint; + fn add(self, rhs: CirclePoint) -> Self::Output { + self + &rhs + } +} +impl> Add<&CirclePoint> for CirclePoint { + type Output = CirclePoint; + fn add(self, rhs: &CirclePoint) -> Self::Output { + &self + rhs + } +} +impl> AddAssign<&CirclePoint> for CirclePoint { + fn add_assign(&mut self, rhs: &CirclePoint) { + *self = &*self + rhs; + } +} +impl> AddAssign> for CirclePoint { + fn add_assign(&mut self, rhs: CirclePoint) { + *self += &rhs; + } +} +/// Multiplication between a point and a scalar (i.e. group operation repeatedly): +/// (x, y) * n = (x ,y) + ... + (x, y) n-times. +impl> Mul for &CirclePoint { + type Output = CirclePoint; + fn mul(self, scalar: u128) -> Self::Output { + let mut scalar = scalar; + let mut res = CirclePoint::::zero(); + let mut cur = self.clone(); + loop { + if scalar == 0 { + return res; + } + if scalar & 1 == 1 { + res += &cur; + } + cur = cur.double(); + scalar >>= 1; + } + } +} +impl> Mul for CirclePoint { + type Output = CirclePoint; + fn mul(self, scalar: u128) -> Self::Output { + &self * scalar + } +} +impl> MulAssign for CirclePoint { + fn mul_assign(&mut self, scalar: u128) { + let mut scalar = scalar; + let mut res = CirclePoint::::zero(); + loop { + if scalar == 0 { + *self = res.clone(); + } + if scalar & 1 == 1 { + res += &*self; + } + *self = self.double(); + scalar >>= 1; + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + type F = Mersenne31Field; + type FE = FieldElement; + type G = CirclePoint; + + type Fp4 = Degree4ExtensionField; + type Fp4E = FieldElement; + type G4 = CirclePoint; + + #[test] + fn create_new_valid_g_point() { + let valid_point = G::new(FE::one(), FE::zero()).unwrap(); + let expected = G { + x: FE::one(), + y: FE::zero(), + }; + assert_eq!(valid_point, expected) + } + + #[test] + fn create_new_valid_g4_point() { + let valid_point = G4::new(Fp4E::one(), Fp4E::zero()).unwrap(); + let expected = G4 { + x: Fp4E::one(), + y: Fp4E::zero(), + }; + assert_eq!(valid_point, expected) + } + + #[test] + fn create_new_invalid_circle_point() { + let invalid_point = G::new(FE::one(), FE::one()); + assert!(invalid_point.is_err()) + } + + #[test] + fn create_new_invalid_g4_circle_point() { + let invalid_point = G4::new(Fp4E::one(), Fp4E::one()); + assert!(invalid_point.is_err()) + } + + #[test] + fn zero_plus_zero_is_zero() { + let a = G::zero(); + let b = G::zero(); + assert_eq!(&a + &b, G::zero()) + } + + #[test] + fn generator_plus_zero_is_generator() { + let g = G::GENERATOR; + let zero = G::zero(); + assert_eq!(&g + &zero, g) + } + + #[test] + fn double_equals_mul_two() { + let g = G::GENERATOR; + assert_eq!(g.clone().double(), g * 2) + } + + #[test] + fn mul_eight_equals_double_three_times() { + let g = G::GENERATOR; + assert_eq!(g.clone().repeated_double(3), g * 8) + } + + #[test] + fn generator_g1_has_order_two_pow_31() { + let g = G::GENERATOR; + let n = 31; + assert_eq!(g.repeated_double(n), G::zero()) + } + + #[test] + fn generator_g4_has_the_order_of_the_group() { + let g = G4::GENERATOR; + assert_eq!(g * G4::ORDER, G4::zero()) + } + + #[test] + fn conjugation_is_inverse_operation() { + let g = G::GENERATOR; + assert_eq!(&g.clone() + &g.conjugate(), G::zero()) + } + + #[test] + fn subgroup_generator_has_correct_order() { + let generator_n = G::get_generator_of_subgroup(7); + assert_eq!(generator_n.repeated_double(7), G::zero()); + } +} diff --git a/math/src/circle/polynomial.rs b/math/src/circle/polynomial.rs new file mode 100644 index 000000000..a3ee2fadc --- /dev/null +++ b/math/src/circle/polynomial.rs @@ -0,0 +1,299 @@ +extern crate alloc; +#[cfg(feature = "alloc")] +use super::{ + cfft::{cfft, icfft, order_cfft_result_naive, order_icfft_input_naive}, + cosets::Coset, + twiddles::{get_twiddles, TwiddlesConfig}, +}; +use crate::{ + fft::cpu::bit_reversing::in_place_bit_reverse_permute, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; +use alloc::vec::Vec; + +/// Given the 2^n coefficients of a two-variables polynomial of degree 2^n - 1 in the basis {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} +/// returns the evaluation of the polynomial on the points of the standard coset of size 2^n. +/// Note that coeff has to be a vector with length a power of two 2^n. +#[cfg(feature = "alloc")] +pub fn evaluate_cfft( + coeff: Vec>, +) -> Vec> { + let mut coeff = coeff; + + // We get the twiddles for the Evaluation. + let domain_log_2_size: u32 = coeff.len().trailing_zeros(); + let coset = Coset::new_standard(domain_log_2_size); + let config = TwiddlesConfig::Evaluation; + let twiddles = get_twiddles(coset, config); + + // For our algorithm to work, we must give as input the coefficients in bit reverse order. + in_place_bit_reverse_permute::>(&mut coeff); + cfft(&mut coeff, twiddles); + + // The cfft returns the evaluations in a certain order, so we permute them to get the natural order. + order_cfft_result_naive(&coeff) +} + +/// Interpolates the 2^n evaluations of a two-variables polynomial of degree 2^n - 1 on the points of the standard coset of size 2^n. +/// As a result we obtain the coefficients of the polynomial in the basis: {1, y, x, xy, 2xˆ2 -1, 2xˆ2y-y, 2xˆ3-x, 2xˆ3y-xy,...} +/// Note that eval has to be a vector of length a power of two 2^n. +/// If the vector of evaluations is empty, it returns an empty vector. +#[cfg(feature = "alloc")] +pub fn interpolate_cfft( + eval: Vec>, +) -> Vec> { + let mut eval = eval; + + if eval.is_empty() { + let poly: Vec> = Vec::new(); + return poly; + } + + // We get the twiddles for the interpolation. + let domain_log_2_size: u32 = eval.len().trailing_zeros(); + let coset = Coset::new_standard(domain_log_2_size); + let config = TwiddlesConfig::Interpolation; + let twiddles = get_twiddles(coset, config); + + // For our algorithm to work, we must give as input the evaluations ordered in a certain way. + let mut eval_ordered = order_icfft_input_naive(&mut eval); + icfft(&mut eval_ordered, twiddles); + + // The icfft returns the polynomial coefficients in bit reverse order. So we premute it to get the natural order. + in_place_bit_reverse_permute::>(&mut eval_ordered); + + // The icfft returns all the coefficients multiplied by 2^n, the length of the evaluations. + // So we multiply every element that outputs the icfft by the inverse of 2^n to get the actual coefficients. + // Note that this `unwrap` will never panic because eval.len() != 0. + let factor = (FieldElement::::from(eval.len() as u64)) + .inv() + .unwrap(); + eval_ordered.iter().map(|coef| coef * factor).collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::circle::cosets::Coset; + type FE = FieldElement; + use alloc::vec; + + /// Naive evaluation of a polynomial of degree 3. + fn evaluate_poly_4(coef: &[FE; 4], x: FE, y: FE) -> FE { + coef[0] + coef[1] * y + coef[2] * x + coef[3] * x * y + } + + /// Naive evaluation of a polynomial of degree 7. + fn evaluate_poly_8(coef: &[FE; 8], x: FE, y: FE) -> FE { + coef[0] + + coef[1] * y + + coef[2] * x + + coef[3] * x * y + + coef[4] * (x.square().double() - FE::one()) + + coef[5] * (x.square().double() - FE::one()) * y + + coef[6] * ((x.square() * x).double() - x) + + coef[7] * ((x.square() * x).double() - x) * y + } + + /// Naive evaluation of a polynomial of degree 15. + fn evaluate_poly_16(coef: &[FE; 16], x: FE, y: FE) -> FE { + let mut a = x; + let mut v = Vec::new(); + v.push(FE::one()); + v.push(x); + for _ in 2..4 { + a = a.square().double() - FE::one(); + v.push(a); + } + + coef[0] * v[0] + + coef[1] * y * v[0] + + coef[2] * v[1] + + coef[3] * y * v[1] + + coef[4] * v[2] + + coef[5] * y * v[2] + + coef[6] * v[1] * v[2] + + coef[7] * y * v[1] * v[2] + + coef[8] * v[3] + + coef[9] * y * v[3] + + coef[10] * v[1] * v[3] + + coef[11] * y * v[1] * v[3] + + coef[12] * v[2] * v[3] + + coef[13] * y * v[2] * v[3] + + coef[14] * v[1] * v[2] * v[3] + + coef[15] * y * v[1] * v[2] * v[3] + } + + #[test] + /// cfft evaluation equals naive evaluation. + fn cfft_evaluation_4_points() { + // We define the coefficients of a polynomial of degree 3. + let input = [FE::from(1), FE::from(2), FE::from(3), FE::from(4)]; + + // We create the coset points and evaluate the polynomial with the naive function. + let coset = Coset::new_standard(2); + let points = Coset::get_coset_points(&coset); + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly_4(&input, point.x, point.y); + expected_result.push(point_eval); + } + + let input_vec = input.to_vec(); + // We evaluate the polynomial using now the cfft. + let result = evaluate_cfft(input_vec); + let slice_result: &[FE] = &result; + + assert_eq!(slice_result, expected_result); + } + + #[test] + /// cfft evaluation equals naive evaluation. + fn cfft_evaluation_8_points() { + // We define the coefficients of a polynomial of degree 7. + let input = [ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + ]; + + // We create the coset points and evaluate them without the fft. + let coset = Coset::new_standard(3); + let points = Coset::get_coset_points(&coset); + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly_8(&input, point.x, point.y); + expected_result.push(point_eval); + } + + // We evaluate the polynomial using now the cfft. + let result = evaluate_cfft(input.to_vec()); + let slice_result: &[FE] = &result; + + assert_eq!(slice_result, expected_result); + } + + #[test] + /// cfft evaluation equals naive evaluation. + fn cfft_evaluation_16_points() { + // We define the coefficients of a polynomial of degree 15. + let input = [ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + FE::from(16), + ]; + + // We create the coset points and evaluate them without the fft. + let coset = Coset::new_standard(4); + let points = Coset::get_coset_points(&coset); + let mut expected_result: Vec = Vec::new(); + for point in points { + let point_eval = evaluate_poly_16(&input, point.x, point.y); + expected_result.push(point_eval); + } + + // We evaluate the polynomial using now the cfft. + let result = evaluate_cfft(input.to_vec()); + let slice_result: &[FE] = &result; + + assert_eq!(slice_result, expected_result); + } + + #[test] + fn evaluate_and_interpolate_8_points_is_identity() { + // We define the 8 coefficients of a polynomial of degree 7. + let coeff = vec![ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); + + assert_eq!(coeff, new_coeff); + } + + #[test] + fn evaluate_and_interpolate_8_other_points() { + let coeff = vec![ + FE::from(2147483650), + FE::from(147483647), + FE::from(2147483700), + FE::from(2147483647), + FE::from(3147483647), + FE::from(4147483647), + FE::from(2147483640), + FE::from(5147483647), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); + + assert_eq!(coeff, new_coeff); + } + + #[test] + fn evaluate_and_interpolate_32_points() { + // We define 32 coefficients of a polynomial of degree 31. + let coeff = vec![ + FE::from(1), + FE::from(2), + FE::from(3), + FE::from(4), + FE::from(5), + FE::from(6), + FE::from(7), + FE::from(8), + FE::from(9), + FE::from(10), + FE::from(11), + FE::from(12), + FE::from(13), + FE::from(14), + FE::from(15), + FE::from(16), + FE::from(17), + FE::from(18), + FE::from(19), + FE::from(20), + FE::from(21), + FE::from(22), + FE::from(23), + FE::from(24), + FE::from(25), + FE::from(26), + FE::from(27), + FE::from(28), + FE::from(29), + FE::from(30), + FE::from(31), + FE::from(32), + ]; + let evals = evaluate_cfft(coeff.clone()); + let new_coeff = interpolate_cfft(evals); + + assert_eq!(coeff, new_coeff); + } +} diff --git a/math/src/circle/twiddles.rs b/math/src/circle/twiddles.rs new file mode 100644 index 000000000..6a07804c4 --- /dev/null +++ b/math/src/circle/twiddles.rs @@ -0,0 +1,83 @@ +extern crate alloc; +use crate::{ + circle::cosets::Coset, + field::{element::FieldElement, fields::mersenne31::field::Mersenne31Field}, +}; +use alloc::vec::Vec; + +#[derive(PartialEq)] +pub enum TwiddlesConfig { + Evaluation, + Interpolation, +} +#[cfg(feature = "alloc")] +pub fn get_twiddles( + domain: Coset, + config: TwiddlesConfig, +) -> Vec>> { + // We first take the half coset. + let half_domain_points = Coset::get_coset_points(&Coset::half_coset(domain.clone())); + + // The first set of twiddles are all the y coordinates of the half coset. + let mut twiddles: Vec>> = Vec::new(); + twiddles.push(half_domain_points.iter().map(|p| p.y).collect()); + + if domain.log_2_size >= 2 { + // The second set of twiddles are the x coordinates of the first half of the half coset. + twiddles.push( + half_domain_points + .iter() + .take(half_domain_points.len() / 2) + .map(|p| p.x) + .collect(), + ); + for _ in 0..(domain.log_2_size - 2) { + // The rest of the sets of twiddles are the "square" of the x coordinates of the first half of the previous set. + let prev = twiddles.last().unwrap(); + let cur = prev + .iter() + .take(prev.len() / 2) + .map(|x| x.square().double() - FieldElement::::one()) + .collect(); + twiddles.push(cur); + } + } + + if config == TwiddlesConfig::Interpolation { + // For the interpolation, we need to take the inverse element of each twiddle in the default order. + // We can take inverse being sure that the `unwrap` won't panic because the twiddles are coordinates + // of elements of the coset (or their squares) so they can't be zero. + twiddles.iter_mut().for_each(|x| { + FieldElement::::inplace_batch_inverse(x).unwrap(); + }); + } else { + // For the evaluation, we need reverse the order of the vector of twiddles. + twiddles.reverse(); + } + twiddles +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn evaluation_twiddles_vectors_length_is_correct() { + let domain = Coset::new_standard(3); + let config = TwiddlesConfig::Evaluation; + let twiddles = get_twiddles(domain, config); + for i in 0..twiddles.len() - 1 { + assert_eq!(2 * twiddles[i].len(), twiddles[i + 1].len()) + } + } + + #[test] + fn interpolation_twiddles_vectors_length_is_correct() { + let domain = Coset::new_standard(3); + let config = TwiddlesConfig::Interpolation; + let twiddles = get_twiddles(domain, config); + for i in 0..twiddles.len() - 1 { + assert_eq!(twiddles[i].len(), 2 * twiddles[i + 1].len()) + } + } +} diff --git a/math/src/field/fields/mersenne31/extensions.rs b/math/src/field/fields/mersenne31/extensions.rs index 27c2ab118..69c64f096 100644 --- a/math/src/field/fields/mersenne31/extensions.rs +++ b/math/src/field/fields/mersenne31/extensions.rs @@ -8,6 +8,8 @@ use crate::field::{ use alloc::vec::Vec; type FpE = FieldElement; +type Fp2E = FieldElement; +type Fp4E = FieldElement; #[derive(Clone, Debug)] pub struct Degree2ExtensionField; @@ -132,11 +134,18 @@ impl IsSubFieldOf for Mersenne31Field { } } -type Fp2E = FieldElement; - #[derive(Clone, Debug)] pub struct Degree4ExtensionField; +impl Degree4ExtensionField { + pub const fn const_from_coefficients(a: u32, b: u32, c: u32, d: u32) -> Fp4E { + Fp4E::const_from_raw([ + Fp2E::const_from_raw([FpE::const_from_raw(a), FpE::const_from_raw(b)]), + Fp2E::const_from_raw([FpE::const_from_raw(c), FpE::const_from_raw(d)]), + ]) + } +} + impl IsField for Degree4ExtensionField { type BaseType = [Fp2E; 2]; diff --git a/math/src/lib.rs b/math/src/lib.rs index 56c6e598e..1f5ae60d6 100644 --- a/math/src/lib.rs +++ b/math/src/lib.rs @@ -3,6 +3,7 @@ #[cfg(feature = "alloc")] extern crate alloc; +pub mod circle; pub mod cyclic_group; pub mod elliptic_curve; pub mod errors;