From c7df8ae1fffce979d5a110d5d4175b29ab70d8ff Mon Sep 17 00:00:00 2001 From: Andrew Milson Date: Tue, 22 Oct 2024 13:48:33 -0400 Subject: [PATCH] Optimizations --- stwo_cairo_verifier/Scarb.lock | 16 -- stwo_cairo_verifier/src/circle.cairo | 1 + stwo_cairo_verifier/src/fri.cairo | 25 ++- stwo_cairo_verifier/src/poly/line.cairo | 273 ++++++++++++++++++++--- stwo_cairo_verifier/src/poly/utils.cairo | 7 +- 5 files changed, 261 insertions(+), 61 deletions(-) diff --git a/stwo_cairo_verifier/Scarb.lock b/stwo_cairo_verifier/Scarb.lock index e776b9b5..0c9db427 100644 --- a/stwo_cairo_verifier/Scarb.lock +++ b/stwo_cairo_verifier/Scarb.lock @@ -1,22 +1,6 @@ # Code generated by scarb DO NOT EDIT. version = 1 -[[package]] -name = "snforge_scarb_plugin" -version = "0.32.0" -source = "git+https://github.com/foundry-rs/starknet-foundry?tag=v0.32.0#3817c903b640201c72e743b9bbe70a97149828a2" - -[[package]] -name = "snforge_std" -version = "0.32.0" -source = "git+https://github.com/foundry-rs/starknet-foundry?tag=v0.32.0#3817c903b640201c72e743b9bbe70a97149828a2" -dependencies = [ - "snforge_scarb_plugin", -] - [[package]] name = "stwo_cairo_verifier" version = "0.1.0" -dependencies = [ - "snforge_std", -] diff --git a/stwo_cairo_verifier/src/circle.cairo b/stwo_cairo_verifier/src/circle.cairo index 7acd9720..146f30ca 100644 --- a/stwo_cairo_verifier/src/circle.cairo +++ b/stwo_cairo_verifier/src/circle.cairo @@ -95,6 +95,7 @@ pub trait CirclePointTrait< impl CirclePointAdd, +Sub, +Mul, +Drop, +Copy> of Add> { /// Performs the operation of the circle as a group with additive notation. + #[inline] fn add(lhs: CirclePoint, rhs: CirclePoint) -> CirclePoint { CirclePoint { x: lhs.x * rhs.x - lhs.y * rhs.y, y: lhs.x * rhs.y + lhs.y * rhs.x } } diff --git a/stwo_cairo_verifier/src/fri.cairo b/stwo_cairo_verifier/src/fri.cairo index c3d6a6eb..528de94c 100644 --- a/stwo_cairo_verifier/src/fri.cairo +++ b/stwo_cairo_verifier/src/fri.cairo @@ -176,7 +176,9 @@ impl FriLayerVerifierImpl of FriLayerVerifierTrait { let subline_initial_index = bit_reverse_index(subline_start, self.domain.log_size()); let subline_initial = self.domain.coset.index_at(subline_initial_index); - let subline_domain = LineDomainImpl::new(CosetImpl::new(subline_initial, FOLD_STEP)); + let subline_domain = LineDomainImpl::new_unchecked( + CosetImpl::new(subline_initial, FOLD_STEP) + ); all_subline_evals.append(LineEvaluationImpl::new(subline_domain, subline_evals)); }; @@ -238,7 +240,7 @@ pub impl FriVerifierImpl of FriVerifierTrait { let mut inner_layers = array![]; let mut layer_bound = *max_column_bound - CIRCLE_TO_LINE_FOLD_STEP; - let mut layer_domain = LineDomainImpl::new( + let mut layer_domain = LineDomainImpl::new_unchecked( CosetImpl::half_odds(layer_bound + config.log_blowup_factor) ); @@ -316,7 +318,6 @@ pub impl FriVerifierImpl of FriVerifierTrait { self: @FriVerifier, queries: @Queries, decommitted_values: Array ) -> Result<(), FriVerificationError> { assert!(queries.log_domain_size == self.expected_query_log_domain_size); - let (last_layer_queries, last_layer_query_evals) = self .decommit_inner_layers(queries, @decommitted_values)?; self.decommit_last_layer(last_layer_queries, last_layer_query_evals) @@ -397,6 +398,11 @@ pub impl FriVerifierImpl of FriVerifierTrait { ) -> Result<(), FriVerificationError> { let FriVerifier { last_layer_domain, last_layer_poly, .. } = self; + let domain_log_size = last_layer_domain.log_size(); + // TODO(andrew): Note depending on the proof parameters, doing FFT on the last layer poly vs + // pointwize evaluation is less efficient. + let last_layer_evals = last_layer_poly.evaluate(*last_layer_domain).values; + let mut i = 0; loop { if i == queries.positions.len() { @@ -404,10 +410,11 @@ pub impl FriVerifierImpl of FriVerifierTrait { } let query = *queries.positions[i]; - let query_eval = *query_evals[i]; - let x = last_layer_domain.at(bit_reverse_index(query, last_layer_domain.log_size())); + // TODO(andrew): Makes more sense for the proof to provide coeffs in natural order and + // the FFT return evals in bit-reversed order to prevent this unnessesary bit-reverse. + let last_layer_eval_i = bit_reverse_index(query, domain_log_size); - if query_eval != last_layer_poly.eval_at_point(x.into()) { + if query_evals[i] != last_layer_evals[last_layer_eval_i] { break Result::Err(FriVerificationError::LastLayerEvaluationsInvalid); } @@ -514,7 +521,7 @@ pub fn fold_circle_into_line(eval: CircleEvaluation, alpha: QM31) -> LineEvaluat let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse()); values.append(f0 + alpha * f1); }; - LineEvaluation { values, domain: LineDomainImpl::new(domain.half_coset) } + LineEvaluation { values, domain: LineDomainImpl::new_unchecked(domain.half_coset) } } pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) { @@ -557,7 +564,9 @@ mod test { #[test] fn test_fold_line_2() { - let domain = LineDomainImpl::new(CosetImpl::new(CirclePointIndexImpl::new(553648128), 1)); + let domain = LineDomainImpl::new_unchecked( + CosetImpl::new(CirclePointIndexImpl::new(553648128), 1) + ); let values = array![ qm31(730692421, 1363821003, 2146256633, 106012305), qm31(1387266930, 149259209, 1148988082, 1930518101) diff --git a/stwo_cairo_verifier/src/poly/line.cairo b/stwo_cairo_verifier/src/poly/line.cairo index d75a73dc..602ef6aa 100644 --- a/stwo_cairo_verifier/src/poly/line.cairo +++ b/stwo_cairo_verifier/src/poly/line.cairo @@ -1,9 +1,13 @@ -use stwo_cairo_verifier::circle::{Coset, CosetImpl, CirclePointIndexImpl, CirclePointTrait}; -use stwo_cairo_verifier::fields::SecureField; +use core::iter::Iterator; +use stwo_cairo_verifier::circle::{ + CirclePoint, Coset, CosetImpl, CirclePointIndexImpl, CirclePointTrait +}; use stwo_cairo_verifier::fields::m31::{M31, m31}; -use stwo_cairo_verifier::fields::qm31::{QM31, QM31Zero}; +use stwo_cairo_verifier::fields::qm31::{QM31, QM31Impl, QM31Zero}; +use stwo_cairo_verifier::fields::{SecureField, BaseField}; use stwo_cairo_verifier::fri::fold_line; use stwo_cairo_verifier::poly::utils::fold; +use stwo_cairo_verifier::utils::pow; /// A univariate polynomial defined on a [LineDomain]. #[derive(Debug, Drop, Clone)] @@ -30,18 +34,129 @@ pub impl LinePolyImpl of LinePolyTrait { } /// Evaluates the polynomial at a single point. - fn eval_at_point(self: @LinePoly, mut x: SecureField) -> SecureField { + // TODO(andrew): Can remove if only use `Self::evaluate()` in the verifier. + // Note there are tradeoffs depending on the blowup factor last FRI layer degree bound. + fn eval_at_point(self: @LinePoly, mut x: BaseField) -> SecureField { let mut doublings = array![]; - let mut i = 0; - while i < *self.log_size { - doublings.append(x); - let x_square = x * x; - x = x_square + x_square - m31(1).into(); - i += 1; - }; + for _ in 0 + ..*self + .log_size { + doublings.append(x); + let x_square = x * x; + x = x_square + x_square - m31(1); + }; fold(self.coeffs, @doublings, 0, 0, self.coeffs.len()) } + + fn evaluate(self: @LinePoly, domain: LineDomain) -> LineEvaluation { + assert!(domain.size() >= self.coeffs.len()); + + // The first few FFT layers may just copy coefficients so we do it directly. + // See the docs for `n_skipped_layers` in `line_fft()`. + let log_domain_size = domain.log_size(); + let log_degree_bound = *self.log_size; + let n_skipped_layers = log_domain_size - log_degree_bound; + let duplicity = pow(2, n_skipped_layers); + let coeffs = repeat_value(self.coeffs.span(), duplicity); + + LineEvaluationImpl::new(domain, line_fft(coeffs, domain, n_skipped_layers)) + } +} + +/// Repeats each value sequentially `duplicity` many times. +pub fn repeat_value(values: Span, duplicity: usize) -> Array { + let mut res = array![]; + for v in values { + for _ in 0..duplicity { + res.append(*v) + }; + }; + res +} + +/// Performs a FFT on a univariate polynomial. +/// +/// `values` is the coefficients stored in bit-reversed order. The evaluations of the polynomial +/// over `domain` is returned in natural order. +/// +/// `n_skipped_layers` specifies how many of the initial butterfly layers to skip. This is used for +/// more efficient degree aware FFTs as the butterflies in the first layers of the FFT only involve +/// copying coefficients to different locations (because one or more of the coefficients is zero). +/// This new algorithm is `O(n log d)` vs `O(n log n)` where `n` is the domain size and `d` is the +/// degree of the polynomial. +/// +/// Note the algorithm does not operate on coefficients in the standard monomial basis but rather +/// coefficients in a basis relating to the circle's x-coordinate doubling map `pi(x) = 2x^2 - 1` +/// i.e. +/// +/// ```text +/// B = { 1 } * { x } * { pi(x) } * { pi(pi(x)) } * ... +/// = { 1, x, pi(x), pi(x) * x, pi(pi(x)), pi(pi(x)) * x, pi(pi(x)) * pi(x), ... } +/// ``` +/// +/// # Panics +/// +/// Panics if the number of values doesn't match the size of the domain. +#[inline] +fn line_fft( + mut values: Array, mut domain: LineDomain, n_skipped_layers: usize +) -> Array { + let n = values.len(); + assert!(values.len() == domain.size()); + + let mut domains = array![]; + while domain.log_size() != n_skipped_layers { + domains.append(domain); + domain = domain.double(); + }; + + let mut domains = domains.span(); + + while let Option::Some(domain) = domains.pop_back() { + let chunk_size = domain.size(); + let twiddles = gen_twiddles(domain).span(); + let n_chunks = n / chunk_size; + let stride = chunk_size / 2; + let mut next_values = array![]; + for chunk_i in 0 + ..n_chunks { + let mut chunk_values_rhs = array![]; + let mut i0 = chunk_i * chunk_size; + let mut i1 = i0 + stride; + for twiddle in twiddles { + let (v0, v1) = butterfly(*values[i0], *values[i1], *twiddle); + next_values.append(v0); + chunk_values_rhs.append(v1); + i0 += 1; + i1 += 1; + }; + next_values.append_span(chunk_values_rhs.span()); + }; + values = next_values; + }; + + values +} + +#[inline] +fn gen_twiddles(self: @LineDomain) -> Array { + let mut iter = LineDomainIterator { + cur: self.coset.initial_index.to_point(), + step: self.coset.step_size.to_point(), + remaining: self.size() / 2 + }; + let mut res = array![]; + while let Option::Some(v) = iter.next() { + res.append(v); + }; + res +} + +#[inline] +fn butterfly(v0: QM31, v1: QM31, twid: M31) -> (QM31, QM31) { + let tmp = v1.mul_m31(twid); + (v0 + tmp, v0 - tmp) } /// Domain comprising of the x-coordinates of points in a [Coset]. @@ -72,6 +187,15 @@ pub impl LineDomainImpl of LineDomainTrait { LineDomain { coset: coset } } + /// Returns a domain comprising of the x-coordinates of points in a coset. + /// + /// # Saftey + /// + /// All coset points must have unique `x` coordinates. + fn new_unchecked(coset: Coset) -> LineDomain { + LineDomain { coset: coset } + } + /// Returns the `i`th domain element. fn at(self: @LineDomain, index: usize) -> M31 { self.coset.at(index).x @@ -96,6 +220,7 @@ pub impl LineDomainImpl of LineDomainTrait { /// Evaluations of a univariate polynomial on a [LineDomain]. #[derive(Drop)] pub struct LineEvaluation { + /// Evaluations in natural order. pub values: Array, pub domain: LineDomain } @@ -126,32 +251,46 @@ pub impl SparseLineEvaluationImpl of SparseLineEvaluationTrait { } } +#[derive(Drop, Clone)] +struct LineDomainIterator { + pub cur: CirclePoint, + pub step: CirclePoint, + pub remaining: usize, +} + +impl LineDomainIteratorImpl of Iterator { + type Item = M31; + + fn next(ref self: LineDomainIterator) -> Option { + if self.remaining == 0 { + return Option::None; + } + self.remaining -= 1; + let res = self.cur.x; + self.cur = self.cur + self.step; + Option::Some(res) + } +} + #[cfg(test)] mod tests { + use core::iter::{IntoIterator, Iterator}; use stwo_cairo_verifier::circle::{CosetImpl, CirclePointIndexImpl}; use stwo_cairo_verifier::fields::m31::m31; use stwo_cairo_verifier::fields::qm31::qm31; - use super::{LinePoly, LinePolyTrait, LineDomainImpl}; - - #[test] - #[should_panic] - fn bad_line_domain() { - // This coset doesn't have points with unique x-coordinates. - let coset = CosetImpl::odds(2); - LineDomainImpl::new(coset); - } + use super::{LinePoly, LinePolyTrait, LineDomain, LineDomainImpl, LineDomainIterator}; #[test] fn line_domain_of_size_two_works() { let coset = CosetImpl::new(CirclePointIndexImpl::new(0), 1); - LineDomainImpl::new(coset); + LineDomainImpl::new_unchecked(coset); } #[test] fn line_domain_of_size_one_works() { let coset = CosetImpl::new(CirclePointIndexImpl::new(0), 0); - LineDomainImpl::new(coset); + LineDomainImpl::new_unchecked(coset); } #[test] @@ -165,7 +304,7 @@ mod tests { }; let x = m31(590768354); - let result = line_poly.eval_at_point(x.into()); + let result = line_poly.eval_at_point(x); assert_eq!(result, qm31(515899232, 1030391528, 1006544539, 11142505)); } @@ -177,7 +316,7 @@ mod tests { }; let x = m31(10); - let result = line_poly.eval_at_point(x.into()); + let result = line_poly.eval_at_point(x); assert_eq!(result, qm31(51, 62, 73, 84)); } @@ -186,21 +325,87 @@ mod tests { fn test_eval_at_point_3() { let poly = LinePoly { coeffs: array![ - qm31(1, 2, 3, 4), - qm31(5, 6, 7, 8), - qm31(9, 10, 11, 12), - qm31(13, 14, 15, 16), - qm31(17, 18, 19, 20), - qm31(21, 22, 23, 24), - qm31(25, 26, 27, 28), - qm31(29, 30, 31, 32), + qm31(1, 8, 0, 1), + qm31(2, 7, 1, 2), + qm31(3, 6, 0, 1), + qm31(4, 5, 1, 3), + qm31(5, 4, 0, 1), + qm31(6, 3, 1, 4), + qm31(7, 2, 0, 1), + qm31(8, 1, 1, 5), ], log_size: 3 }; - let x = qm31(2, 5, 7, 11); + let x = m31(10); let result = poly.eval_at_point(x); - assert_eq!(result, qm31(1857853974, 839310133, 939318020, 651207981)); + assert_eq!(result, qm31(1328848956, 239350644, 174242200, 838661589)); + } + + #[test] + fn test_evaluate() { + let log_size = 3; + let domain = LineDomainImpl::new_unchecked(CosetImpl::half_odds(log_size)); + let poly = LinePoly { + coeffs: array![ + qm31(1, 8, 0, 1), + qm31(2, 7, 1, 2), + qm31(3, 6, 0, 1), + qm31(4, 5, 1, 3), + qm31(5, 4, 0, 1), + qm31(6, 3, 1, 4), + qm31(7, 2, 0, 1), + qm31(8, 1, 1, 5), + ], + log_size, + }; + + let result = poly.evaluate(domain); + let mut result_iter = result.values.into_iter(); + + for x in domain + .into_iter() { + assert_eq!(result_iter.next().unwrap(), poly.eval_at_point(x)); + } + } + + #[test] + fn test_evaluate_with_large_domain() { + let log_size = 3; + let domain = LineDomainImpl::new_unchecked(CosetImpl::half_odds(log_size + 2)); + let poly = LinePoly { + coeffs: array![ + qm31(1, 8, 0, 1), + qm31(2, 7, 1, 2), + qm31(3, 6, 0, 1), + qm31(4, 5, 1, 3), + qm31(5, 4, 0, 1), + qm31(6, 3, 1, 4), + qm31(7, 2, 0, 1), + qm31(8, 1, 1, 5), + ], + log_size, + }; + + let result = poly.evaluate(domain); + let mut result_iter = result.values.into_iter(); + + for x in domain + .into_iter() { + assert_eq!(result_iter.next().unwrap(), poly.eval_at_point(x)); + } + } + + impl LineDomainIntoIterator of IntoIterator { + type IntoIter = LineDomainIterator; + + fn into_iter(self: LineDomain) -> LineDomainIterator { + LineDomainIterator { + cur: self.coset.initial_index.to_point(), + step: self.coset.step_size.to_point(), + remaining: self.size(), + } + } } } diff --git a/stwo_cairo_verifier/src/poly/utils.cairo b/stwo_cairo_verifier/src/poly/utils.cairo index 31e0ca10..48ae86e5 100644 --- a/stwo_cairo_verifier/src/poly/utils.cairo +++ b/stwo_cairo_verifier/src/poly/utils.cairo @@ -1,4 +1,5 @@ -use stwo_cairo_verifier::fields::SecureField; +use stwo_cairo_verifier::fields::qm31::QM31Impl; +use stwo_cairo_verifier::fields::{SecureField, BaseField}; /// Folds values recursively in `O(n)` by a hierarchical application of folding factors. /// @@ -20,7 +21,7 @@ use stwo_cairo_verifier::fields::SecureField; /// factors is provided. pub fn fold( values: @Array, - folding_factors: @Array, + folding_factors: @Array, index: usize, level: usize, n: usize @@ -31,5 +32,5 @@ pub fn fold( let lhs_val = fold(values, folding_factors, index, level + 1, n / 2); let rhs_val = fold(values, folding_factors, index + n / 2, level + 1, n / 2); - return lhs_val + rhs_val * *folding_factors[level]; + return lhs_val + rhs_val.mul_m31(*folding_factors[level]); }