Skip to content

Commit

Permalink
Merge branch 'main' into stark-prover-over-field-extensions-v3
Browse files Browse the repository at this point in the history
  • Loading branch information
schouhy committed Dec 20, 2023
2 parents 19c55f8 + e1a8716 commit 345ef5e
Show file tree
Hide file tree
Showing 16 changed files with 420 additions and 139 deletions.
19 changes: 17 additions & 2 deletions math/benches/criterion_fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use utils::stark252_utils;

mod utils;

const SIZE_ORDERS: [u64; 4] = [21, 22, 23, 24];
const SIZE_ORDERS: [u64; 5] = [20, 21, 22, 23, 24];

pub fn fft_benchmarks(c: &mut Criterion) {
let mut group = c.benchmark_group("Ordered FFT");
Expand All @@ -22,7 +22,7 @@ pub fn fft_benchmarks(c: &mut Criterion) {

group.bench_with_input(
"Sequential from NR radix2",
&(input_nat, twiddles_bitrev),
&(input_nat.clone(), twiddles_bitrev.clone()),
|bench, (input, twiddles)| {
bench.iter_batched(
|| input.clone(),
Expand All @@ -46,6 +46,21 @@ pub fn fft_benchmarks(c: &mut Criterion) {
);
},
);
if order % 2 == 0 {
group.bench_with_input(
"Sequential from NR radix4",
&(input_nat, twiddles_bitrev),
|bench, (input, twiddles)| {
bench.iter_batched(
|| input.clone(),
|mut input| {
fft_functions::ordered_fft_nr4(&mut input, twiddles);
},
BatchSize::LargeInput,
);
},
);
}
}

group.finish();
Expand Down
6 changes: 5 additions & 1 deletion math/benches/utils/fft_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
use criterion::black_box;
use lambdaworks_math::fft::cpu::{
bit_reversing::in_place_bit_reverse_permute,
fft::{in_place_nr_2radix_fft, in_place_rn_2radix_fft},
fft::{in_place_nr_2radix_fft, in_place_nr_4radix_fft, in_place_rn_2radix_fft},
roots_of_unity::get_twiddles,
};
use lambdaworks_math::{field::traits::RootsConfig, polynomial::Polynomial};
Expand All @@ -18,6 +18,10 @@ pub fn ordered_fft_rn(input: &mut [FE], twiddles: &[FE]) {
in_place_rn_2radix_fft(input, twiddles);
}

pub fn ordered_fft_nr4(input: &mut [FE], twiddles: &[FE]) {
in_place_nr_4radix_fft(input, twiddles);
}

pub fn twiddles_generation(order: u64, config: RootsConfig) {
get_twiddles::<F>(order, config).unwrap();
}
Expand Down
95 changes: 94 additions & 1 deletion math/src/fft/cpu/fft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,79 @@ where
}
}

/// In-Place Radix-4 NR DIT FFT algorithm over a slice of two-adic field elements.
/// It's required that the twiddle factors are in bit-reverse order. Else this function will not
/// return fourier transformed values.
/// Also the input size needs to be a power of two.
/// It's recommended to use the current safe abstractions instead of this function.
///
/// Performs a fast fourier transform with the next attributes:
/// - In-Place: an auxiliary vector of data isn't needed for the algorithm.
/// - Radix-4: the algorithm halves the problem size log(n) times.
/// - NR: natural to reverse order, meaning that the input is naturally ordered and the output will
/// be bit-reversed ordered.
/// - DIT: decimation in time
pub fn in_place_nr_4radix_fft<F, E>(input: &mut [FieldElement<E>], twiddles: &[FieldElement<F>])
where
F: IsFFTField + IsSubFieldOf<E>,
E: IsField,
{
debug_assert!(input.len().is_power_of_two());
debug_assert!(input.len().ilog2() % 2 == 0); // Even power of 2 => x is power of 4

// divide input in groups, starting with 1, duplicating the number of groups in each stage.
let mut group_count = 1;
let mut group_size = input.len();

// for each group, there'll be group_size / 4 butterflies.
// a butterfly is the atomic operation of a FFT, e.g:
// x' = x + yw2 + zw1 + tw1w2
// y' = x - yw2 + zw1 - tw1w2
// z' = x + yw3 - zw1 - tw1w3
// t' = x - yw3 - zw1 + tw1w3
// The 0.25 factor is what gives FFT its performance, it recursively divides the problem size
// by 4 (group size).

while group_count < input.len() {
#[allow(clippy::needless_range_loop)] // the suggestion would obfuscate a bit the algorithm
for group in 0..group_count {
let first_in_group = group * group_size;
let first_in_next_group = first_in_group + group_size / 4;

let (w1, w2, w3) = (
&twiddles[group],
&twiddles[2 * group],
&twiddles[2 * group + 1],
);

for i in first_in_group..first_in_next_group {
let (j, k, l) = (
i + group_size / 4,
i + group_size / 2,
i + 3 * group_size / 4,
);

let zw1 = w1 * &input[k];
let tw1 = w1 * &input[l];
let a = w2 * (&input[j] + &tw1);
let b = w3 * (&input[j] - &tw1);

let x = &input[i] + &zw1 + &a;
let y = &input[i] + &zw1 - &a;
let z = &input[i] - &zw1 + &b;
let t = &input[i] - &zw1 - &b;

input[i] = x;
input[j] = y;
input[k] = z;
input[l] = t;
}
}
group_count *= 4;
group_size /= 4;
}
}

#[cfg(test)]
mod tests {
use crate::fft::cpu::bit_reversing::in_place_bit_reverse_permute;
Expand All @@ -128,7 +201,12 @@ mod tests {
}
}
prop_compose! {
fn field_vec(max_exp: u8)(vec in collection::vec(field_element(), 2..1<<max_exp).prop_filter("Avoid polynomials of size not power of two", |vec| vec.len().is_power_of_two())) -> Vec<FE> {
fn field_vec(max_exp: u8)(vec in (1..max_exp).prop_flat_map(|i| collection::vec(field_element(), 1 << i))) -> Vec<FE> {
vec
}
}
prop_compose! {
fn field_vec_r4(max_exp: u8)(vec in (1..max_exp).prop_flat_map(|i| collection::vec(field_element(), 1 << (2 * i)))) -> Vec<FE> {
vec
}
}
Expand Down Expand Up @@ -163,5 +241,20 @@ mod tests {

prop_assert_eq!(result, expected);
}

// Property-based test that ensures NR Radix-2 FFT gives the same result as a naive DFT.
#[test]
fn test_nr_4radix_fft_matches_naive_eval(coeffs in field_vec_r4(5)) {
let expected = naive_matrix_dft_test(&coeffs);

let order = coeffs.len().trailing_zeros();
let twiddles = get_twiddles(order.into(), RootsConfig::BitReverse).unwrap();

let mut result = coeffs;
in_place_nr_4radix_fft::<F, F>(&mut result, &twiddles);
in_place_bit_reverse_permute(&mut result);

prop_assert_eq!(expected, result);
}
}
}
137 changes: 135 additions & 2 deletions math/src/field/fields/winterfell.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
use core::ops::Add;

use crate::{
errors::ByteConversionError,
field::{
element::FieldElement,
errors::FieldError,
traits::{IsFFTField, IsField, IsPrimeField},
traits::{IsFFTField, IsField, IsPrimeField, IsSubFieldOf},
},
traits::{ByteConversion, Serializable},
unsigned_integer::element::U256,
};
pub use miden_core::Felt;
use miden_core::QuadExtension;
pub use winter_math::fields::f128::BaseElement;
use winter_math::{FieldElement as IsWinterfellFieldElement, StarkField};
use winter_math::{ExtensionOf, FieldElement as IsWinterfellFieldElement, StarkField};

impl IsFFTField for Felt {
const TWO_ADICITY: u64 = <Felt as StarkField>::TWO_ADICITY as u64;
Expand Down Expand Up @@ -118,3 +121,133 @@ impl ByteConversion for Felt {
}
}
}

pub type QuadFelt = QuadExtension<Felt>;

impl ByteConversion for QuadFelt {
fn to_bytes_be(&self) -> Vec<u8> {
let [b0, b1] = self.to_base_elements();
let mut bytes = b0.to_bytes_be();
bytes.extend(&b1.to_bytes_be());
bytes
}

fn to_bytes_le(&self) -> Vec<u8> {
let [b0, b1] = self.to_base_elements();
let mut bytes = b0.to_bytes_le();
bytes.extend(&b1.to_bytes_be());
bytes
}

fn from_bytes_be(_bytes: &[u8]) -> Result<Self, ByteConversionError>
where
Self: Sized,
{
todo!()
}

fn from_bytes_le(_bytes: &[u8]) -> Result<Self, ByteConversionError>
where
Self: Sized,
{
todo!()
}
}

impl Serializable for FieldElement<QuadFelt> {
fn serialize(&self) -> Vec<u8> {
let [b0, b1] = self.value().to_base_elements();
let mut bytes = b0.to_bytes_be();
bytes.extend(&b1.to_bytes_be());
bytes
}
}

impl IsField for QuadFelt {
type BaseType = QuadFelt;

fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType {
*a + *b
}

fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType {
*a * *b
}

fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType {
*a - *b
}

fn neg(a: &Self::BaseType) -> Self::BaseType {
-*a
}

fn inv(a: &Self::BaseType) -> Result<Self::BaseType, FieldError> {
Ok((*a).inv())
}

fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType {
*a / *b
}

fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool {
*a == *b
}

fn zero() -> Self::BaseType {
Self::BaseType::ZERO
}

fn one() -> Self::BaseType {
Self::BaseType::ONE
}

fn from_u64(x: u64) -> Self::BaseType {
Self::BaseType::from(x)
}

fn from_base_type(x: Self::BaseType) -> Self::BaseType {
x
}
}

impl IsSubFieldOf<QuadFelt> for Felt {
fn mul(
a: &Self::BaseType,
b: &<QuadFelt as IsField>::BaseType,
) -> <QuadFelt as IsField>::BaseType {
b.mul_base(*a)
}

fn add(
a: &Self::BaseType,
b: &<QuadFelt as IsField>::BaseType,
) -> <QuadFelt as IsField>::BaseType {
let [b0, b1] = b.to_base_elements();
QuadFelt::new(b0.add(*a), b1)
}

fn div(
a: &Self::BaseType,
b: &<QuadFelt as IsField>::BaseType,
) -> <QuadFelt as IsField>::BaseType {
let b_inv = b.inv();
<Self as IsSubFieldOf<QuadFelt>>::mul(a, &b_inv)
}

fn sub(
a: &Self::BaseType,
b: &<QuadFelt as IsField>::BaseType,
) -> <QuadFelt as IsField>::BaseType {
let [b0, b1] = b.to_base_elements();
QuadFelt::new(a.add(-(b0)), -b1)
}

fn embed(a: Self::BaseType) -> <QuadFelt as IsField>::BaseType {
QuadFelt::new(a, Felt::ZERO)
}

fn to_subfield_vec(b: <QuadFelt as IsField>::BaseType) -> Vec<Self::BaseType> {
b.to_base_elements().to_vec()
}
}
2 changes: 1 addition & 1 deletion math/src/polynomial.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl<F: IsField> Polynomial<FieldElement<F>> {
coefficients.push(c.clone());
c = coeff + c * b;
}
coefficients.reverse();
coefficients = coefficients.into_iter().rev().collect();
Polynomial::new(&coefficients)
} else {
Polynomial::zero()
Expand Down
4 changes: 2 additions & 2 deletions provers/stark/src/constraints/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl<A: AIR> ConstraintEvaluator<A> {
&domain.coset_offset,
)
})
.collect::<Result<Vec<Vec<FieldElement<A::FieldExtension>>>, FFTError>>()
.collect::<Result<Vec<Vec<FieldElement<A::Field>>>, FFTError>>()
.unwrap();

let boundary_polys_evaluations = boundary_constraints
Expand Down Expand Up @@ -193,7 +193,7 @@ impl<A: AIR> ConstraintEvaluator<A> {

let periodic_values: Vec<_> = lde_periodic_columns
.iter()
.map(|col| col[i].clone().to_extension())
.map(|col| col[i].clone())
.collect();

// Compute all the transition constraints at this
Expand Down
8 changes: 6 additions & 2 deletions provers/stark/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,12 @@ pub fn validate_trace<A: AIR>(
.get_periodic_column_polynomials()
.iter()
.map(|poly| {
Polynomial::evaluate_fft::<A::Field>(poly, 1, Some(domain.interpolation_domain_size))
.unwrap()
Polynomial::<FieldElement<A::Field>>::evaluate_fft::<A::Field>(
poly,
1,
Some(domain.interpolation_domain_size),
)
.unwrap()
})
.collect();

Expand Down
Loading

0 comments on commit 345ef5e

Please sign in to comment.