From 75dd92f9593bd2629ff7bf5d69b56c16a518d04b Mon Sep 17 00:00:00 2001 From: Sergio Chouhy <41742639+schouhy@users.noreply.github.com> Date: Mon, 11 Dec 2023 15:57:34 -0300 Subject: [PATCH 1/2] Math: FFT over field extensions (#711) * add subfield trait * implement IsSubfieldOf for Degree2ExtensionField struct in BLS12381 * fix mul * avoid using field elements in issubfield impl. Add tests * clippy and fmt * simplify trait bounds * add explicit type * change iter to into_iter * fix metal code * remove explicit type * refactor fftpoly * fmt * into iter for gpu * update evaluate_fft_metal * update metal fft * fix imports * add explicit types * clippy, fmt * update metal benches * add explicit types * refactor quadratic extensions and implement IsSubFieldOf * refactor cubic extensions and implement issubfield * add to_subfield_vec method to field element * add test fft over field extension * fix type * clippy * make to_subfield_vec available only under std feature * fmt * run test only on std feature * update fft docs * update docs. Fix babybear --- math/benches/utils/fft_functions.rs | 8 +- math/benches/utils/metal_functions.rs | 5 +- math/src/fft/cpu/fft.rs | 16 +- math/src/fft/cpu/ops.rs | 15 +- math/src/fft/cpu/roots_of_unity.rs | 2 +- math/src/fft/gpu/metal/ops.rs | 18 ++- math/src/fft/gpu/metal/polynomial.rs | 22 +-- math/src/fft/polynomial.rs | 153 +++++++++--------- .../src/field/fields/fft_friendly/babybear.rs | 14 +- math/src/field/test_fields/u64_test_field.rs | 35 +++- math/src/polynomial.rs | 4 +- provers/groth16/src/qap.rs | 35 ++-- provers/plonk/src/prover.rs | 54 +++---- provers/plonk/src/setup.rs | 17 +- provers/plonk/src/test_utils/circuit_1.rs | 17 +- provers/plonk/src/test_utils/circuit_json.rs | 47 ++++-- provers/stark/src/debug.rs | 7 +- provers/stark/src/fri/mod.rs | 6 +- provers/stark/src/prover.rs | 5 +- provers/stark/src/trace.rs | 3 +- provers/stark/src/traits.rs | 4 +- 21 files changed, 277 insertions(+), 210 deletions(-) diff --git a/math/benches/utils/fft_functions.rs b/math/benches/utils/fft_functions.rs index 7dc97e04f..2b2f50fb6 100644 --- a/math/benches/utils/fft_functions.rs +++ b/math/benches/utils/fft_functions.rs @@ -6,9 +6,7 @@ use lambdaworks_math::fft::cpu::{ fft::{in_place_nr_2radix_fft, in_place_rn_2radix_fft}, roots_of_unity::get_twiddles, }; -use lambdaworks_math::{ - fft::polynomial::FFTPoly, field::traits::RootsConfig, polynomial::Polynomial, -}; +use lambdaworks_math::{field::traits::RootsConfig, polynomial::Polynomial}; use super::stark252_utils::{F, FE}; @@ -29,9 +27,9 @@ pub fn bitrev_permute(input: &mut [FE]) { } pub fn poly_evaluate_fft(poly: &Polynomial) -> Vec { - poly.evaluate_fft(black_box(1), black_box(None)).unwrap() + Polynomial::evaluate_fft::(poly, black_box(1), black_box(None)).unwrap() } pub fn poly_interpolate_fft(evals: &[FE]) { - Polynomial::interpolate_fft(evals).unwrap(); + Polynomial::interpolate_fft::(evals).unwrap(); } diff --git a/math/benches/utils/metal_functions.rs b/math/benches/utils/metal_functions.rs index 6838c291b..7fd60c514 100644 --- a/math/benches/utils/metal_functions.rs +++ b/math/benches/utils/metal_functions.rs @@ -1,6 +1,5 @@ use lambdaworks_gpu::metal::abstractions::state::MetalState; use lambdaworks_math::fft::gpu::metal::ops::*; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::{field::traits::RootsConfig, polynomial::Polynomial}; // WARN: These should always be fields supported by Metal, else the last two benches will use CPU FFT. @@ -22,8 +21,8 @@ pub fn bitrev_permute(input: &[FE]) { } pub fn poly_evaluate_fft(poly: &Polynomial) { - poly.evaluate_fft(1, None).unwrap(); + Polynomial::evaluate_fft::(poly, 1, None).unwrap(); } pub fn poly_interpolate_fft(evals: &[FE]) { - Polynomial::interpolate_fft(evals).unwrap(); + Polynomial::interpolate_fft::(evals).unwrap(); } diff --git a/math/src/fft/cpu/fft.rs b/math/src/fft/cpu/fft.rs index 194e596dd..724ca9ca0 100644 --- a/math/src/fft/cpu/fft.rs +++ b/math/src/fft/cpu/fft.rs @@ -1,4 +1,7 @@ -use crate::field::{element::FieldElement, traits::IsFFTField}; +use crate::field::{ + element::FieldElement, + traits::{IsFFTField, IsField, IsSubFieldOf}, +}; /// In-Place Radix-2 NR DIT FFT algorithm over a slice of two-adic field elements. /// It's required that the twiddle factors are in bit-reverse order. Else this function will not @@ -12,9 +15,12 @@ use crate::field::{element::FieldElement, traits::IsFFTField}; /// - NR: natural to reverse order, meaning that the input is naturally ordered and the output will /// be bit-reversed ordered. /// - DIT: decimation in time -pub fn in_place_nr_2radix_fft(input: &mut [FieldElement], twiddles: &[FieldElement]) +/// +/// It supports values in a field E and domain in a subfield F. +pub fn in_place_nr_2radix_fft(input: &mut [FieldElement], twiddles: &[FieldElement]) where - F: IsFFTField, + F: IsFFTField + IsSubFieldOf, + E: IsField, { // divide input in groups, starting with 1, duplicating the number of groups in each stage. let mut group_count = 1; @@ -60,6 +66,8 @@ where /// - RN: reverse to natural order, meaning that the input is bit-reversed ordered and the output will /// be naturally ordered. /// - DIT: decimation in time +/// +/// It supports values in a field E and domain in a subfield F. #[allow(dead_code)] pub fn in_place_rn_2radix_fft(input: &mut [FieldElement], twiddles: &[FieldElement]) where @@ -135,7 +143,7 @@ mod tests { let twiddles = get_twiddles(order.into(), RootsConfig::BitReverse).unwrap(); let mut result = coeffs; - in_place_nr_2radix_fft(&mut result, &twiddles); + in_place_nr_2radix_fft::(&mut result, &twiddles); in_place_bit_reverse_permute(&mut result); prop_assert_eq!(expected, result); diff --git a/math/src/fft/cpu/ops.rs b/math/src/fft/cpu/ops.rs index da9876f98..f4bdffd7a 100644 --- a/math/src/fft/cpu/ops.rs +++ b/math/src/fft/cpu/ops.rs @@ -1,16 +1,19 @@ use crate::{ fft::errors::FFTError, - field::{element::FieldElement, traits::IsFFTField}, + field::{ + element::FieldElement, + traits::{IsFFTField, IsField, IsSubFieldOf}, + }, }; use super::{bit_reversing::in_place_bit_reverse_permute, fft::in_place_nr_2radix_fft}; -/// Executes Fast Fourier Transform over elements of a two-adic finite field `F`. Usually used for -/// fast polynomial evaluation. -pub fn fft( - input: &[FieldElement], +/// Executes Fast Fourier Transform over elements of a two-adic finite field `E` and domain in a +/// subfield `F`. Usually used for fast polynomial evaluation. +pub fn fft, E: IsField>( + input: &[FieldElement], twiddles: &[FieldElement], -) -> Result>, FFTError> { +) -> Result>, FFTError> { if !input.len().is_power_of_two() { return Err(FFTError::InputError(input.len())); } diff --git a/math/src/fft/cpu/roots_of_unity.rs b/math/src/fft/cpu/roots_of_unity.rs index 6632cc6d1..43456860e 100644 --- a/math/src/fft/cpu/roots_of_unity.rs +++ b/math/src/fft/cpu/roots_of_unity.rs @@ -54,7 +54,7 @@ pub fn get_powers_of_primitive_root_coset( offset: &FieldElement, ) -> Result>, FFTError> { let root = F::get_primitive_root_of_unity(n)?; - let results = (0..count).map(|i| root.pow(i) * offset); + let results = (0..count).map(|i| offset * root.pow(i)); Ok(results.collect()) } diff --git a/math/src/fft/gpu/metal/ops.rs b/math/src/fft/gpu/metal/ops.rs index c9a347dfc..c22fc8620 100644 --- a/math/src/fft/gpu/metal/ops.rs +++ b/math/src/fft/gpu/metal/ops.rs @@ -1,6 +1,6 @@ use crate::field::{ element::FieldElement, - traits::{IsFFTField, RootsConfig}, + traits::{IsFFTField, IsField, IsSubFieldOf, RootsConfig}, }; use lambdaworks_gpu::metal::abstractions::{errors::MetalError, state::*}; @@ -15,11 +15,17 @@ use core::mem; /// in this order too. Natural order means that input[i] corresponds to the i-th coefficient, /// as opposed to bit-reverse order in which input[bit_rev(i)] corresponds to the i-th /// coefficient. -pub fn fft( - input: &[FieldElement], +/// +/// It supports values in a field E and domain in a subfield F. +pub fn fft( + input: &[FieldElement], twiddles: &[FieldElement], state: &MetalState, -) -> Result>, MetalError> { +) -> Result>, MetalError> +where + F: IsFFTField + IsSubFieldOf, + E: IsField, +{ // TODO: make a twiddle factor abstraction for handling invalid twiddles if !input.len().is_power_of_two() { return Err(MetalError::InputError(input.len())); @@ -173,7 +179,7 @@ mod tests { fn test_metal_fft_matches_sequential(input in field_vec(6)) { let metal_state = MetalState::new(None).unwrap(); let order = input.len().trailing_zeros(); - let twiddles = get_twiddles(order.into(), RootsConfig::BitReverse).unwrap(); + let twiddles = get_twiddles::(order.into(), RootsConfig::BitReverse).unwrap(); let metal_result = super::fft(&input, &twiddles, &metal_state).unwrap(); let sequential_result = crate::fft::cpu::ops::fft(&input, &twiddles).unwrap(); @@ -190,7 +196,7 @@ mod tests { let metal_state = MetalState::new(None).unwrap(); let order = input.len().trailing_zeros(); - let twiddles = get_twiddles(order.into(), RootsConfig::BitReverse).unwrap(); + let twiddles = get_twiddles::(order.into(), RootsConfig::BitReverse).unwrap(); let metal_result = super::fft(&input, &twiddles, &metal_state).unwrap(); let sequential_result = crate::fft::cpu::ops::fft(&input, &twiddles).unwrap(); diff --git a/math/src/fft/gpu/metal/polynomial.rs b/math/src/fft/gpu/metal/polynomial.rs index 4e1714fe4..f46e3e29a 100644 --- a/math/src/fft/gpu/metal/polynomial.rs +++ b/math/src/fft/gpu/metal/polynomial.rs @@ -1,7 +1,7 @@ use crate::{ field::{ element::FieldElement, - traits::{IsFFTField, RootsConfig}, + traits::{IsFFTField, IsField, IsSubFieldOf, RootsConfig}, }, polynomial::Polynomial, }; @@ -9,30 +9,34 @@ use lambdaworks_gpu::metal::abstractions::{errors::MetalError, state::MetalState use super::ops::*; -pub fn evaluate_fft_metal(coeffs: &[FieldElement]) -> Result>, MetalError> +pub fn evaluate_fft_metal( + coeffs: &[FieldElement], +) -> Result>, MetalError> where - F: IsFFTField, + F: IsFFTField + IsSubFieldOf, + E: IsField, { let state = MetalState::new(None)?; let order = coeffs.len().trailing_zeros(); - let twiddles = gen_twiddles(order.into(), RootsConfig::BitReverse, &state)?; + let twiddles = gen_twiddles::(order.into(), RootsConfig::BitReverse, &state)?; fft(coeffs, &twiddles, &state) } /// Returns a new polynomial that interpolates `fft_evals`, which are evaluations using twiddle /// factors. This is considered to be the inverse operation of [evaluate_fft_metal()]. -pub fn interpolate_fft_metal( - fft_evals: &[FieldElement], -) -> Result>, MetalError> +pub fn interpolate_fft_metal( + fft_evals: &[FieldElement], +) -> Result>, MetalError> where - F: IsFFTField, + F: IsFFTField + IsSubFieldOf, + E: IsField, { let metal_state = MetalState::new(None)?; let order = fft_evals.len().trailing_zeros(); - let twiddles = gen_twiddles(order.into(), RootsConfig::BitReverseInversed, &metal_state)?; + let twiddles = gen_twiddles::(order.into(), RootsConfig::BitReverseInversed, &metal_state)?; let coeffs = fft(fft_evals, &twiddles, &metal_state)?; diff --git a/math/src/fft/polynomial.rs b/math/src/fft/polynomial.rs index eaaf40b97..481e32692 100644 --- a/math/src/fft/polynomial.rs +++ b/math/src/fft/polynomial.rs @@ -1,5 +1,6 @@ use crate::fft::errors::FFTError; +use crate::field::traits::{IsField, IsSubFieldOf}; use crate::{ field::{ element::FieldElement, @@ -15,58 +16,37 @@ use crate::fft::gpu::metal::polynomial::{evaluate_fft_metal, interpolate_fft_met use super::cpu::{ops, roots_of_unity}; -pub trait FFTPoly { - fn evaluate_fft( - &self, - blowup_factor: usize, - domain_size: Option, - ) -> Result>, FFTError>; - fn evaluate_offset_fft( - &self, - blowup_factor: usize, - domain_size: Option, - offset: &FieldElement, - ) -> Result>, FFTError>; - fn interpolate_fft( - fft_evals: &[FieldElement], - ) -> Result>, FFTError>; - fn interpolate_offset_fft( - fft_evals: &[FieldElement], - offset: &FieldElement, - ) -> Result>, FFTError>; -} - -impl FFTPoly for Polynomial> { - /// Returns `N` evaluations of this polynomial using FFT (so the results +impl Polynomial> { + /// Returns `N` evaluations of this polynomial using FFT over a domain in a subfield F of E (so the results /// are P(w^i), with w being a primitive root of unity). /// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`. /// If `domain_size` is `None`, it defaults to 0. - fn evaluate_fft( - &self, + pub fn evaluate_fft>( + poly: &Polynomial>, blowup_factor: usize, domain_size: Option, - ) -> Result>, FFTError> { + ) -> Result>, FFTError> { let domain_size = domain_size.unwrap_or(0); - let len = std::cmp::max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor; + let len = std::cmp::max(poly.coeff_len(), domain_size).next_power_of_two() * blowup_factor; - if self.coefficients().is_empty() { + if poly.coefficients().is_empty() { return Ok(vec![FieldElement::zero(); len]); } - let mut coeffs = self.coefficients().to_vec(); + let mut coeffs = poly.coefficients().to_vec(); coeffs.resize(len, FieldElement::zero()); // padding with zeros will make FFT return more evaluations of the same polynomial. #[cfg(feature = "metal")] { if !F::field_name().is_empty() { - Ok(evaluate_fft_metal(&coeffs)?) + Ok(evaluate_fft_metal::(&coeffs)?) } else { println!( "GPU evaluation failed for field {}. Program will fallback to CPU.", std::any::type_name::() ); - evaluate_fft_cpu(&coeffs) + evaluate_fft_cpu::(&coeffs) } } @@ -76,44 +56,46 @@ impl FFTPoly for Polynomial> { if F::field_name() == "stark256" { Ok(evaluate_fft_cuda(&coeffs)?) } else { - evaluate_fft_cpu(&coeffs) + evaluate_fft_cpu::(&coeffs) } } #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] { - evaluate_fft_cpu(&coeffs) + evaluate_fft_cpu::(&coeffs) } } - /// Returns `N` evaluations with an offset of this polynomial using FFT + /// Returns `N` evaluations with an offset of this polynomial using FFT over a domain in a subfield F of E /// (so the results are P(w^i), with w being a primitive root of unity). /// `N = max(self.coeff_len(), domain_size).next_power_of_two() * blowup_factor`. /// If `domain_size` is `None`, it defaults to 0. - fn evaluate_offset_fft( - &self, + pub fn evaluate_offset_fft>( + poly: &Polynomial>, blowup_factor: usize, domain_size: Option, offset: &FieldElement, - ) -> Result>, FFTError> { - let scaled = self.scale(offset); - scaled.evaluate_fft(blowup_factor, domain_size) + ) -> Result>, FFTError> { + let scaled = poly.scale(offset); + Polynomial::evaluate_fft::(&scaled, blowup_factor, domain_size) } /// Returns a new polynomial that interpolates `(w^i, fft_evals[i])`, with `w` being a - /// Nth primitive root of unity, and `i in 0..N`, with `N = fft_evals.len()`. + /// Nth primitive root of unity in a subfield F of E, and `i in 0..N`, with `N = fft_evals.len()`. /// This is considered to be the inverse operation of [Self::evaluate_fft()]. - fn interpolate_fft(fft_evals: &[FieldElement]) -> Result { + pub fn interpolate_fft>( + fft_evals: &[FieldElement], + ) -> Result { #[cfg(feature = "metal")] { if !F::field_name().is_empty() { - Ok(interpolate_fft_metal(fft_evals)?) + Ok(interpolate_fft_metal::(fft_evals)?) } else { println!( "GPU interpolation failed for field {}. Program will fallback to CPU.", std::any::type_name::() ); - interpolate_fft_cpu(fft_evals) + interpolate_fft_cpu::(fft_evals) } } @@ -122,63 +104,67 @@ impl FFTPoly for Polynomial> { if !F::field_name().is_empty() { Ok(interpolate_fft_cuda(fft_evals)?) } else { - interpolate_fft_cpu(fft_evals) + interpolate_fft_cpu::(fft_evals) } } #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] { - interpolate_fft_cpu(fft_evals) + interpolate_fft_cpu::(fft_evals) } } /// Returns a new polynomial that interpolates offset `(w^i, fft_evals[i])`, with `w` being a - /// Nth primitive root of unity, and `i in 0..N`, with `N = fft_evals.len()`. + /// Nth primitive root of unity in a subfield F of E, and `i in 0..N`, with `N = fft_evals.len()`. /// This is considered to be the inverse operation of [Self::evaluate_offset_fft()]. - fn interpolate_offset_fft( - fft_evals: &[FieldElement], + pub fn interpolate_offset_fft>( + fft_evals: &[FieldElement], offset: &FieldElement, - ) -> Result>, FFTError> { - let scaled = Polynomial::interpolate_fft(fft_evals)?; + ) -> Result>, FFTError> { + let scaled = Polynomial::interpolate_fft::(fft_evals)?; Ok(scaled.scale(&offset.inv().unwrap())) } } -pub fn compose_fft( - poly_1: &Polynomial>, - poly_2: &Polynomial>, -) -> Polynomial> +pub fn compose_fft( + poly_1: &Polynomial>, + poly_2: &Polynomial>, +) -> Polynomial> where - F: IsFFTField, + F: IsFFTField + IsSubFieldOf, + E: IsField, { - let poly_2_evaluations = poly_2.evaluate_fft(1, None).unwrap(); + let poly_2_evaluations = Polynomial::evaluate_fft::(poly_2, 1, None).unwrap(); let values: Vec<_> = poly_2_evaluations .iter() .map(|value| poly_1.evaluate(value)) .collect(); - Polynomial::interpolate_fft(values.as_slice()).unwrap() + Polynomial::interpolate_fft::(values.as_slice()).unwrap() } -pub fn evaluate_fft_cpu(coeffs: &[FieldElement]) -> Result>, FFTError> +pub fn evaluate_fft_cpu(coeffs: &[FieldElement]) -> Result>, FFTError> where - F: IsFFTField, + F: IsFFTField + IsSubFieldOf, + E: IsField, { let order = coeffs.len().trailing_zeros(); - let twiddles = roots_of_unity::get_twiddles(order.into(), RootsConfig::BitReverse)?; + let twiddles = roots_of_unity::get_twiddles::(order.into(), RootsConfig::BitReverse)?; // Bit reverse order is needed for NR DIT FFT. ops::fft(coeffs, &twiddles) } -pub fn interpolate_fft_cpu( - fft_evals: &[FieldElement], -) -> Result>, FFTError> +pub fn interpolate_fft_cpu( + fft_evals: &[FieldElement], +) -> Result>, FFTError> where - F: IsFFTField, + F: IsFFTField + IsSubFieldOf, + E: IsField, { let order = fft_evals.len().trailing_zeros(); - let twiddles = roots_of_unity::get_twiddles(order.into(), RootsConfig::BitReverseInversed)?; + let twiddles = + roots_of_unity::get_twiddles::(order.into(), RootsConfig::BitReverseInversed)?; let coeffs = ops::fft(fft_evals, &twiddles)?; @@ -191,7 +177,10 @@ mod tests { #[cfg(all(not(feature = "metal"), not(feature = "cuda")))] use crate::field::traits::IsField; - use crate::field::traits::RootsConfig; + use crate::field::{ + test_fields::u64_test_field::{U64TestField, U64TestFieldExtension}, + traits::RootsConfig, + }; use proptest::{collection, prelude::*}; use roots_of_unity::{get_powers_of_primitive_root, get_powers_of_primitive_root_coset}; @@ -206,7 +195,7 @@ mod tests { let twiddles = get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap(); - let fft_eval = poly.evaluate_fft(1, None).unwrap(); + let fft_eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); let naive_eval = poly.evaluate_slice(&twiddles); (fft_eval, naive_eval) @@ -222,9 +211,8 @@ mod tests { let twiddles = get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset).unwrap(); - let fft_eval = poly - .evaluate_offset_fft(blowup_factor, None, &offset) - .unwrap(); + let fft_eval = + Polynomial::evaluate_offset_fft::(&poly, blowup_factor, None, &offset).unwrap(); let naive_eval = poly.evaluate_slice(&twiddles); (fft_eval, naive_eval) @@ -238,7 +226,7 @@ mod tests { get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap(); let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap(); - let fft_poly = Polynomial::interpolate_fft(fft_evals).unwrap(); + let fft_poly = Polynomial::interpolate_fft::(fft_evals).unwrap(); (fft_poly, naive_poly) } @@ -259,8 +247,8 @@ mod tests { fn gen_fft_interpolate_and_evaluate( poly: Polynomial>, ) -> (Polynomial>, Polynomial>) { - let eval = poly.evaluate_fft(1, None).unwrap(); - let new_poly = Polynomial::interpolate_fft(&eval).unwrap(); + let eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); + let new_poly = Polynomial::interpolate_fft::(&eval).unwrap(); (poly, new_poly) } @@ -357,7 +345,7 @@ mod tests { let p = Polynomial::new(&[FE::new(0), FE::new(2)]); let q = Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(1)]); assert_eq!( - compose_fft(&p, &q), + compose_fft::(&p, &q), Polynomial::new(&[FE::new(0), FE::new(0), FE::new(0), FE::new(2)]) ); } @@ -447,4 +435,21 @@ mod tests { } } } + + #[test] + fn test_fft_with_values_in_field_extension_over_domain_in_prime_field() { + type TF = U64TestField; + type TL = U64TestFieldExtension; + + let a = FieldElement::::from(&[FieldElement::one(), FieldElement::one()]); + let b = FieldElement::::from(&[-FieldElement::from(2), FieldElement::from(17)]); + let c = FieldElement::::one(); + let poly = Polynomial::new(&[a, b, c]); + + let eval = Polynomial::evaluate_offset_fft::(&poly, 8, Some(4), &FieldElement::from(2)) + .unwrap(); + let new_poly = + Polynomial::interpolate_offset_fft::(&eval, &FieldElement::from(2)).unwrap(); + assert_eq!(poly, new_poly); + } } diff --git a/math/src/field/fields/fft_friendly/babybear.rs b/math/src/field/fields/fft_friendly/babybear.rs index 01aa9e920..a516e8523 100644 --- a/math/src/field/fields/fft_friendly/babybear.rs +++ b/math/src/field/fields/fft_friendly/babybear.rs @@ -102,7 +102,6 @@ mod tests { get_powers_of_primitive_root, get_powers_of_primitive_root_coset, }; #[cfg(not(any(feature = "metal", feature = "cuda")))] - use crate::fft::polynomial::FFTPoly; use crate::field::element::FieldElement; #[cfg(not(any(feature = "metal", feature = "cuda")))] use crate::field::traits::{IsFFTField, RootsConfig}; @@ -118,7 +117,7 @@ mod tests { let twiddles = get_powers_of_primitive_root(order.into(), len, RootsConfig::Natural).unwrap(); - let fft_eval = poly.evaluate_fft(1, None).unwrap(); + let fft_eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); let naive_eval = poly.evaluate_slice(&twiddles); (fft_eval, naive_eval) @@ -136,9 +135,8 @@ mod tests { get_powers_of_primitive_root_coset(order.into(), len * blowup_factor, &offset) .unwrap(); - let fft_eval = poly - .evaluate_offset_fft(blowup_factor, None, &offset) - .unwrap(); + let fft_eval = + Polynomial::evaluate_offset_fft::(&poly, blowup_factor, None, &offset).unwrap(); let naive_eval = poly.evaluate_slice(&twiddles); (fft_eval, naive_eval) @@ -153,7 +151,7 @@ mod tests { get_powers_of_primitive_root(order, 1 << order, RootsConfig::Natural).unwrap(); let naive_poly = Polynomial::interpolate(&twiddles, fft_evals).unwrap(); - let fft_poly = Polynomial::interpolate_fft(fft_evals).unwrap(); + let fft_poly = Polynomial::interpolate_fft::(fft_evals).unwrap(); (fft_poly, naive_poly) } @@ -176,8 +174,8 @@ mod tests { fn gen_fft_interpolate_and_evaluate( poly: Polynomial>, ) -> (Polynomial>, Polynomial>) { - let eval = poly.evaluate_fft(1, None).unwrap(); - let new_poly = Polynomial::interpolate_fft(&eval).unwrap(); + let eval = Polynomial::evaluate_fft::(&poly, 1, None).unwrap(); + let new_poly = Polynomial::interpolate_fft::(&eval).unwrap(); (poly, new_poly) } diff --git a/math/src/field/test_fields/u64_test_field.rs b/math/src/field/test_fields/u64_test_field.rs index 780924d2a..bd24582f0 100644 --- a/math/src/field/test_fields/u64_test_field.rs +++ b/math/src/field/test_fields/u64_test_field.rs @@ -1,7 +1,11 @@ use crate::{ errors::CreationError, - field::errors::FieldError, - field::traits::{IsFFTField, IsField, IsPrimeField}, + field::{ + element::FieldElement, + extensions::quadratic::QuadraticExtensionField, + traits::{IsFFTField, IsField, IsPrimeField}, + }, + field::{errors::FieldError, extensions::quadratic::HasQuadraticNonResidue}, }; #[derive(Debug, Clone, PartialEq, Eq)] @@ -94,9 +98,23 @@ impl IsFFTField for U64TestField { const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: u64 = 1753635133440165772; } +#[derive(Clone, Debug)] +pub struct TestNonResidue; +impl HasQuadraticNonResidue for TestNonResidue { + fn residue() -> FieldElement { + FieldElement::from(7) + } +} + +pub type U64TestFieldExtension = QuadraticExtensionField; + #[cfg(test)] mod tests_u64_test_field { - use crate::field::{test_fields::u64_test_field::U64TestField, traits::IsPrimeField}; + use crate::field::{ + element::FieldElement, + test_fields::u64_test_field::{U64TestField, U64TestFieldExtension}, + traits::IsPrimeField, + }; #[test] fn from_hex_for_b_is_11() { @@ -110,4 +128,15 @@ mod tests_u64_test_field { 64 ); } + + #[cfg(feature = "std")] + #[test] + fn test_to_subfield_vec() { + let a = FieldElement::::from(&[ + FieldElement::from(1), + FieldElement::from(3), + ]); + let b = a.to_subfield_vec::(); + assert_eq!(b, vec![FieldElement::from(1), FieldElement::from(3)]); + } } diff --git a/math/src/polynomial.rs b/math/src/polynomial.rs index 940e49cfc..d479f8881 100644 --- a/math/src/polynomial.rs +++ b/math/src/polynomial.rs @@ -1,5 +1,5 @@ use super::field::element::FieldElement; -use crate::field::traits::IsField; +use crate::field::traits::{IsField, IsSubFieldOf}; use std::ops; /// Represents the polynomial c_0 + c_1 * X + c_2 * X^2 + ... + c_n * X^n @@ -202,7 +202,7 @@ impl Polynomial> { } } - pub fn scale(&self, factor: &FieldElement) -> Self { + pub fn scale>(&self, factor: &FieldElement) -> Self { let scaled_coefficients = self .coefficients .iter() diff --git a/provers/groth16/src/qap.rs b/provers/groth16/src/qap.rs index 35bb8f76d..d2433c4f5 100644 --- a/provers/groth16/src/qap.rs +++ b/provers/groth16/src/qap.rs @@ -1,4 +1,4 @@ -use lambdaworks_math::{fft::polynomial::FFTPoly, polynomial::Polynomial}; +use lambdaworks_math::polynomial::Polynomial; use crate::common::*; @@ -56,9 +56,12 @@ impl QuadraticArithmeticProgram { let [l, r, o] = self.scale_and_accumulate_variable_polynomials(w, degree, offset); // TODO: Change to a vector of offsetted evaluations of x^N-1 - let mut t = (Polynomial::new_monomial(FrElement::one(), self.num_of_gates()) - - FrElement::one()) - .evaluate_offset_fft(1, Some(degree), offset) + let mut t = Polynomial::evaluate_offset_fft( + &(Polynomial::new_monomial(FrElement::one(), self.num_of_gates()) - FrElement::one()), + 1, + Some(degree), + offset, + ) .unwrap(); FrElement::inplace_batch_inverse(&mut t).unwrap(); @@ -91,7 +94,7 @@ impl QuadraticArithmeticProgram { fn build_variable_polynomials(from_matrix: &[Vec]) -> Vec> { from_matrix .iter() - .map(|row| Polynomial::interpolate_fft(row).unwrap()) + .map(|row| Polynomial::interpolate_fft::(row).unwrap()) .collect() } @@ -105,14 +108,20 @@ impl QuadraticArithmeticProgram { offset: &FrElement, ) -> [Vec; 3] { [&self.l, &self.r, &self.o].map(|var_polynomials| { - var_polynomials - .iter() - .zip(w) - .map(|(poly, coeff)| poly.mul_with_ref(&Polynomial::new_monomial(coeff.clone(), 0))) - .reduce(|poly1, poly2| poly1 + poly2) - .unwrap() - .evaluate_offset_fft(1, Some(degree), offset) - .unwrap() + Polynomial::evaluate_offset_fft( + &(var_polynomials + .iter() + .zip(w) + .map(|(poly, coeff)| { + poly.mul_with_ref(&Polynomial::new_monomial(coeff.clone(), 0)) + }) + .reduce(|poly1, poly2| poly1 + poly2) + .unwrap()), + 1, + Some(degree), + offset, + ) + .unwrap() }) } } diff --git a/provers/plonk/src/prover.rs b/provers/plonk/src/prover.rs index 5e78b70f5..b351d1c38 100644 --- a/provers/plonk/src/prover.rs +++ b/provers/plonk/src/prover.rs @@ -1,6 +1,5 @@ use lambdaworks_crypto::fiat_shamir::transcript::Transcript; use lambdaworks_math::errors::DeserializationError; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::field::traits::IsFFTField; use lambdaworks_math::traits::{Deserializable, IsRandomFieldElementGenerator, Serializable}; use std::marker::PhantomData; @@ -311,11 +310,11 @@ where witness: &Witness, common_preprocessed_input: &CommonPreprocessedInput, ) -> Round1Result { - let p_a = Polynomial::interpolate_fft(&witness.a) + let p_a = Polynomial::interpolate_fft::(&witness.a) .expect("xs and ys have equal length and xs are unique"); - let p_b = Polynomial::interpolate_fft(&witness.b) + let p_b = Polynomial::interpolate_fft::(&witness.b) .expect("xs and ys have equal length and xs are unique"); - let p_c = Polynomial::interpolate_fft(&witness.c) + let p_c = Polynomial::interpolate_fft::(&witness.c) .expect("xs and ys have equal length and xs are unique"); let z_h = Polynomial::new_monomial(FieldElement::one(), common_preprocessed_input.n) @@ -364,7 +363,7 @@ where coefficients.push(new_term); } - let p_z = Polynomial::interpolate_fft(&coefficients) + let p_z = Polynomial::interpolate_fft::(&coefficients) .expect("xs and ys have equal length and xs are unique"); let z_h = Polynomial::new_monomial(FieldElement::one(), common_preprocessed_input.n) - FieldElement::one(); @@ -392,7 +391,7 @@ where let k2 = &cpi.k1 * &cpi.k1; let one = Polynomial::new_monomial(FieldElement::one(), 0); - let p_x = &Polynomial::new_monomial(FieldElement::one(), 1); + let p_x = &Polynomial::new_monomial(FieldElement::::one(), 1); let zh = Polynomial::new_monomial(FieldElement::one(), cpi.n) - &one; let z_x_omega_coefficients: Vec> = p_z @@ -402,13 +401,13 @@ where .map(|(i, x)| x * &cpi.domain[i % cpi.n]) .collect(); let z_x_omega = Polynomial::new(&z_x_omega_coefficients); - let mut e1 = vec![FieldElement::zero(); cpi.domain.len()]; + let mut e1 = vec![FieldElement::::zero(); cpi.domain.len()]; e1[0] = FieldElement::one(); - let l1 = Polynomial::interpolate_fft(&e1) + let l1 = Polynomial::interpolate_fft::(&e1) .expect("xs and ys have equal length and xs are unique"); let mut p_pi_y = public_input.to_vec(); p_pi_y.append(&mut vec![FieldElement::zero(); cpi.n - public_input.len()]); - let p_pi = Polynomial::interpolate_fft(&p_pi_y) + let p_pi = Polynomial::interpolate_fft::(&p_pi_y) .expect("xs and ys have equal length and xs are unique"); // Compute p @@ -417,24 +416,23 @@ where // TODO: check a factor of 4 is a sensible upper bound let degree = 4 * cpi.n; let offset = &cpi.k1; - let p_a_eval = p_a.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_b_eval = p_b.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_c_eval = p_c.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let ql_eval = cpi.ql.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let qr_eval = cpi.qr.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let qm_eval = cpi.qm.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let qo_eval = cpi.qo.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let qc_eval = cpi.qc.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_pi_eval = p_pi.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_x_eval = p_x.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_z_eval = p_z.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_z_x_omega_eval = z_x_omega - .evaluate_offset_fft(1, Some(degree), offset) - .unwrap(); - let p_s1_eval = cpi.s1.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_s2_eval = cpi.s2.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let p_s3_eval = cpi.s3.evaluate_offset_fft(1, Some(degree), offset).unwrap(); - let l1_eval = l1.evaluate_offset_fft(1, Some(degree), offset).unwrap(); + let p_a_eval = Polynomial::evaluate_offset_fft(p_a, 1, Some(degree), offset).unwrap(); + let p_b_eval = Polynomial::evaluate_offset_fft(p_b, 1, Some(degree), offset).unwrap(); + let p_c_eval = Polynomial::evaluate_offset_fft(p_c, 1, Some(degree), offset).unwrap(); + let ql_eval = Polynomial::evaluate_offset_fft(&cpi.ql, 1, Some(degree), offset).unwrap(); + let qr_eval = Polynomial::evaluate_offset_fft(&cpi.qr, 1, Some(degree), offset).unwrap(); + let qm_eval = Polynomial::evaluate_offset_fft(&cpi.qm, 1, Some(degree), offset).unwrap(); + let qo_eval = Polynomial::evaluate_offset_fft(&cpi.qo, 1, Some(degree), offset).unwrap(); + let qc_eval = Polynomial::evaluate_offset_fft(&cpi.qc, 1, Some(degree), offset).unwrap(); + let p_pi_eval = Polynomial::evaluate_offset_fft(&p_pi, 1, Some(degree), offset).unwrap(); + let p_x_eval = Polynomial::evaluate_offset_fft(p_x, 1, Some(degree), offset).unwrap(); + let p_z_eval = Polynomial::evaluate_offset_fft(p_z, 1, Some(degree), offset).unwrap(); + let p_z_x_omega_eval = + Polynomial::evaluate_offset_fft(&z_x_omega, 1, Some(degree), offset).unwrap(); + let p_s1_eval = Polynomial::evaluate_offset_fft(&cpi.s1, 1, Some(degree), offset).unwrap(); + let p_s2_eval = Polynomial::evaluate_offset_fft(&cpi.s2, 1, Some(degree), offset).unwrap(); + let p_s3_eval = Polynomial::evaluate_offset_fft(&cpi.s3, 1, Some(degree), offset).unwrap(); + let l1_eval = Polynomial::evaluate_offset_fft(&l1, 1, Some(degree), offset).unwrap(); let p_constraints_eval: Vec<_> = p_a_eval .iter() @@ -496,7 +494,7 @@ where .map(|((p2, p1), co)| (p2 * &alpha + p1) * &alpha + co) .collect(); - let mut zh_eval = zh.evaluate_offset_fft(1, Some(degree), offset).unwrap(); + let mut zh_eval = Polynomial::evaluate_offset_fft(&zh, 1, Some(degree), offset).unwrap(); FieldElement::inplace_batch_inverse(&mut zh_eval).unwrap(); let c: Vec<_> = p_eval .iter() diff --git a/provers/plonk/src/setup.rs b/provers/plonk/src/setup.rs index 974c50e3f..93d8fc41a 100644 --- a/provers/plonk/src/setup.rs +++ b/provers/plonk/src/setup.rs @@ -5,7 +5,6 @@ use crate::test_utils::utils::{generate_domain, generate_permutation_coefficient use lambdaworks_crypto::commitments::traits::IsCommitmentScheme; use lambdaworks_crypto::fiat_shamir::default_transcript::DefaultTranscript; use lambdaworks_crypto::fiat_shamir::transcript::Transcript; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::field::traits::IsFFTField; use lambdaworks_math::field::{element::FieldElement, traits::IsField}; use lambdaworks_math::polynomial::Polynomial; @@ -86,14 +85,14 @@ impl CommonPreprocessedInput { n, omega, k1: order_r_minus_1_root_unity.clone(), - ql: Polynomial::interpolate_fft(&ql).unwrap(), // TODO: Remove unwraps - qr: Polynomial::interpolate_fft(&qr).unwrap(), - qo: Polynomial::interpolate_fft(&qo).unwrap(), - qm: Polynomial::interpolate_fft(&qm).unwrap(), - qc: Polynomial::interpolate_fft(&qc).unwrap(), - s1: Polynomial::interpolate_fft(&s1_lagrange).unwrap(), - s2: Polynomial::interpolate_fft(&s2_lagrange).unwrap(), - s3: Polynomial::interpolate_fft(&s3_lagrange).unwrap(), + ql: Polynomial::interpolate_fft::(&ql).unwrap(), // TODO: Remove unwraps + qr: Polynomial::interpolate_fft::(&qr).unwrap(), + qo: Polynomial::interpolate_fft::(&qo).unwrap(), + qm: Polynomial::interpolate_fft::(&qm).unwrap(), + qc: Polynomial::interpolate_fft::(&qc).unwrap(), + s1: Polynomial::interpolate_fft::(&s1_lagrange).unwrap(), + s2: Polynomial::interpolate_fft::(&s2_lagrange).unwrap(), + s3: Polynomial::interpolate_fft::(&s3_lagrange).unwrap(), s1_lagrange, s2_lagrange, s3_lagrange, diff --git a/provers/plonk/src/test_utils/circuit_1.rs b/provers/plonk/src/test_utils/circuit_1.rs index bc95c8962..929f73583 100644 --- a/provers/plonk/src/test_utils/circuit_1.rs +++ b/provers/plonk/src/test_utils/circuit_1.rs @@ -2,7 +2,6 @@ use super::utils::{ generate_domain, generate_permutation_coefficients, ORDER_R_MINUS_1_ROOT_UNITY, }; use crate::setup::{CommonPreprocessedInput, Witness}; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::{ elliptic_curve::short_weierstrass::curves::bls12_381::default_types::{FrElement, FrField}, field::{element::FieldElement, traits::IsFFTField}, @@ -40,7 +39,7 @@ pub fn test_common_preprocessed_input_1() -> CommonPreprocessedInput { domain, k1: ORDER_R_MINUS_1_ROOT_UNITY, // domain: domain.clone(), - ql: Polynomial::interpolate_fft(&[ + ql: Polynomial::interpolate_fft::(&[ -FieldElement::one(), -FieldElement::one(), FieldElement::zero(), @@ -48,7 +47,7 @@ pub fn test_common_preprocessed_input_1() -> CommonPreprocessedInput { ]) .unwrap(), - qr: Polynomial::interpolate_fft(&[ + qr: Polynomial::interpolate_fft::(&[ FieldElement::zero(), FieldElement::zero(), FieldElement::zero(), @@ -56,7 +55,7 @@ pub fn test_common_preprocessed_input_1() -> CommonPreprocessedInput { ]) .unwrap(), - qo: Polynomial::interpolate_fft(&[ + qo: Polynomial::interpolate_fft::(&[ FieldElement::zero(), FieldElement::zero(), -FieldElement::one(), @@ -64,7 +63,7 @@ pub fn test_common_preprocessed_input_1() -> CommonPreprocessedInput { ]) .unwrap(), - qm: Polynomial::interpolate_fft(&[ + qm: Polynomial::interpolate_fft::(&[ FieldElement::zero(), FieldElement::zero(), FieldElement::one(), @@ -72,7 +71,7 @@ pub fn test_common_preprocessed_input_1() -> CommonPreprocessedInput { ]) .unwrap(), - qc: Polynomial::interpolate_fft(&[ + qc: Polynomial::interpolate_fft::(&[ FieldElement::from(0_u64), FieldElement::from(0_u64), FieldElement::zero(), @@ -80,9 +79,9 @@ pub fn test_common_preprocessed_input_1() -> CommonPreprocessedInput { ]) .unwrap(), - s1: Polynomial::interpolate_fft(&s1_lagrange).unwrap(), - s2: Polynomial::interpolate_fft(&s2_lagrange).unwrap(), - s3: Polynomial::interpolate_fft(&s3_lagrange).unwrap(), + s1: Polynomial::interpolate_fft::(&s1_lagrange).unwrap(), + s2: Polynomial::interpolate_fft::(&s2_lagrange).unwrap(), + s3: Polynomial::interpolate_fft::(&s3_lagrange).unwrap(), s1_lagrange, s2_lagrange, diff --git a/provers/plonk/src/test_utils/circuit_json.rs b/provers/plonk/src/test_utils/circuit_json.rs index 13aea759e..43913c844 100644 --- a/provers/plonk/src/test_utils/circuit_json.rs +++ b/provers/plonk/src/test_utils/circuit_json.rs @@ -2,7 +2,6 @@ use super::utils::{ generate_domain, generate_permutation_coefficients, ORDER_R_MINUS_1_ROOT_UNITY, }; use crate::setup::{CommonPreprocessedInput, Witness}; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::field::traits::IsFFTField; use lambdaworks_math::{ elliptic_curve::short_weierstrass::curves::bls12_381::default_types::{FrElement, FrField}, @@ -62,19 +61,39 @@ pub fn common_preprocessed_input_from_json( domain, omega, k1: ORDER_R_MINUS_1_ROOT_UNITY, - ql: Polynomial::interpolate_fft(&process_vector(json_input.Ql, &FrElement::zero(), n)) - .unwrap(), - qr: Polynomial::interpolate_fft(&process_vector(json_input.Qr, &FrElement::zero(), n)) - .unwrap(), - qo: Polynomial::interpolate_fft(&process_vector(json_input.Qo, &FrElement::zero(), n)) - .unwrap(), - qm: Polynomial::interpolate_fft(&process_vector(json_input.Qm, &FrElement::zero(), n)) - .unwrap(), - qc: Polynomial::interpolate_fft(&process_vector(json_input.Qc, &FrElement::zero(), n)) - .unwrap(), - s1: Polynomial::interpolate_fft(&s1_lagrange).unwrap(), - s2: Polynomial::interpolate_fft(&s2_lagrange).unwrap(), - s3: Polynomial::interpolate_fft(&s3_lagrange).unwrap(), + ql: Polynomial::interpolate_fft::(&process_vector( + json_input.Ql, + &FrElement::zero(), + n, + )) + .unwrap(), + qr: Polynomial::interpolate_fft::(&process_vector( + json_input.Qr, + &FrElement::zero(), + n, + )) + .unwrap(), + qo: Polynomial::interpolate_fft::(&process_vector( + json_input.Qo, + &FrElement::zero(), + n, + )) + .unwrap(), + qm: Polynomial::interpolate_fft::(&process_vector( + json_input.Qm, + &FrElement::zero(), + n, + )) + .unwrap(), + qc: Polynomial::interpolate_fft::(&process_vector( + json_input.Qc, + &FrElement::zero(), + n, + )) + .unwrap(), + s1: Polynomial::interpolate_fft::(&s1_lagrange).unwrap(), + s2: Polynomial::interpolate_fft::(&s2_lagrange).unwrap(), + s3: Polynomial::interpolate_fft::(&s3_lagrange).unwrap(), s1_lagrange, s2_lagrange, s3_lagrange, diff --git a/provers/stark/src/debug.rs b/provers/stark/src/debug.rs index 1b02007a2..bede4a7a3 100644 --- a/provers/stark/src/debug.rs +++ b/provers/stark/src/debug.rs @@ -3,7 +3,6 @@ use crate::trace::TraceTable; use super::domain::Domain; use super::traits::AIR; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::{ field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, @@ -23,8 +22,7 @@ pub fn validate_trace>( let trace_columns: Vec<_> = trace_polys .iter() .map(|poly| { - poly.evaluate_fft(1, Some(domain.interpolation_domain_size)) - .unwrap() + Polynomial::evaluate_fft::(poly, 1, Some(domain.interpolation_domain_size)).unwrap() }) .collect(); @@ -34,8 +32,7 @@ pub fn validate_trace>( .get_periodic_column_polynomials() .iter() .map(|poly| { - poly.evaluate_fft(1, Some(domain.interpolation_domain_size)) - .unwrap() + Polynomial::evaluate_fft::(poly, 1, Some(domain.interpolation_domain_size)).unwrap() }) .collect(); diff --git a/provers/stark/src/fri/mod.rs b/provers/stark/src/fri/mod.rs index e562f6bd9..624e16edf 100644 --- a/provers/stark/src/fri/mod.rs +++ b/provers/stark/src/fri/mod.rs @@ -3,7 +3,6 @@ pub mod fri_decommit; mod fri_functions; use lambdaworks_math::fft::cpu::bit_reversing::in_place_bit_reverse_permute; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::field::traits::IsFFTField; use lambdaworks_math::traits::Serializable; pub use lambdaworks_math::{ @@ -119,9 +118,8 @@ where F: IsFFTField, FieldElement: Serializable + Sync + Send, { - let mut evaluation = poly - .evaluate_offset_fft(1, Some(domain_size), coset_offset) - .unwrap(); // TODO: return error + let mut evaluation = + Polynomial::evaluate_offset_fft(poly, 1, Some(domain_size), coset_offset).unwrap(); // TODO: return error in_place_bit_reverse_permute(&mut evaluation); diff --git a/provers/stark/src/prover.rs b/provers/stark/src/prover.rs index bc8cd9449..86dcaa99e 100644 --- a/provers/stark/src/prover.rs +++ b/provers/stark/src/prover.rs @@ -3,7 +3,7 @@ use std::time::Instant; use lambdaworks_crypto::merkle_tree::proof::Proof; use lambdaworks_math::fft::cpu::bit_reversing::{in_place_bit_reverse_permute, reverse_index}; -use lambdaworks_math::fft::{errors::FFTError, polynomial::FFTPoly}; +use lambdaworks_math::fft::errors::FFTError; use lambdaworks_math::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField; use lambdaworks_math::traits::Serializable; use lambdaworks_math::{ @@ -88,9 +88,8 @@ pub fn evaluate_polynomial_on_lde_domain( ) -> Result>, FFTError> where F: IsFFTField, - Polynomial>: FFTPoly, { - let evaluations = p.evaluate_offset_fft(blowup_factor, Some(domain_size), offset)?; + let evaluations = Polynomial::evaluate_offset_fft(p, blowup_factor, Some(domain_size), offset)?; let step = evaluations.len() / (domain_size * blowup_factor); match step { 1 => Ok(evaluations), diff --git a/provers/stark/src/trace.rs b/provers/stark/src/trace.rs index 7c8d7a333..2a05c42ed 100644 --- a/provers/stark/src/trace.rs +++ b/provers/stark/src/trace.rs @@ -1,6 +1,5 @@ use crate::table::{Table, TableView}; use lambdaworks_math::fft::errors::FFTError; -use lambdaworks_math::fft::polynomial::FFTPoly; use lambdaworks_math::{ field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, @@ -121,7 +120,7 @@ impl<'t, F: IsFFTField> TraceTable { #[cfg(not(feature = "parallel"))] let iter = columns.iter(); - iter.map(|col| Polynomial::interpolate_fft(col)) + iter.map(|col| Polynomial::interpolate_fft::(col)) .collect::>>, FFTError>>() .unwrap() } diff --git a/provers/stark/src/traits.rs b/provers/stark/src/traits.rs index 4755fc12e..005332a0d 100644 --- a/provers/stark/src/traits.rs +++ b/provers/stark/src/traits.rs @@ -1,6 +1,6 @@ use itertools::Itertools; use lambdaworks_math::{ - fft::{cpu::roots_of_unity::get_powers_of_primitive_root_coset, polynomial::FFTPoly}, + fft::cpu::roots_of_unity::get_powers_of_primitive_root_coset, field::{element::FieldElement, traits::IsFFTField}, polynomial::Polynomial, }; @@ -137,7 +137,7 @@ pub trait AIR { .take(self.trace_length()) .cloned() .collect(); - let poly = Polynomial::interpolate_fft(&values).unwrap(); + let poly = Polynomial::interpolate_fft::(&values).unwrap(); result.push(poly); } result From 606319702d15482877a05f1c47f98ca9161130b9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Garassino?= Date: Tue, 12 Dec 2023 12:29:03 -0300 Subject: [PATCH 2/2] Miden adapter (#681) * Add winterfell adapter in other crate * Add README * Better README * Reestructure files * Remove unnecessary code * todo's at the end * Reorder code * Clippy * cargo format * Improve readme * remove from workspace dependencies * Fix * Fixes * Fix for new frame * Many changes for a proof of concept of miden compatibility * Add periodic columns and a Fibonacci example using a periodic column * Format and better code * Clippy * Cargo fmt * Adding periodic values, adding degree bound for composition polynomial * Add num_transition_exemptions to air adapter * Fix in compute transitions of adapter and add tests for fibnacci miden/cubic air * Fix bug in constraint indexing. Cargo fmt and clippy. * Benches * Use same proof options for winterfell and lambdaworks in benchmark * Update fibonacci size, add benches to readme * put winterfell and miden as optional dependencies in the math and provers lambdaworks packages * add metadata to fromcolumns * add metadata to air adapter * Transcript is not a fixed number * remove trace from air adapter public inputs * move files to examples * move winterfell field to fields dir * use winterfell fork * fix dependencies to use forked miden and winterfell * Delete provers/stark/src/examples/fibonacci_periodic_cols.rs * Update mod.rs * move tests to correct files * fmt * fix benches --------- Co-authored-by: Sergio Chouhy Co-authored-by: Sergio Chouhy <41742639+schouhy@users.noreply.github.com> --- Cargo.toml | 11 + math/Cargo.toml | 4 + math/src/field/fields/mod.rs | 5 + math/src/field/fields/winterfell.rs | 120 ++++++++++ math/src/field/mod.rs | 3 +- provers/stark/Cargo.toml | 2 + provers/stark/src/verifier.rs | 7 +- winterfell_adapter/Cargo.toml | 26 ++- winterfell_adapter/README.md | 13 ++ winterfell_adapter/benches/proving.rs | 118 ++++++++++ winterfell_adapter/src/adapter/air.rs | 216 +++++++----------- winterfell_adapter/src/adapter/mod.rs | 62 +++++ .../src/adapter/public_inputs.rs | 33 ++- winterfell_adapter/src/examples/cubic.rs | 120 ++++++++++ .../src/examples/fibonacci_2_terms.rs | 76 ++++-- .../src/examples/fibonacci_rap.rs | 99 ++++++-- winterfell_adapter/src/examples/miden_vm.rs | 191 ++++++++++++++++ winterfell_adapter/src/examples/mod.rs | 2 + .../src/field_element/element.rs | 13 +- winterfell_adapter/src/utils.rs | 43 ++-- 20 files changed, 961 insertions(+), 203 deletions(-) create mode 100644 math/src/field/fields/winterfell.rs create mode 100644 winterfell_adapter/benches/proving.rs create mode 100644 winterfell_adapter/src/examples/cubic.rs create mode 100644 winterfell_adapter/src/examples/miden_vm.rs diff --git a/Cargo.toml b/Cargo.toml index dbcf7cad7..611dc73ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,6 +17,17 @@ lambdaworks-math = { path = "./math", version = "0.3.0" } stark-platinum-prover = { path = "./provers/stark", version = "0.3.0" } cairo-platinum-prover = { path = "./provers/cairo", version = "0.3.0" } +[patch.crates-io] +winter-air = { git = "https://github.com/lambdaclass/winterfell-for-lambdaworks.git", branch = "derive-clone-v6.4"} +winter-prover = { git = "https://github.com/lambdaclass/winterfell-for-lambdaworks.git", branch = "derive-clone-v6.4"} +winter-math = { git = "https://github.com/lambdaclass/winterfell-for-lambdaworks.git", branch = "derive-clone-v6.4"} +winter-utils = { git = "https://github.com/lambdaclass/winterfell-for-lambdaworks.git", branch = "derive-clone-v6.4"} +winter-crypto = { git = "https://github.com/lambdaclass/winterfell-for-lambdaworks.git", branch = "derive-clone-v6.4"} +miden-air = { git = "https://github.com/lambdaclass/miden-vm" } +miden-core = { git = "https://github.com/lambdaclass/miden-vm" } +miden-assembly = { git = "https://github.com/lambdaclass/miden-vm" } +miden-processor = { git = "https://github.com/lambdaclass/miden-vm" } + [profile.bench] lto = true codegen-units = 1 diff --git a/math/Cargo.toml b/math/Cargo.toml index 7860acea7..287ebd72c 100644 --- a/math/Cargo.toml +++ b/math/Cargo.toml @@ -12,6 +12,8 @@ thiserror = { version = "1.0", optional = true } serde = { version = "1.0", features = ["derive"], optional = true } serde_json = { version = "1.0", optional = true } proptest = { version = "1.1.0", optional = true } +winter-math = { package = "winter-math", version = "0.6.4", default-features = false, optional = true } +miden-core = { package = "miden-core" , version = "0.7", default-features = false, optional = true } # rayon rayon = { version = "1.7", optional = true } @@ -41,6 +43,7 @@ std = ["dep:thiserror"] lambdaworks-serde-binary = ["dep:serde", "std"] lambdaworks-serde-string = ["dep:serde", "dep:serde_json", "std"] proptest = ["dep:proptest"] +winter_compatibility = ["winter-math", "miden-core"] # gpu metal = [ @@ -88,3 +91,4 @@ harness = false name = "criterion_metal" harness = false required-features = ["metal"] + diff --git a/math/src/field/fields/mod.rs b/math/src/field/fields/mod.rs index 0bcbd2868..d9d923d6c 100644 --- a/math/src/field/fields/mod.rs +++ b/math/src/field/fields/mod.rs @@ -11,5 +11,10 @@ pub mod pallas_field; pub mod u64_goldilocks_field; /// Implementation of prime fields over 64 bit unsigned integers. pub mod u64_prime_field; + +/// Winterfell and miden field compatibility +#[cfg(feature = "winter_compatibility")] +pub mod winterfell; + /// Implemenation of Vesta Prime field (p = 2^254 + 45560315531506369815346746415080538113) mod vesta_field; diff --git a/math/src/field/fields/winterfell.rs b/math/src/field/fields/winterfell.rs new file mode 100644 index 000000000..8d06f7347 --- /dev/null +++ b/math/src/field/fields/winterfell.rs @@ -0,0 +1,120 @@ +use crate::{ + errors::ByteConversionError, + field::{ + element::FieldElement, + errors::FieldError, + traits::{IsFFTField, IsField, IsPrimeField}, + }, + traits::{ByteConversion, Serializable}, + unsigned_integer::element::U256, +}; +pub use miden_core::Felt; +pub use winter_math::fields::f128::BaseElement; +use winter_math::{FieldElement as IsWinterfellFieldElement, StarkField}; + +impl IsFFTField for Felt { + const TWO_ADICITY: u64 = ::TWO_ADICITY as u64; + const TWO_ADIC_PRIMITVE_ROOT_OF_UNITY: Self::BaseType = Felt::TWO_ADIC_ROOT_OF_UNITY; +} + +impl IsPrimeField for Felt { + type RepresentativeType = U256; + + fn representative(_a: &Self::BaseType) -> Self::RepresentativeType { + todo!() + } + + fn from_hex(_hex_string: &str) -> Result { + todo!() + } + + fn field_bit_size() -> usize { + 128 // TODO + } +} + +impl IsField for Felt { + type BaseType = Felt; + + fn add(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + *a + *b + } + + fn mul(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + *a * *b + } + + fn sub(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + *a - *b + } + + fn neg(a: &Self::BaseType) -> Self::BaseType { + -*a + } + + fn inv(a: &Self::BaseType) -> Result { + Ok((*a).inv()) + } + + fn div(a: &Self::BaseType, b: &Self::BaseType) -> Self::BaseType { + *a / *b + } + + fn eq(a: &Self::BaseType, b: &Self::BaseType) -> bool { + *a == *b + } + + fn zero() -> Self::BaseType { + Self::BaseType::ZERO + } + + fn one() -> Self::BaseType { + Self::BaseType::ONE + } + + fn from_u64(x: u64) -> Self::BaseType { + Self::BaseType::from(x) + } + + fn from_base_type(x: Self::BaseType) -> Self::BaseType { + x + } +} + +impl Serializable for FieldElement { + fn serialize(&self) -> Vec { + Felt::elements_as_bytes(&[*self.value()]).to_vec() + } +} + +impl ByteConversion for Felt { + fn to_bytes_be(&self) -> Vec { + Felt::elements_as_bytes(&[*self]).to_vec() + } + + fn to_bytes_le(&self) -> Vec { + Felt::elements_as_bytes(&[*self]).to_vec() + } + + fn from_bytes_be(bytes: &[u8]) -> Result + where + Self: Sized, + { + unsafe { + let res = Felt::bytes_as_elements(bytes) + .map_err(|_| ByteConversionError::FromBEBytesError)?; + Ok(res[0]) + } + } + + fn from_bytes_le(bytes: &[u8]) -> Result + where + Self: Sized, + { + unsafe { + let res = Felt::bytes_as_elements(bytes) + .map_err(|_| ByteConversionError::FromBEBytesError)?; + Ok(res[0]) + } + } +} diff --git a/math/src/field/mod.rs b/math/src/field/mod.rs index 2fe6c4272..5cb5f2506 100644 --- a/math/src/field/mod.rs +++ b/math/src/field/mod.rs @@ -1,5 +1,6 @@ /// Implementation of FieldElement, a generic element of a field. pub mod element; +pub mod errors; /// Implementation of quadratic extensions of fields. pub mod extensions; /// Implementation of particular cases of fields. @@ -8,5 +9,3 @@ pub mod fields; pub mod test_fields; /// Common behaviour for field elements. pub mod traits; - -pub mod errors; diff --git a/provers/stark/Cargo.toml b/provers/stark/Cargo.toml index aa1f5c683..033e1a9c9 100644 --- a/provers/stark/Cargo.toml +++ b/provers/stark/Cargo.toml @@ -11,6 +11,7 @@ crate-type = ["cdylib", "rlib"] [dependencies] lambdaworks-math = { workspace = true , features = ["lambdaworks-serde-binary"] } lambdaworks-crypto.workspace = true +miden-core = { git="https://github.com/lambdaclass/miden-vm", optional=true} rand = "0.8.5" thiserror = "1.0.38" @@ -47,6 +48,7 @@ instruments = [] # This enables timing prints in prover and ve metal = ["lambdaworks-math/metal"] parallel = ["dep:rayon", "lambdaworks-crypto/parallel"] wasm = ["dep:wasm-bindgen", "dep:serde-wasm-bindgen", "dep:web-sys"] +winter_compatibility = ["miden-core"] [target.'cfg(not(all(target_arch = "wasm32", target_os = "unknown")))'.dev-dependencies] proptest = "1.2.0" diff --git a/provers/stark/src/verifier.rs b/provers/stark/src/verifier.rs index c498578d0..bf2ec9c85 100644 --- a/provers/stark/src/verifier.rs +++ b/provers/stark/src/verifier.rs @@ -6,6 +6,9 @@ use lambdaworks_crypto::merkle_tree::proof::Proof; #[cfg(not(feature = "test_fiat_shamir"))] use log::error; +use crate::{ + config::Commitment, proof::stark::DeepPolynomialOpening, transcript::IsStarkTranscript, +}; use lambdaworks_math::{ fft::cpu::bit_reversing::reverse_index, field::{ @@ -15,10 +18,6 @@ use lambdaworks_math::{ traits::Serializable, }; -use crate::{ - config::Commitment, proof::stark::DeepPolynomialOpening, transcript::IsStarkTranscript, -}; - use super::{ config::BatchedMerkleTreeBackend, domain::Domain, diff --git a/winterfell_adapter/Cargo.toml b/winterfell_adapter/Cargo.toml index 18d40a14a..0dcf68df3 100644 --- a/winterfell_adapter/Cargo.toml +++ b/winterfell_adapter/Cargo.toml @@ -5,8 +5,26 @@ edition.workspace = true license.workspace = true [dependencies] -lambdaworks-math = { path = "../math" } -stark-platinum-prover = { path = "../provers/stark" } -winterfell = { git = "https://github.com/facebook/winterfell" } -winter-utils = { git = "https://github.com/facebook/winterfell" } +lambdaworks-math = { path = "../math", features=["winter_compatibility"] } +stark-platinum-prover = { path = "../provers/stark" , features=["winter_compatibility"]} rand = "0.8.5" +winter-air = { package = "winter-air", version = "0.6.4", default-features = false } +winter-prover = { package = "winter-prover", version = "0.6.4", default-features = false } +winter-math = { package = "winter-math", version = "0.6.4", default-features = false } +winter-utils = { package = "winter-utils", version = "0.6.4", default-features = false } +miden-air = { package = "miden-air", version = "0.7", default-features = false } +miden-core = { package = "miden-core" , version = "0.7", default-features = false } +miden-assembly = { package = "miden-assembly", version = "0.7", default-features = false } +miden-processor = { package = "miden-processor", version = "0.7", default-features = false } +sha3 = "0.10" + + +[dev-dependencies] +criterion = { version = "0.4", default-features = false } +miden-prover = { package = "miden-prover", version = "0.7", default-features = false } + + +[[bench]] +name = "proving" +harness = false + diff --git a/winterfell_adapter/README.md b/winterfell_adapter/README.md index 83177559a..03e6e11f6 100644 --- a/winterfell_adapter/README.md +++ b/winterfell_adapter/README.md @@ -51,3 +51,16 @@ let proof = Prover::prove::>>( ``` Here `TraceTable` is the Winterfell type that represents your trace table. To check more examples you can see the `examples` folder inside this crate. + +# Benchmarks +To run the fibonacci Miden benchmark run: + +```rust +cargo bench +``` + +To run it with parallelization run: + +```rust +cargo bench --features stark-platinum-prover/parallel,winter-prover/concurrent +``` diff --git a/winterfell_adapter/benches/proving.rs b/winterfell_adapter/benches/proving.rs new file mode 100644 index 000000000..6a828d671 --- /dev/null +++ b/winterfell_adapter/benches/proving.rs @@ -0,0 +1,118 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use miden_air::{HashFunction, ProcessorAir, ProvingOptions, PublicInputs}; +use miden_assembly::Assembler; +use miden_core::{Felt, Program, StackInputs}; +use miden_processor::DefaultHost; +use miden_processor::{self as processor}; +use miden_prover::prove; +use processor::ExecutionTrace; +use stark_platinum_prover::{proof::options::ProofOptions, prover::IsStarkProver}; +use winter_air::FieldExtension; +use winter_prover::Trace; +use winterfell_adapter::adapter::public_inputs::AirAdapterPublicInputs; +use winterfell_adapter::adapter::{air::AirAdapter, Prover, Transcript}; +use winterfell_adapter::examples::miden_vm::ExecutionTraceMetadata; + +struct BenchInstance { + program: Program, + stack_inputs: StackInputs, + lambda_proof_options: ProofOptions, +} + +fn create_bench_instance(fibonacci_number: usize) -> BenchInstance { + let program = format!( + "begin + repeat.{} + swap dup.1 add + end + end", + fibonacci_number - 1 + ); + let program = Assembler::default().compile(program).unwrap(); + let stack_inputs = StackInputs::try_from_values([0, 1]).unwrap(); + let mut lambda_proof_options = ProofOptions::default_test_options(); + lambda_proof_options.blowup_factor = 8; + + BenchInstance { + program, + stack_inputs, + lambda_proof_options, + } +} + +pub fn bench_prove_miden_fibonacci(c: &mut Criterion) { + let instance = create_bench_instance(100); + + c.bench_function("winterfell_prover", |b| { + b.iter(|| { + let proving_options = ProvingOptions::new( + instance.lambda_proof_options.fri_number_of_queries, + instance.lambda_proof_options.blowup_factor as usize, + instance.lambda_proof_options.grinding_factor as u32, + FieldExtension::None, + 2, + 0, + HashFunction::Blake3_192, + ); + + let (_outputs, _proof) = black_box( + prove( + &instance.program, + instance.stack_inputs.clone(), + DefaultHost::default(), + proving_options, + ) + .unwrap(), + ); + }) + }); + + c.bench_function("lambda_prover", |b| { + b.iter(|| { + // This is here because the only pub method in miden + // is a prove function that executes AND proves. + // This makes the benchmark a more fair + // in the case that the program execution takes + // too long. + let winter_trace = processor::execute( + &instance.program, + instance.stack_inputs.clone(), + DefaultHost::default(), + *ProvingOptions::default().execution_options(), + ) + .unwrap(); + + let program_info = winter_trace.program_info().clone(); + let stack_outputs = winter_trace.stack_outputs().clone(); + let pub_inputs = AirAdapterPublicInputs::::new( + PublicInputs::new( + program_info, + instance.stack_inputs.clone(), + stack_outputs.clone(), + ), + vec![2; 182], + vec![0, 1], + winter_trace.get_info(), + winter_trace.clone().into(), + ); + + let trace = + AirAdapter::::convert_winterfell_trace_table( + winter_trace.main_segment().clone(), + ); + + let _proof = black_box( + Prover::prove::>( + &trace, + &pub_inputs, + &instance.lambda_proof_options, + Transcript::new(&[]), + ) + .unwrap(), + ); + }) + }); +} + +criterion_group!(benches, bench_prove_miden_fibonacci); +criterion_main!(benches); diff --git a/winterfell_adapter/src/adapter/air.rs b/winterfell_adapter/src/adapter/air.rs index 9ca8014fe..1aaecc3a9 100644 --- a/winterfell_adapter/src/adapter/air.rs +++ b/winterfell_adapter/src/adapter/air.rs @@ -1,69 +1,87 @@ -use crate::field_element::element::AdapterFieldElement; -use crate::utils::{matrix_adapter2field, matrix_field2adapter, vec_field2adapter}; -use lambdaworks_math::field::{ - element::FieldElement, fields::fft_friendly::stark_252_prime_field::Stark252PrimeField, +use crate::utils::{ + matrix_lambda2winter, matrix_winter2lambda, vec_lambda2winter, vec_winter2lambda, }; +use lambdaworks_math::field::element::FieldElement; +use lambdaworks_math::field::traits::{IsFFTField, IsField}; +use lambdaworks_math::traits::ByteConversion; +use miden_core::Felt; use stark_platinum_prover::{ constraints::boundary::{BoundaryConstraint, BoundaryConstraints}, traits::AIR, }; use std::marker::PhantomData; -use winterfell::{ - Air, AuxTraceRandElements, EvaluationFrame, FieldExtension, ProofOptions, Trace, TraceTable, -}; +use winter_air::{Air, AuxTraceRandElements, EvaluationFrame, FieldExtension, ProofOptions}; +use winter_math::{FieldElement as IsWinterfellFieldElement, StarkField}; +use winter_prover::{ColMatrix, Trace, TraceTable}; use super::public_inputs::AirAdapterPublicInputs; -pub trait FromColumns { - fn from_cols(columns: Vec>) -> Self; +pub trait FromColumns { + fn from_cols(columns: Vec>, metadata: &M) -> Self; } -impl FromColumns for TraceTable { - fn from_cols(columns: Vec>) -> Self { +impl FromColumns for TraceTable { + fn from_cols(columns: Vec>, _: &()) -> Self { TraceTable::init(columns) } } #[derive(Clone)] -pub struct AirAdapter +pub struct AirAdapter where - A: Air, + FE: IsWinterfellFieldElement + StarkField + ByteConversion + Unpin + IsFFTField, + A: Air, A::PublicInputs: Clone, - T: Trace + Clone + FromColumns, + T: Trace + Clone + FromColumns, + M: Clone, { winterfell_air: A, - public_inputs: AirAdapterPublicInputs, + public_inputs: AirAdapterPublicInputs, air_context: stark_platinum_prover::context::AirContext, phantom: PhantomData, } -impl AirAdapter +impl AirAdapter where - A: Air + Clone, + FE: IsWinterfellFieldElement + + StarkField + + ByteConversion + + Unpin + + IsFFTField + + IsField, + A: Air + Clone, A::PublicInputs: Clone, - T: Trace + Clone + FromColumns, + T: Trace + Clone + FromColumns, + M: Clone, { pub fn convert_winterfell_trace_table( - trace: TraceTable, - ) -> stark_platinum_prover::trace::TraceTable { + trace: ColMatrix, + ) -> stark_platinum_prover::trace::TraceTable { let mut columns = Vec::new(); - for i in 0..trace.width() { + for i in 0..trace.num_cols() { columns.push(trace.get_column(i).to_owned()); } - stark_platinum_prover::trace::TraceTable::from_columns(matrix_adapter2field(&columns), 1) + stark_platinum_prover::trace::TraceTable::from_columns(matrix_winter2lambda(&columns), 1) } } -impl AIR for AirAdapter +impl AIR for AirAdapter where - A: Air + Clone, + FE: IsWinterfellFieldElement + + StarkField + + ByteConversion + + Unpin + + IsFFTField + + IsField, + A: Air + Clone, A::PublicInputs: Clone, - T: Trace + Clone + FromColumns, + T: Trace + Clone + FromColumns, + M: Clone, { - type Field = Stark252PrimeField; - type RAPChallenges = Vec; - type PublicInputs = AirAdapterPublicInputs; + type Field = FE; + type RAPChallenges = Vec; + type PublicInputs = AirAdapterPublicInputs; const STEP_SIZE: usize = 1; fn new( @@ -109,19 +127,22 @@ where rap_challenges: &Self::RAPChallenges, ) -> stark_platinum_prover::trace::TraceTable { // We support at most a one-stage RAP. This covers most use cases. - if let Some(winter_trace) = T::from_cols(matrix_field2adapter(&main_trace.columns())) - .build_aux_segment(&[], rap_challenges) + if let Some(winter_trace) = T::from_cols( + matrix_lambda2winter(&main_trace.columns()), + &self.pub_inputs().metadata, + ) + .build_aux_segment(&[], rap_challenges) { let mut columns = Vec::new(); for i in 0..winter_trace.num_cols() { columns.push(winter_trace.get_column(i).to_owned()); } stark_platinum_prover::trace::TraceTable::from_columns( - matrix_adapter2field(&columns), + matrix_winter2lambda(&columns), 1, ) } else { - stark_platinum_prover::trace::TraceTable::empty() + stark_platinum_prover::trace::TraceTable::::empty() } } @@ -137,7 +158,7 @@ where for _ in 0..trace_layout.get_aux_segment_rand_elements(0) { result.push(transcript.sample_field_element()); } - vec_field2adapter(&result) + vec_lambda2winter(&result) } else if num_segments == 0 { Vec::new() } else { @@ -150,13 +171,16 @@ where } fn composition_poly_degree_bound(&self) -> usize { - self.public_inputs.composition_poly_degree_bound + self.winterfell_air + .context() + .num_constraint_composition_columns() + * self.trace_length() } fn compute_transition( &self, frame: &stark_platinum_prover::frame::Frame, - _periodic_values: &[FieldElement], + periodic_values: &[FieldElement], rap_challenges: &Self::RAPChallenges, ) -> Vec> { let num_aux_columns = self.number_auxiliary_rap_columns(); @@ -166,22 +190,27 @@ where let second_step = frame.get_evaluation_step(1); let main_frame = EvaluationFrame::from_rows( - vec_field2adapter(&first_step.get_row(0)[..num_main_columns]), - vec_field2adapter(&second_step.get_row(0)[..num_main_columns]), + vec_lambda2winter(&first_step.get_row(0)[..num_main_columns]), + vec_lambda2winter(&second_step.get_row(0)[..num_main_columns]), ); + let periodic_values = vec_lambda2winter(periodic_values); + let mut main_result = vec![ FieldElement::zero(); self.winterfell_air .context() .num_main_transition_constraints() ]; - self.winterfell_air - .evaluate_transition::( - &main_frame, - &[], - &mut vec_field2adapter(&main_result), - ); // Periodic values not supported + + let mut main_result_winter = vec_lambda2winter(&main_result); + self.winterfell_air.evaluate_transition::( + &main_frame, + &periodic_values, + &mut main_result_winter, + ); // Periodic values not supported + + main_result = vec_winter2lambda(&main_result_winter); if self.winterfell_air.trace_layout().num_aux_segments() == 1 { let mut rand_elements = AuxTraceRandElements::new(); @@ -191,24 +220,25 @@ where let second_step = frame.get_evaluation_step(1); let aux_frame = EvaluationFrame::from_rows( - vec_field2adapter(&first_step.get_row(0)[num_main_columns..]), - vec_field2adapter(&second_step.get_row(0)[num_main_columns..]), + vec_lambda2winter(&first_step.get_row(0)[num_main_columns..]), + vec_lambda2winter(&second_step.get_row(0)[num_main_columns..]), ); - let aux_result = vec![ + let mut aux_result = vec![ FieldElement::zero(); self.winterfell_air .context() .num_aux_transition_constraints() ]; + let mut winter_aux_result = vec_lambda2winter(&aux_result); self.winterfell_air.evaluate_aux_transition( &main_frame, &aux_frame, - &[], + &periodic_values, &rand_elements, - &mut vec_field2adapter(&aux_result), + &mut winter_aux_result, ); - + aux_result = vec_winter2lambda(&winter_aux_result); main_result.extend_from_slice(&aux_result); } main_result @@ -217,14 +247,17 @@ where fn boundary_constraints( &self, rap_challenges: &Self::RAPChallenges, - ) -> stark_platinum_prover::constraints::boundary::BoundaryConstraints { + ) -> stark_platinum_prover::constraints::boundary::BoundaryConstraints { + let num_aux_columns = self.number_auxiliary_rap_columns(); + let num_main_columns = self.context().trace_columns - num_aux_columns; + let mut result = Vec::new(); for assertion in self.winterfell_air.get_assertions() { assert!(assertion.is_single()); result.push(BoundaryConstraint::new( assertion.column(), assertion.first_step(), - assertion.values()[0].0, + FieldElement::::const_from_raw(assertion.values()[0]), )); } @@ -234,9 +267,9 @@ where for assertion in self.winterfell_air.get_aux_assertions(&rand_elements) { assert!(assertion.is_single()); result.push(BoundaryConstraint::new( - assertion.column(), + assertion.column() + num_main_columns, assertion.first_step(), - assertion.values()[0].0, + FieldElement::::const_from_raw(assertion.values()[0]), )); } @@ -254,81 +287,8 @@ where fn pub_inputs(&self) -> &Self::PublicInputs { &self.public_inputs } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::examples::fibonacci_2_terms::{self, FibAir2Terms}; - use crate::examples::fibonacci_rap::{self, FibonacciRAP, RapTraceTable}; - use stark_platinum_prover::{ - proof::options::ProofOptions, - prover::{IsStarkProver, Prover}, - transcript::StoneProverTranscript, - verifier::{IsStarkVerifier, Verifier}, - }; - use winterfell::{TraceInfo, TraceLayout}; - - #[test] - fn prove_and_verify_a_winterfell_fibonacci_2_terms_air() { - let lambda_proof_options = ProofOptions::default_test_options(); - let trace = AirAdapter::>::convert_winterfell_trace_table( - fibonacci_2_terms::build_trace(16), - ); - let pub_inputs = AirAdapterPublicInputs { - winterfell_public_inputs: AdapterFieldElement(trace.columns()[1][7]), - transition_exemptions: vec![1, 1], - transition_offsets: vec![0, 1], - composition_poly_degree_bound: 8, - trace_info: TraceInfo::new(2, 8), - }; - - let proof = Prover::prove::>>( - &trace, - &pub_inputs, - &lambda_proof_options, - StoneProverTranscript::new(&[]), - ) - .unwrap(); - assert!(Verifier::verify::>>( - &proof, - &pub_inputs, - &lambda_proof_options, - StoneProverTranscript::new(&[]), - )); - } - #[test] - fn prove_and_verify_a_winterfell_fibonacci_rap_air() { - let lambda_proof_options = ProofOptions::default_test_options(); - let trace = AirAdapter::>::convert_winterfell_trace_table( - fibonacci_rap::build_trace(16), - ); - let trace_layout = TraceLayout::new(3, [1], [1]); - let trace_info = TraceInfo::new_multi_segment(trace_layout, 16, vec![]); - let fibonacci_result = trace.columns()[1][15]; - let pub_inputs = AirAdapterPublicInputs { - winterfell_public_inputs: AdapterFieldElement(fibonacci_result), - transition_exemptions: vec![1, 1, 1], - transition_offsets: vec![0, 1], - composition_poly_degree_bound: 32, - trace_info, - }; - - let proof = Prover::prove::>>( - &trace, - &pub_inputs, - &lambda_proof_options, - StoneProverTranscript::new(&[]), - ) - .unwrap(); - assert!( - Verifier::verify::>>( - &proof, - &pub_inputs, - &lambda_proof_options, - StoneProverTranscript::new(&[]), - ) - ); + fn get_periodic_column_values(&self) -> Vec>> { + matrix_winter2lambda(&self.winterfell_air.get_periodic_column_values()) } } diff --git a/winterfell_adapter/src/adapter/mod.rs b/winterfell_adapter/src/adapter/mod.rs index 3987d9706..5975ea734 100644 --- a/winterfell_adapter/src/adapter/mod.rs +++ b/winterfell_adapter/src/adapter/mod.rs @@ -1,2 +1,64 @@ +use lambdaworks_math::traits::ByteConversion; +use miden_core::Felt; +use sha3::{Digest, Keccak256}; +use stark_platinum_prover::{ + fri::FieldElement, prover::IsStarkProver, transcript::IsStarkTranscript, + verifier::IsStarkVerifier, +}; +use winter_math::StarkField; + pub mod air; pub mod public_inputs; + +pub struct Prover; +impl IsStarkProver for Prover { + type Field = Felt; +} + +pub struct Verifier {} +impl IsStarkVerifier for Verifier { + type Field = Felt; +} + +pub struct Transcript { + hasher: Keccak256, +} + +impl Transcript { + pub fn new(data: &[u8]) -> Self { + let mut res = Self { + hasher: Keccak256::new(), + }; + res.append_bytes(data); + res + } +} + +impl IsStarkTranscript for Transcript { + fn append_field_element(&mut self, element: &FieldElement) { + self.append_bytes(&element.value().to_bytes_be()); + } + + fn append_bytes(&mut self, new_bytes: &[u8]) { + self.hasher.update(&mut new_bytes.to_owned()); + } + + fn state(&self) -> [u8; 32] { + self.hasher.clone().finalize().into() + } + + fn sample_field_element(&mut self) -> FieldElement { + let mut bytes = self.state()[..8].try_into().unwrap(); + let mut x = u64::from_be_bytes(bytes); + while x >= Felt::MODULUS { + self.append_bytes(&bytes); + bytes = self.state()[..8].try_into().unwrap(); + x = u64::from_be_bytes(bytes); + } + FieldElement::const_from_raw(Felt::new(x)) + } + + fn sample_u64(&mut self, upper_bound: u64) -> u64 { + u64::from_be_bytes(self.state()[..8].try_into().unwrap()) % upper_bound + } +} diff --git a/winterfell_adapter/src/adapter/public_inputs.rs b/winterfell_adapter/src/adapter/public_inputs.rs index fc463ebb5..fc5e49ba9 100644 --- a/winterfell_adapter/src/adapter/public_inputs.rs +++ b/winterfell_adapter/src/adapter/public_inputs.rs @@ -1,15 +1,38 @@ -use crate::field_element::element::AdapterFieldElement; -use winterfell::{Air, TraceInfo}; +use winter_air::{Air, TraceInfo}; #[derive(Clone)] -pub struct AirAdapterPublicInputs +pub struct AirAdapterPublicInputs where - A: Air, + A: Air, A::PublicInputs: Clone, + M: Clone, { pub(crate) winterfell_public_inputs: A::PublicInputs, pub(crate) transition_exemptions: Vec, pub(crate) transition_offsets: Vec, pub(crate) trace_info: TraceInfo, - pub(crate) composition_poly_degree_bound: usize, + pub(crate) metadata: M, +} + +impl AirAdapterPublicInputs +where + A: Air, + A::PublicInputs: Clone, + M: Clone, +{ + pub fn new( + winterfell_public_inputs: A::PublicInputs, + transition_exemptions: Vec, + transition_offsets: Vec, + trace_info: TraceInfo, + metadata: M, + ) -> Self { + Self { + winterfell_public_inputs, + transition_exemptions, + transition_offsets, + trace_info, + metadata, + } + } } diff --git a/winterfell_adapter/src/examples/cubic.rs b/winterfell_adapter/src/examples/cubic.rs new file mode 100644 index 000000000..ea8320c50 --- /dev/null +++ b/winterfell_adapter/src/examples/cubic.rs @@ -0,0 +1,120 @@ +use miden_core::Felt; +use winter_air::{ + Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo, + TransitionConstraintDegree, +}; +use winter_math::FieldElement as IsWinterfellFieldElement; +use winter_prover::TraceTable; + +/// A fibonacci winterfell AIR example. Two terms are computed +/// at each step. This was taken from the original winterfell +/// repository and adapted to work with lambdaworks. +#[derive(Clone)] +pub struct Cubic { + context: AirContext, + result: Felt, +} + +impl Air for Cubic { + type BaseField = Felt; + type PublicInputs = Felt; + + fn new(trace_info: TraceInfo, pub_inputs: Self::BaseField, options: ProofOptions) -> Self { + let degrees = vec![TransitionConstraintDegree::new(3)]; + Cubic { + context: AirContext::new(trace_info, degrees, 2, options), + result: pub_inputs, + } + } + + fn context(&self) -> &AirContext { + &self.context + } + + fn evaluate_transition>( + &self, + frame: &EvaluationFrame, + _periodic_values: &[E], + result: &mut [E], + ) { + let current = frame.current(); + let next = frame.next(); + + // s_i = (s_{i-1})³ + result[0] = next[0] - (current[0] * current[0] * current[0]); + } + + fn get_assertions(&self) -> Vec> { + // A valid Fibonacci sequence should start with two ones and terminate with + // the expected result + let last_step = self.trace_length() - 1; + vec![ + Assertion::single(0, 0, Self::BaseField::from(2u16)), + Assertion::single(0, last_step, self.result), + ] + } +} + +pub fn build_trace(sequence_length: usize) -> TraceTable { + assert!( + sequence_length.is_power_of_two(), + "sequence length must be a power of 2" + ); + + let mut accum = Felt::from(2u16); + let mut column = vec![accum]; + while column.len() < sequence_length { + accum = accum * accum * accum; + column.push(accum); + } + TraceTable::init(vec![column]) +} + +#[cfg(test)] +mod tests { + use miden_core::Felt; + use stark_platinum_prover::{ + proof::options::ProofOptions, prover::IsStarkProver, verifier::IsStarkVerifier, + }; + use winter_air::TraceInfo; + use winter_prover::{Trace, TraceTable}; + + use crate::{ + adapter::{ + air::AirAdapter, public_inputs::AirAdapterPublicInputs, Prover, Transcript, Verifier, + }, + examples::cubic::{self, Cubic}, + }; + + #[test] + fn prove_and_verify_a_winterfell_cubic_air() { + let lambda_proof_options = ProofOptions::default_test_options(); + let winter_trace = cubic::build_trace(16); + let trace = AirAdapter::, Felt, ()>::convert_winterfell_trace_table( + winter_trace.main_segment().clone(), + ); + let pub_inputs = AirAdapterPublicInputs { + winterfell_public_inputs: *trace.columns()[0][15].value(), + transition_exemptions: vec![1], + transition_offsets: vec![0, 1], + trace_info: TraceInfo::new(1, 16), + metadata: (), + }; + + let proof = Prover::prove::, Felt, _>>( + &trace, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + ) + .unwrap(); + assert!( + Verifier::verify::, Felt, _>>( + &proof, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + ) + ); + } +} diff --git a/winterfell_adapter/src/examples/fibonacci_2_terms.rs b/winterfell_adapter/src/examples/fibonacci_2_terms.rs index c86296ee2..babd4f9de 100644 --- a/winterfell_adapter/src/examples/fibonacci_2_terms.rs +++ b/winterfell_adapter/src/examples/fibonacci_2_terms.rs @@ -1,24 +1,23 @@ -use lambdaworks_math::field::element::FieldElement; -use winterfell::math::FieldElement as IsWinterfellFieldElement; -use winterfell::{ - Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo, TraceTable, +use miden_core::Felt; +use winter_air::{ + Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo, TransitionConstraintDegree, }; - -use crate::field_element::element::AdapterFieldElement; +use winter_math::FieldElement as IsWinterfellFieldElement; +use winter_prover::TraceTable; /// A fibonacci winterfell AIR example. Two terms are computed /// at each step. This was taken from the original winterfell /// repository and adapted to work with lambdaworks. #[derive(Clone)] pub struct FibAir2Terms { - context: AirContext, - result: AdapterFieldElement, + context: AirContext, + result: Felt, } impl Air for FibAir2Terms { - type BaseField = AdapterFieldElement; - type PublicInputs = AdapterFieldElement; + type BaseField = Felt; + type PublicInputs = Felt; fn new(trace_info: TraceInfo, pub_inputs: Self::BaseField, options: ProofOptions) -> Self { let degrees = vec![ @@ -63,7 +62,7 @@ impl Air for FibAir2Terms { } } -pub fn build_trace(sequence_length: usize) -> TraceTable { +pub fn build_trace(sequence_length: usize) -> TraceTable { assert!( sequence_length.is_power_of_two(), "sequence length must be a power of 2" @@ -72,8 +71,8 @@ pub fn build_trace(sequence_length: usize) -> TraceTable { let mut trace = TraceTable::new(2, sequence_length / 2); trace.fill( |state| { - state[0] = AdapterFieldElement(FieldElement::one()); - state[1] = AdapterFieldElement(FieldElement::one()); + state[0] = Felt::ONE; + state[1] = Felt::ONE; }, |_, state| { state[0] += state[1]; @@ -83,3 +82,54 @@ pub fn build_trace(sequence_length: usize) -> TraceTable { trace } + +#[cfg(test)] +mod tests { + use miden_core::Felt; + use stark_platinum_prover::{ + proof::options::ProofOptions, prover::IsStarkProver, verifier::IsStarkVerifier, + }; + use winter_air::TraceInfo; + use winter_prover::{Trace, TraceTable}; + + use crate::{ + adapter::{ + air::AirAdapter, public_inputs::AirAdapterPublicInputs, Prover, Transcript, Verifier, + }, + examples::fibonacci_2_terms::{self, FibAir2Terms}, + }; + + #[test] + fn prove_and_verify_a_winterfell_fibonacci_2_terms_air() { + let lambda_proof_options = ProofOptions::default_test_options(); + let winter_trace = fibonacci_2_terms::build_trace(16); + let trace = + AirAdapter::, Felt, ()>::convert_winterfell_trace_table( + winter_trace.main_segment().clone(), + ); + let pub_inputs = AirAdapterPublicInputs { + winterfell_public_inputs: *trace.columns()[1][7].value(), + transition_exemptions: vec![1, 1], + transition_offsets: vec![0, 1], + trace_info: TraceInfo::new(2, 8), + metadata: (), + }; + + let proof = Prover::prove::, Felt, _>>( + &trace, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + ) + .unwrap(); + + assert!(Verifier::verify::< + AirAdapter, Felt, _>, + >( + &proof, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + )); + } +} diff --git a/winterfell_adapter/src/examples/fibonacci_rap.rs b/winterfell_adapter/src/examples/fibonacci_rap.rs index 3b9155e6d..f46e4e94e 100644 --- a/winterfell_adapter/src/examples/fibonacci_rap.rs +++ b/winterfell_adapter/src/examples/fibonacci_rap.rs @@ -1,16 +1,14 @@ use crate::adapter::air::FromColumns; -use crate::field_element::element::AdapterFieldElement; -use crate::utils::vec_field2adapter; -use lambdaworks_math::field::element::FieldElement; +use miden_core::Felt; use rand::seq::SliceRandom; use rand::thread_rng; -use winter_utils::{collections::Vec, uninit_vector}; -use winterfell::math::FieldElement as IsWinterfellFieldElement; -use winterfell::{math::StarkField, matrix::ColMatrix, Trace, TraceLayout}; -use winterfell::{ - Air, AirContext, Assertion, EvaluationFrame, ProofOptions, TraceInfo, TraceTable, - TransitionConstraintDegree, +use winter_air::{ + Air, AirContext, Assertion, AuxTraceRandElements, EvaluationFrame, ProofOptions, TraceInfo, + TraceLayout, TransitionConstraintDegree, }; +use winter_math::{ExtensionOf, FieldElement as IsWinterfellFieldElement, StarkField}; +use winter_prover::{ColMatrix, Trace, TraceTable}; +use winter_utils::{collections::Vec, uninit_vector}; #[derive(Clone)] pub struct RapTraceTable { @@ -132,21 +130,21 @@ impl Trace for RapTraceTable { } } -impl FromColumns for RapTraceTable { - fn from_cols(columns: Vec>) -> Self { +impl FromColumns for RapTraceTable { + fn from_cols(columns: Vec>, _: &()) -> Self { RapTraceTable::init(columns) } } #[derive(Clone)] pub struct FibonacciRAP { - context: AirContext, - result: AdapterFieldElement, + context: AirContext, + result: Felt, } impl Air for FibonacciRAP { - type BaseField = AdapterFieldElement; - type PublicInputs = AdapterFieldElement; + type BaseField = Felt; + type PublicInputs = Felt; fn new(trace_info: TraceInfo, pub_inputs: Self::BaseField, options: ProofOptions) -> Self { let degrees = vec![ @@ -183,11 +181,11 @@ impl Air for FibonacciRAP { main_frame: &EvaluationFrame, aux_frame: &EvaluationFrame, _periodic_values: &[F], - aux_rand_elements: &winterfell::AuxTraceRandElements, + aux_rand_elements: &AuxTraceRandElements, result: &mut [E], ) where F: IsWinterfellFieldElement, - E: IsWinterfellFieldElement + winterfell::math::ExtensionOf, + E: IsWinterfellFieldElement + ExtensionOf, { let gamma = aux_rand_elements.get_segment_elements(0)[0]; let curr_aux = aux_frame.current(); @@ -211,20 +209,20 @@ impl Air for FibonacciRAP { fn get_aux_assertions>( &self, - _aux_rand_elements: &winterfell::AuxTraceRandElements, + _aux_rand_elements: &AuxTraceRandElements, ) -> Vec> { let last_step = self.trace_length() - 1; - vec![Assertion::single(3, last_step, Self::BaseField::ONE.into())] + vec![Assertion::single(0, last_step, Self::BaseField::ONE.into())] } } -pub fn build_trace(sequence_length: usize) -> TraceTable { +pub fn build_trace(sequence_length: usize) -> TraceTable { assert!( sequence_length.is_power_of_two(), "sequence length must be a power of 2" ); - let mut fibonacci = vec![FieldElement::one(), FieldElement::one()]; + let mut fibonacci = vec![Felt::ONE, Felt::ONE]; for i in 2..(sequence_length + 1) { fibonacci.push(fibonacci[i - 2] + fibonacci[i - 1]) } @@ -234,8 +232,61 @@ pub fn build_trace(sequence_length: usize) -> TraceTable { permuted.shuffle(&mut rng); TraceTable::init(vec![ - vec_field2adapter(&fibonacci[..fibonacci.len() - 1]), - vec_field2adapter(&fibonacci[1..]), - vec_field2adapter(&permuted), + fibonacci[..fibonacci.len() - 1].to_vec(), + fibonacci[1..].to_vec(), + permuted, ]) } + +#[cfg(test)] +mod tests { + use miden_core::Felt; + use stark_platinum_prover::{ + proof::options::ProofOptions, prover::IsStarkProver, verifier::IsStarkVerifier, + }; + use winter_air::{TraceInfo, TraceLayout}; + use winter_prover::Trace; + + use crate::{ + adapter::{ + air::AirAdapter, public_inputs::AirAdapterPublicInputs, Prover, Transcript, Verifier, + }, + examples::fibonacci_rap::{self, FibonacciRAP, RapTraceTable}, + }; + + #[test] + fn prove_and_verify_a_winterfell_fibonacci_rap_air() { + let lambda_proof_options = ProofOptions::default_test_options(); + let winter_trace = fibonacci_rap::build_trace(16); + let trace = + AirAdapter::, Felt, ()>::convert_winterfell_trace_table( + winter_trace.main_segment().clone(), + ); + let trace_layout = TraceLayout::new(3, [1], [1]); + let trace_info = TraceInfo::new_multi_segment(trace_layout, 16, vec![]); + let fibonacci_result = trace.columns()[1][15]; + let pub_inputs = AirAdapterPublicInputs:: { + winterfell_public_inputs: *fibonacci_result.value(), + transition_exemptions: vec![1, 1, 1], + transition_offsets: vec![0, 1], + trace_info, + metadata: (), + }; + + let proof = Prover::prove::, Felt, _>>( + &trace, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + ) + .unwrap(); + assert!(Verifier::verify::< + AirAdapter, Felt, _>, + >( + &proof, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + )); + } +} diff --git a/winterfell_adapter/src/examples/miden_vm.rs b/winterfell_adapter/src/examples/miden_vm.rs new file mode 100644 index 000000000..0f6da58f6 --- /dev/null +++ b/winterfell_adapter/src/examples/miden_vm.rs @@ -0,0 +1,191 @@ +use miden_core::{Felt, ProgramInfo, StackOutputs}; +use miden_processor::{AuxTraceHints, ExecutionTrace, TraceLenSummary}; +use winter_air::TraceLayout; +use winter_prover::ColMatrix; + +use crate::adapter::air::FromColumns; + +#[derive(Clone)] +pub struct ExecutionTraceMetadata { + meta: Vec, + layout: TraceLayout, + aux_trace_hints: AuxTraceHints, + program_info: ProgramInfo, + stack_outputs: StackOutputs, + trace_len_summary: TraceLenSummary, +} + +impl From for ExecutionTraceMetadata { + fn from(value: ExecutionTrace) -> Self { + Self { + meta: value.meta, + layout: value.layout, + aux_trace_hints: value.aux_trace_hints, + program_info: value.program_info, + stack_outputs: value.stack_outputs, + trace_len_summary: value.trace_len_summary, + } + } +} + +impl FromColumns for ExecutionTrace { + fn from_cols(columns: Vec>, metadata: &ExecutionTraceMetadata) -> Self { + ExecutionTrace { + meta: metadata.meta.clone(), + layout: metadata.layout.clone(), + main_trace: ColMatrix::new(columns), + aux_trace_hints: metadata.aux_trace_hints.clone(), + program_info: metadata.program_info.clone(), + stack_outputs: metadata.stack_outputs.clone(), + trace_len_summary: metadata.trace_len_summary, + } + } +} + +#[cfg(test)] +mod tests { + use crate::adapter::air::AirAdapter; + use crate::adapter::public_inputs::AirAdapterPublicInputs; + use crate::adapter::{Prover, Transcript, Verifier}; + use crate::examples::fibonacci_2_terms::FibAir2Terms; + use miden_air::{ProcessorAir, ProvingOptions, PublicInputs}; + use miden_assembly::Assembler; + use miden_core::{Felt, StackInputs}; + use miden_processor::DefaultHost; + use miden_processor::{self as processor}; + use processor::ExecutionTrace; + use stark_platinum_prover::{ + proof::options::ProofOptions, prover::IsStarkProver, verifier::IsStarkVerifier, + }; + use winter_math::{FieldElement, StarkField}; + use winter_prover::Trace; + + #[test] + fn prove_and_verify_miden_readme_example() { + let mut lambda_proof_options = ProofOptions::default_test_options(); + lambda_proof_options.blowup_factor = 32; + let assembler = Assembler::default(); + + let program = assembler.compile("begin push.3 push.5 add end").unwrap(); + + let winter_trace = processor::execute( + &program, + StackInputs::default(), + DefaultHost::default(), + *ProvingOptions::default().execution_options(), + ) + .unwrap(); + let program_info = winter_trace.program_info().clone(); + let stack_outputs = winter_trace.stack_outputs().clone(); + + let pub_inputs = PublicInputs::new(program_info, StackInputs::default(), stack_outputs); + + let pub_inputs = AirAdapterPublicInputs { + winterfell_public_inputs: pub_inputs, + transition_exemptions: vec![2; 182], + transition_offsets: vec![0, 1], + trace_info: winter_trace.get_info(), + metadata: winter_trace.clone().into(), + }; + + let trace = + AirAdapter::::convert_winterfell_trace_table( + winter_trace.main_segment().clone(), + ); + + let proof = Prover::prove::>( + &trace, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + ) + .unwrap(); + + assert!(Verifier::verify::< + AirAdapter, + >( + &proof, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + )); + } + + fn compute_fibonacci(n: usize) -> Felt { + let mut t0 = Felt::ZERO; + let mut t1 = Felt::ONE; + + for _ in 0..n { + t1 = t0 + t1; + core::mem::swap(&mut t0, &mut t1); + } + t0 + } + + #[test] + fn prove_and_verify_miden_fibonacci() { + let fibonacci_number = 16; + let program = format!( + "begin + repeat.{} + swap dup.1 add + end + end", + fibonacci_number - 1 + ); + let program = Assembler::default().compile(program).unwrap(); + let expected_result = vec![compute_fibonacci(fibonacci_number).as_int()]; + let stack_inputs = StackInputs::try_from_values([0, 1]).unwrap(); + + let mut lambda_proof_options = ProofOptions::default_test_options(); + lambda_proof_options.blowup_factor = 8; + + let winter_trace = processor::execute( + &program, + stack_inputs.clone(), + DefaultHost::default(), + *ProvingOptions::default().execution_options(), + ) + .unwrap(); + let program_info = winter_trace.program_info().clone(); + let stack_outputs = winter_trace.stack_outputs().clone(); + + let pub_inputs = PublicInputs::new(program_info, stack_inputs, stack_outputs.clone()); + + assert_eq!( + expected_result, + stack_outputs.clone().stack_truncated(1), + "Program result was computed incorrectly" + ); + + let pub_inputs = AirAdapterPublicInputs { + winterfell_public_inputs: pub_inputs, + transition_exemptions: vec![2; 182], + transition_offsets: vec![0, 1], + trace_info: winter_trace.get_info(), + metadata: winter_trace.clone().into(), + }; + + let trace = + AirAdapter::::convert_winterfell_trace_table( + winter_trace.main_segment().clone(), + ); + + let proof = Prover::prove::>( + &trace, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + ) + .unwrap(); + + assert!(Verifier::verify::< + AirAdapter, + >( + &proof, + &pub_inputs, + &lambda_proof_options, + Transcript::new(&[]), + )); + } +} diff --git a/winterfell_adapter/src/examples/mod.rs b/winterfell_adapter/src/examples/mod.rs index 9a74354be..117c0609a 100644 --- a/winterfell_adapter/src/examples/mod.rs +++ b/winterfell_adapter/src/examples/mod.rs @@ -1,2 +1,4 @@ +pub mod cubic; pub mod fibonacci_2_terms; pub mod fibonacci_rap; +pub mod miden_vm; diff --git a/winterfell_adapter/src/field_element/element.rs b/winterfell_adapter/src/field_element/element.rs index dda01351e..8137587d4 100644 --- a/winterfell_adapter/src/field_element/element.rs +++ b/winterfell_adapter/src/field_element/element.rs @@ -1,3 +1,4 @@ +use crate::field_element::positive_integer::AdapterPositiveInteger; use core::fmt; use core::{ mem, @@ -12,14 +13,8 @@ use lambdaworks_math::{ traits::ByteConversion, }; use std::ops::{Add, AddAssign, Div, Mul, Neg, Sub}; -use winter_utils::{AsBytes, DeserializationError, Randomizable}; -use winterfell::math::ExtensibleField; -use winterfell::{ - math::{FieldElement as IsWinterfellFieldElement, StarkField}, - Deserializable, Serializable, -}; - -use crate::field_element::positive_integer::AdapterPositiveInteger; +use winter_math::{ExtensibleField, FieldElement as IsWinterfellFieldElement, StarkField}; +use winter_utils::{AsBytes, Deserializable, DeserializationError, Randomizable, Serializable}; #[derive(Debug, Copy, Clone, Default)] pub struct AdapterFieldElement(pub FieldElement); @@ -68,7 +63,7 @@ impl IsWinterfellFieldElement for AdapterFieldElement { unsafe { slice::from_raw_parts(p as *const u8, len) } } - unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], winterfell::DeserializationError> { + unsafe fn bytes_as_elements(bytes: &[u8]) -> Result<&[Self], DeserializationError> { if bytes.len() % Self::ELEMENT_BYTES != 0 { return Err(DeserializationError::InvalidValue(format!( "number of bytes ({}) does not divide into whole number of field elements", diff --git a/winterfell_adapter/src/utils.rs b/winterfell_adapter/src/utils.rs index 40e082c83..8fa8422c3 100644 --- a/winterfell_adapter/src/utils.rs +++ b/winterfell_adapter/src/utils.rs @@ -1,23 +1,38 @@ -use crate::field_element::element::AdapterFieldElement; -use lambdaworks_math::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField; +use lambdaworks_math::{field::traits::IsField, traits::ByteConversion}; use stark_platinum_prover::fri::FieldElement; +use winter_math::FieldElement as IsWinterfellFieldElement; -pub fn vec_field2adapter(input: &[FieldElement]) -> Vec { - input.iter().map(|&e| AdapterFieldElement(e)).collect() +pub fn vec_lambda2winter< + FE: IsField + IsWinterfellFieldElement + ByteConversion + Unpin, +>( + input: &[FieldElement], +) -> Vec { + input.iter().map(|&e| *e.value()).collect() } -pub fn vec_adapter2field(input: &[AdapterFieldElement]) -> Vec> { - input.iter().map(|&e| e.0).collect() +pub fn vec_winter2lambda< + FE: IsField + IsWinterfellFieldElement + ByteConversion + Unpin, +>( + input: &[FE], +) -> Vec> { + input + .iter() + .map(|&e| FieldElement::::const_from_raw(e)) + .collect() } -pub fn matrix_field2adapter( - input: &[Vec>], -) -> Vec> { - input.iter().map(|v| vec_field2adapter(v)).collect() +pub fn matrix_lambda2winter< + FE: IsField + IsWinterfellFieldElement + ByteConversion + Unpin, +>( + input: &[Vec>], +) -> Vec> { + input.iter().map(|v| vec_lambda2winter(v)).collect() } -pub fn matrix_adapter2field( - input: &[Vec], -) -> Vec>> { - input.iter().map(|v| vec_adapter2field(v)).collect() +pub fn matrix_winter2lambda< + FE: IsField + IsWinterfellFieldElement + ByteConversion + Unpin, +>( + input: &[Vec], +) -> Vec>> { + input.iter().map(|v| vec_winter2lambda(v)).collect() }