Skip to content

Commit

Permalink
Optimize polynomial folding
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewmilson committed Oct 29, 2024
1 parent 879c0ba commit c786f81
Show file tree
Hide file tree
Showing 10 changed files with 149 additions and 79 deletions.
8 changes: 6 additions & 2 deletions stwo_cairo_verifier/src/fields.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ pub trait Field<T> {

pub trait FieldBatchInverse<T, +Field<T>, +Copy<T>, +Drop<T>, +Mul<T>> {
/// Computes all `1/arr[i]` with a single call to `inverse()` using Montgomery batch inversion.
fn batch_inverse(arr: Array<T>) -> Array<T> {
fn batch_inverse(
arr: Array<T>
) -> Array<
T
> {
if arr.is_empty() {
return array![];
}

// Collect array `z, zy, ..., zy..b`.
let mut prefix_product_rev = array![];
let mut cumulative_product = *arr[arr.len() - 1];
Expand Down
2 changes: 1 addition & 1 deletion stwo_cairo_verifier/src/fields/cm31.cairo
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use core::num::traits::{One, Zero};
use core::ops::{AddAssign, MulAssign, SubAssign};
use super::{Field, FieldBatchInverse};
use super::m31::{M31, M31Impl, m31};
use super::{Field, FieldBatchInverse};

#[derive(Copy, Drop, Debug, PartialEq)]
pub struct CM31 {
Expand Down
4 changes: 2 additions & 2 deletions stwo_cairo_verifier/src/fields/m31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ pub impl M31FieldImpl of Field<M31> {
}
}

pub impl M31FieldBatchInverseImpl of FieldBatchInverse<M31> {}

fn sqn(v: M31, n: usize) -> M31 {
if n == 0 {
return v;
}
sqn(v * v, n - 1)
}

pub impl M31FieldBatchInverseImpl of FieldBatchInverse<M31> {}

#[generate_trait]
pub impl M31Impl of M31Trait {
#[inline]
Expand Down
1 change: 1 addition & 0 deletions stwo_cairo_verifier/src/fields/qm31.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ pub impl QM31Add of core::traits::Add<QM31> {
}

pub impl QM31Sub of core::traits::Sub<QM31> {
#[inline]
fn sub(lhs: QM31, rhs: QM31) -> QM31 {
QM31 { a: lhs.a - rhs.a, b: lhs.b - rhs.b }
}
Expand Down
76 changes: 23 additions & 53 deletions stwo_cairo_verifier/src/fri.cairo
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
use core::dict::Felt252Dict;
use stwo_cairo_verifier::channel::{Channel, ChannelTrait};
use stwo_cairo_verifier::circle::CosetImpl;
use stwo_cairo_verifier::fields::Field;
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::qm31::{QM31_EXTENSION_DEGREE, QM31, QM31Zero, QM31Trait};
use stwo_cairo_verifier::poly::circle::CircleDomainImpl;
use stwo_cairo_verifier::poly::circle::{
CircleEvaluation, SparseCircleEvaluation, SparseCircleEvaluationImpl
};
use stwo_cairo_verifier::poly::circle::{SparseCircleEvaluation, SparseCircleEvaluationImpl};
use stwo_cairo_verifier::poly::line::{
LineEvaluation, LineEvaluationImpl, SparseLineEvaluation, SparseLineEvaluationImpl
};
Expand Down Expand Up @@ -341,7 +337,7 @@ pub impl FriVerifierImpl of FriVerifierTrait {

while let Option::Some(_) = column_bounds.next_if_eq(@circle_poly_degree_bound) {
let sparse_evaluation = decommitted_values.pop_front().unwrap();
let mut folded_evals = sparse_evaluation.fold(circle_poly_alpha);
let mut folded_evals = sparse_evaluation.fold_to_line(circle_poly_alpha);

let n = folded_evals.len();
assert!(folded_evals.len() == layer_query_evals.len());
Expand Down Expand Up @@ -463,52 +459,26 @@ fn get_opening_positions(
positions
}

/// Folds a degree `d` polynomial into a degree `d/2` polynomial.
///
/// Let `eval` be a polynomial evaluated on line domain `E`, `alpha` be a random field
/// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function
/// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x *
/// f1(pi(x))`.
pub fn fold_line(eval: LineEvaluation, alpha: QM31) -> LineEvaluation {
let domain = eval.domain;
let mut values = array![];
for i in 0
..eval.values.len()
/ 2 {
let x = domain.at(bit_reverse_index(i * FOLD_FACTOR, domain.log_size()));
let f_x = eval.values[2 * i];
let f_neg_x = eval.values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse());
values.append(f0 + alpha * f1);
};
LineEvaluationImpl::new(domain.double(), values)
}

/// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate polynomial.
///
/// Let `src` be the evaluation of a circle polynomial `f` on a [`CircleDomain`] `E`. This function
/// computes evaluations of `f' = f0 + alpha * f1` on the x-coordinates of `E` such that `2f(p) =
/// f0(px) + py * f1(px)`. The evaluations of `f'` are accumulated into `dst` by the formula
/// `dst = dst * alpha^2 + f'`.
pub fn fold_circle_into_line(eval: CircleEvaluation, alpha: QM31) -> LineEvaluation {
let domain = eval.domain;
let mut values = array![];
for i in 0
..eval.bit_reversed_values.len()
/ 2 {
let p = domain
.at(bit_reverse_index(i * CIRCLE_TO_LINE_FOLD_FACTOR, domain.log_size()));
let f_p = eval.bit_reversed_values[2 * i];
let f_neg_p = eval.bit_reversed_values[2 * i + 1];
let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse());
values.append(f0 + alpha * f1);
};
LineEvaluation { values, domain: LineDomainImpl::new_unchecked(domain.half_coset) }
}

pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) {
(v0 + v1, (v0 - v1).mul_m31(itwid))
}
// /// Folds a degree `d` polynomial into a degree `d/2` polynomial.
// ///
// /// Let `eval` be a polynomial evaluated on line domain `E`, `alpha` be a random field
// /// element and `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function
// /// returns `f' = f0 + alpha * f1` evaluated on `pi(E)` such that `2f(x) = f0(pi(x)) + x *
// /// f1(pi(x))`.
// pub fn fold_line(eval: LineEvaluation, alpha: QM31) -> LineEvaluation {
// let domain = eval.domain;
// let mut values = array![];
// for i in 0
// ..eval.values.len()
// / 2 {
// let x = domain.at(bit_reverse_index(i * FOLD_FACTOR, domain.log_size()));
// let f_x = eval.values[2 * i];
// let f_neg_x = eval.values[2 * i + 1];
// let (f0, f1) = ibutterfly(*f_x, *f_neg_x, x.inverse());
// values.append(f0 + alpha * f1);
// };
// LineEvaluationImpl::new(domain.double(), values)
// }

#[cfg(test)]
mod test {
Expand Down Expand Up @@ -575,7 +545,7 @@ mod test {
subcircle_evals: array![CircleEvaluationImpl::new(domain, values)]
};
let alpha = qm31(260773061, 362745443, 1347591543, 1084609991);
let result = sparse_circle_evaluation.fold(alpha);
let result = sparse_circle_evaluation.fold_to_line(alpha);
let expected_result = array![qm31(730692421, 1363821003, 2146256633, 106012305)];
assert_eq!(expected_result, result);
}
Expand Down
49 changes: 46 additions & 3 deletions stwo_cairo_verifier/src/poly/circle.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ use stwo_cairo_verifier::circle::{
Coset, CosetImpl, CirclePoint, CirclePointM31Impl, CirclePointIndex, CirclePointIndexImpl,
CirclePointTrait
};
use stwo_cairo_verifier::fields::FieldBatchInverse;
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::qm31::QM31;
use stwo_cairo_verifier::fri::fold_circle_into_line;
use stwo_cairo_verifier::poly::utils::ibutterfly;
use stwo_cairo_verifier::utils::pow;

/// A valid domain for circle polynomial interpolation and evaluation.
Expand Down Expand Up @@ -163,16 +164,58 @@ pub impl SparseCircleEvaluationImpl of SparseCircleEvaluationImplTrait {
SparseCircleEvaluation { subcircle_evals }
}

fn fold(self: SparseCircleEvaluation, alpha: QM31) -> Array<QM31> {
/// Folds degree `d` circle polynomials into degree `d/2` univariate polynomials.
///
/// Let `eval_i` be the evaluation of circle polynomial `f_i` on [`CircleDomain`] `E_i`.
/// This function returns all `f_i' = f_i0 + alpha * f_i1` evaluated on the x-coordinate
/// of `E_i` (Note `E_i = {(x, y), (x, -y)}`) such that `2f_i(p) = f_i0(px) + py * f_i1(px)`.
fn fold_to_line(self: SparseCircleEvaluation, alpha: QM31) -> Array<QM31> {
let mut domain_initial_ys = array![];

for eval in self.subcircle_evals.span() {
domain_initial_ys.append(eval.domain.at(0).y);
};

let mut domain_initial_ys_inv = FieldBatchInverse::batch_inverse(domain_initial_ys);
let mut res = array![];

for eval in self
.subcircle_evals {
res.append(*fold_circle_into_line(eval, alpha).values[0])
let y_inv = domain_initial_ys_inv.pop_front().unwrap();
let f_i_at_p = *eval.bit_reversed_values[0];
let f_i_at_neg_p = *eval.bit_reversed_values[1];
let (f_i0, f_i1) = ibutterfly(f_i_at_p, f_i_at_neg_p, y_inv);
res.append(f_i0 + alpha * f_i1);
};

res
}
}

// /// Folds and accumulates a degree `d` circle polynomial into a degree `d/2` univariate
// polynomial.
// ///
// /// Let `src` be the evaluation of a circle polynomial `f` on a [`CircleDomain`] `E`. This
// function /// computes evaluations of `f' = f0 + alpha * f1` on the x-coordinates of `E` such that
// `2f(p) =
// /// f0(px) + py * f1(px)`. The evaluations of `f'` are accumulated into `dst` by the formula
// /// `dst = dst * alpha^2 + f'`.
// pub fn fold_circle_into_line(eval: CircleEvaluation, alpha: QM31) -> LineEvaluation {
// let domain = eval.domain;
// let mut values = array![];
// for i in 0
// ..eval.bit_reversed_values.len()
// / 2 {
// let p = domain
// .at(bit_reverse_index(i * CIRCLE_TO_LINE_FOLD_FACTOR, domain.log_size()));
// let f_p = eval.bit_reversed_values[2 * i];
// let f_neg_p = eval.bit_reversed_values[2 * i + 1];
// let (f0, f1) = ibutterfly(*f_p, *f_neg_p, p.y.inverse());
// values.append(f0 + alpha * f1);
// };
// LineEvaluation { values, domain: LineDomainImpl::new_unchecked(domain.half_coset) }
// }

#[cfg(test)]
mod tests {
use stwo_cairo_verifier::circle::{Coset, CosetImpl, CirclePoint, CirclePointIndexImpl};
Expand Down
36 changes: 25 additions & 11 deletions stwo_cairo_verifier/src/poly/line.cairo
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use core::iter::Iterator;
use stwo_cairo_verifier::circle::{CirclePoint, Coset, CosetImpl, CirclePointIndexImpl, CirclePointTrait};
use stwo_cairo_verifier::fields::FieldBatchInverse;
use stwo_cairo_verifier::fields::m31::{M31, m31};
use stwo_cairo_verifier::fields::qm31::{QM31, QM31Impl, QM31Zero};
use stwo_cairo_verifier::fields::{SecureField, BaseField};
use stwo_cairo_verifier::fri::fold_line;
use stwo_cairo_verifier::poly::utils::fold;
use stwo_cairo_verifier::poly::utils::{fold, butterfly, ibutterfly};
use stwo_cairo_verifier::utils::pow;

/// A univariate polynomial defined on a [LineDomain].
Expand Down Expand Up @@ -153,12 +153,6 @@ fn gen_twiddles(self: @LineDomain) -> Array<M31> {
res
}

#[inline]
fn butterfly(v0: QM31, v1: QM31, twid: M31) -> (QM31, QM31) {
let tmp = v1.mul_m31(twid);
(v0 + tmp, v0 - tmp)
}

/// Domain comprising of the x-coordinates of points in a [Coset].
///
/// For use with univariate polynomials.
Expand Down Expand Up @@ -242,11 +236,31 @@ pub struct SparseLineEvaluation {

#[generate_trait]
pub impl SparseLineEvaluationImpl of SparseLineEvaluationTrait {
/// Folds degree `d` polynomials into degree `d/2` polynomials.
///
/// Let `eval_i` be the evaluation of polynomial `f_i` on [`LineDomain`] `E_i` and
/// `pi(x) = 2x^2 - 1` be the circle's x-coordinate doubling map. This function returns
/// all `f_i' = f_i0 + alpha * f_i1` evaluated on `pi(E_i)` (`E_i` has a two points)
/// such that `2f_i(x) = f_i0(pi(x)) + x * f_i1(pi(x))`.
fn fold(self: SparseLineEvaluation, alpha: QM31) -> Array<QM31> {
let mut res = array![];
for eval in self.subline_evals {
res.append(*fold_line(eval, alpha).values[0]);
let mut domain_initials = array![];

for eval in self.subline_evals.span() {
domain_initials.append(eval.domain.at(0));
};

let mut domain_initials_inv = FieldBatchInverse::batch_inverse(domain_initials);
let mut res = array![];

for eval in self
.subline_evals {
let x_inv = domain_initials_inv.pop_front().unwrap();
let f_i_at_x = *eval.values[0];
let f_i_at_neg_x = *eval.values[1];
let (f_i0, f_i1) = ibutterfly(f_i_at_x, f_i_at_neg_x, x_inv);
res.append(f_i0 + alpha * f_i1);
};

res
}
}
Expand Down
14 changes: 13 additions & 1 deletion stwo_cairo_verifier/src/poly/utils.cairo
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use stwo_cairo_verifier::fields::qm31::QM31Impl;
use stwo_cairo_verifier::fields::m31::M31;
use stwo_cairo_verifier::fields::qm31::{QM31, QM31Impl};
use stwo_cairo_verifier::fields::{SecureField, BaseField};

/// Folds values recursively in `O(n)` by a hierarchical application of folding factors.
Expand Down Expand Up @@ -34,3 +35,14 @@ pub fn fold(
let rhs_val = fold(values, folding_factors, index + n / 2, level + 1, n / 2);
return lhs_val + rhs_val.mul_m31(*folding_factors[level]);
}

#[inline]
pub fn butterfly(v0: QM31, v1: QM31, twid: M31) -> (QM31, QM31) {
let tmp = v1.mul_m31(twid);
(v0 + tmp, v0 - tmp)
}

#[inline]
pub fn ibutterfly(v0: QM31, v1: QM31, itwid: M31) -> (QM31, QM31) {
(v0 + v1, (v0 - v1).mul_m31(itwid))
}
22 changes: 22 additions & 0 deletions stwo_cairo_verifier/src/utils.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,28 @@ pub impl U128TrailingZerosImpl of U128TrailingZerosTrait {
}
}

// #[generate_trait]
// impl UsizeImpl of UsizeExTrait {
// #[inline]
// fn next_multiple_of(self: usize, rhs: usize) -> usize {
// let r = self % rhs;
// if r == 0 {
// return self;
// }
// self + (rhs - r)
// }

// /// Calculates the largest value less than or equal to `self` that is a multiple of `rhs`.
// #[inline]
// fn prev_multiple_of(self: usize, rhs: usize) -> usize {
// let r = self % rhs;
// if r == 0 {
// return self;
// }
// self - r
// }
// }

#[cfg(test)]
mod tests {
use super::{pow, pow_qm31, qm31, bit_reverse_index, ArrayImpl};
Expand Down
16 changes: 10 additions & 6 deletions stwo_cairo_verifier/src/vcs/hasher.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub trait MerkleHasher {
const M31_ELEMENETS_IN_HASH: usize = 8;
const M31_ELEMENETS_IN_HASH_MINUS1: usize = M31_ELEMENETS_IN_HASH - 1;
const M31_IN_HASH_SHIFT: felt252 = 0x80000000; // 2**31.
const M31_IN_HASH_SHIFT_POW_4: felt252 = 0x10000000000000000000000000000000; // (2**32)**4.
pub impl PoseidonMerkleHasher of MerkleHasher {
type Hash = felt252;

Expand All @@ -34,19 +35,22 @@ pub impl PoseidonMerkleHasher of MerkleHasher {
hash_array.append(y);
}

// Pad column_values to a multiple of 8.
// Most often a node has no column values.
// TODO(andrew): Consider handing also common `len == QM31_EXTENSION_DEGREE`.
if column_values.len() == 0 {
return poseidon_hash_span(hash_array.span());
}

let mut pad_len = M31_ELEMENETS_IN_HASH_MINUS1
- ((column_values.len() + M31_ELEMENETS_IN_HASH_MINUS1) % M31_ELEMENETS_IN_HASH);
while pad_len > 0 {
while pad_len != 0 {
column_values.append(core::num::traits::Zero::zero());
pad_len = M31_ELEMENETS_IN_HASH_MINUS1
- ((column_values.len() + M31_ELEMENETS_IN_HASH_MINUS1) % M31_ELEMENETS_IN_HASH);
pad_len -= 1;
};

while !column_values.is_empty() {
let mut word = 0;
// Hash M31_ELEMENETS_IN_HASH = 8 M31 elements into a single field element.
word = word * M31_IN_HASH_SHIFT + column_values.pop_front().unwrap().inner.into();
let mut word = column_values.pop_front().unwrap().inner.into();
word = word * M31_IN_HASH_SHIFT + column_values.pop_front().unwrap().inner.into();
word = word * M31_IN_HASH_SHIFT + column_values.pop_front().unwrap().inner.into();
word = word * M31_IN_HASH_SHIFT + column_values.pop_front().unwrap().inner.into();
Expand Down

0 comments on commit c786f81

Please sign in to comment.