From 290a8aebf1012fa7d70a2f5cd4b20b6bb5fcf35c Mon Sep 17 00:00:00 2001 From: Skalman Date: Thu, 29 Aug 2024 15:54:53 +0200 Subject: [PATCH] Boiler plate for TE CondAdd --- common/src/gadgets/cond_add.rs | 7 +- common/src/gadgets/mod.rs | 1 + common/src/gadgets/sw_cond_add.rs | 3 +- common/src/gadgets/te_cond_add.rs | 265 ++++++++++++++++++++++++++++++ 4 files changed, 272 insertions(+), 4 deletions(-) create mode 100644 common/src/gadgets/te_cond_add.rs diff --git a/common/src/gadgets/cond_add.rs b/common/src/gadgets/cond_add.rs index 2108a10..fc6df54 100644 --- a/common/src/gadgets/cond_add.rs +++ b/common/src/gadgets/cond_add.rs @@ -38,18 +38,19 @@ impl> AffineColumn { } } -pub trait CondAdd where +pub trait CondAdd where F: FftField, Curve: CurveConfig, AffinePoint: AffineRepr, - ContAddVal: CondAddValues + { + type CondAddValT: CondAddValues; fn init(bitmask: BitColumn, points: AffineColumn, seed: AffinePoint, domain: &Domain) -> Self; - fn evaluate_assignment(&self, z: &F) -> ContAddVal; + fn evaluate_assignment(&self, z: &F) -> Self::CondAddValT; } diff --git a/common/src/gadgets/mod.rs b/common/src/gadgets/mod.rs index 901ec14..b610e80 100644 --- a/common/src/gadgets/mod.rs +++ b/common/src/gadgets/mod.rs @@ -7,6 +7,7 @@ pub mod booleanity; // pub mod inner_prod_pub; pub mod cond_add; pub mod sw_cond_add; +pub mod te_cond_add; pub mod fixed_cells; pub mod inner_prod; diff --git a/common/src/gadgets/sw_cond_add.rs b/common/src/gadgets/sw_cond_add.rs index 2eee3e2..3c5a9d2 100644 --- a/common/src/gadgets/sw_cond_add.rs +++ b/common/src/gadgets/sw_cond_add.rs @@ -31,11 +31,12 @@ pub struct SwCondAddValues { pub acc: (F, F), } -impl CondAdd, SwCondAddValues> for SwCondAdd> where +impl CondAdd > for SwCondAdd> where F: FftField, Curve: SWCurveConfig, { + type CondAddValT = SwCondAddValues; // Populates the acc column starting from the supplied seed (as 0 doesn't have an affine SW representation). // As the SW addition formula used is not complete, the seed must be selected in a way that would prevent // exceptional cases (doublings or adding the opposite point). diff --git a/common/src/gadgets/te_cond_add.rs b/common/src/gadgets/te_cond_add.rs new file mode 100644 index 0000000..f8dea86 --- /dev/null +++ b/common/src/gadgets/te_cond_add.rs @@ -0,0 +1,265 @@ +use ark_ec::{AffineRepr, CurveGroup}; +use ark_ec::twisted_edwards::{Affine,TECurveConfig}; +use ark_ff::{FftField, Field}; +use ark_poly::{Evaluations, GeneralEvaluationDomain}; +use ark_poly::univariate::DensePolynomial; +use ark_std::{vec, vec::Vec}; + +use crate::{Column, FieldColumn, const_evals}; +use crate::domain::Domain; +use crate::gadgets::{ProverGadget, VerifierGadget}; +use crate::gadgets::booleanity::BitColumn; +use crate::gadgets::cond_add::{AffineColumn, CondAdd, CondAddValues}; + +// Conditional affine addition: +// if the bit is set for a point, add the point to the acc and store, +// otherwise copy the acc value +pub struct TeCondAdd> { + pub(super)bitmask: BitColumn, + pub(super)points: AffineColumn, + // The polynomial `X - w^{n-1}` in the Lagrange basis + pub(super)not_last: FieldColumn, + // Accumulates the (conditional) rolling sum of the points + pub acc: AffineColumn, + pub result: P, +} + +pub struct TeCondAddValues { + pub bitmask: F, + pub points: (F, F), + pub not_last: F, + pub acc: (F, F), +} + +impl CondAdd> for TeCondAdd> where + F: FftField, + Curve: TECurveConfig, +{ + type CondAddValT = TeCondAddValues; + // Populates the acc column starting from the supplied seed (as 0 doesn't work with the addition formula). + // As the TE addition formula used is not complete, the seed must be selected in a way that would prevent + // exceptional cases (doublings or adding the opposite point). + // The last point of the input column is ignored, as adding it would made the acc column overflow due the initial point. + fn init(bitmask: BitColumn, + points: AffineColumn>, + seed: Affine, + domain: &Domain) -> Self { + assert_eq!(bitmask.bits.len(), domain.capacity - 1); + assert_eq!(points.points.len(), domain.capacity - 1); + let not_last = domain.not_last_row.clone(); + let acc = bitmask.bits.iter() + .zip(points.points.iter()) + .scan(seed, |acc, (&b, point)| { + if b { + *acc = (*acc + point).into_affine(); + } + Some(*acc) + }); + let acc: Vec<_> = ark_std::iter::once(seed) + .chain(acc) + .collect(); + let init_plus_result = acc.last().unwrap(); + let result = init_plus_result.into_group() - seed.into_group(); + let result = result.into_affine(); + let acc = AffineColumn::private_column(acc, domain); + + Self { bitmask, points, acc, not_last, result } + } + + fn evaluate_assignment(&self, z: &F) -> TeCondAddValues { + TeCondAddValues { + bitmask: self.bitmask.evaluate(z), + points: self.points.evaluate(z), + not_last: self.not_last.evaluate(z), + acc: self.acc.evaluate(z), + } + } +} + +impl ProverGadget for TeCondAdd> + where + F: FftField, + Curve: TECurveConfig, +{ + fn witness_columns(&self) -> Vec> { + vec![self.acc.xs.poly.clone(), self.acc.ys.poly.clone()] + } + + fn constraints(&self) -> Vec> { + let domain = self.bitmask.domain_4x(); + let b = &self.bitmask.col.evals_4x; + let one = &const_evals(F::one(), domain); + let (x1, y1) = (&self.acc.xs.evals_4x, &self.acc.ys.evals_4x); + let (x2, y2) = (&self.points.xs.evals_4x, &self.points.ys.evals_4x); + let (x3, y3) = (&self.acc.xs.shifted_4x(), &self.acc.ys.shifted_4x()); + + let mut c1 = + &( + b * + &( + &( + &( + &(x1 - x2) * &(x1 - x2) + ) * + &( + &(x1 + x2) + x3 + ) + ) - + &( + &(y2 - y1) * &(y2 - y1) + ) + ) + ) + + &( + &(one - b) * &(y3 - y1) + ); + + let mut c2 = + &( + b * + &( + &( + &(x1 - x2) * &(y3 + y1) + ) - + &( + &(y2 - y1) * &(x3 - x1) + ) + ) + ) + + &( + &(one - b) * &(x3 - x1) + ); + + let not_last = &self.not_last.evals_4x; + c1 *= not_last; + c2 *= not_last; + + vec![c1, c2] + } + + fn constraints_linearized(&self, z: &F) -> Vec> { + let vals = self.evaluate_assignment(z); + let acc_x = self.acc.xs.as_poly(); + let acc_y = self.acc.ys.as_poly(); + + let (c_acc_x, c_acc_y) = vals.acc_coeffs_1(); + let c1_lin = acc_x * c_acc_x + acc_y * c_acc_y; + + let (c_acc_x, c_acc_y) = vals.acc_coeffs_2(); + let c2_lin = acc_x * c_acc_x + acc_y * c_acc_y; + + vec![c1_lin, c2_lin] + } + + fn domain(&self) -> GeneralEvaluationDomain { + self.bitmask.domain() + } +} + + +impl VerifierGadget for TeCondAddValues { + fn evaluate_constraints_main(&self) -> Vec { + let b = self.bitmask; + let (x1, y1) = self.acc; + let (x2, y2) = self.points; + let (x3, y3) = (F::zero(), F::zero()); + + let mut c1 = + b * ( + (x1 - x2) * (x1 - x2) * (x1 + x2 + x3) + - (y2 - y1) * (y2 - y1) + ) + (F::one() - b) * (y3 - y1); + + let mut c2 = + b * ( + (x1 - x2) * (y3 + y1) + - (y2 - y1) * (x3 - x1) + ) + (F::one() - b) * (x3 - x1); + + c1 *= self.not_last; + c2 *= self.not_last; + + vec![c1, c2] + } +} + + +impl CondAddValues for TeCondAddValues { + fn acc_coeffs_1(&self) -> (F, F) { + let b = self.bitmask; + let (x1, _y1) = self.acc; + let (x2, _y2) = self.points; + + let mut c_acc_x = b * (x1 - x2) * (x1 - x2); + let mut c_acc_y = F::one() - b; + + c_acc_x *= self.not_last; + c_acc_y *= self.not_last; + + (c_acc_x, c_acc_y) + } + + fn acc_coeffs_2(&self) -> (F, F) { + let b = self.bitmask; + let (x1, y1) = self.acc; + let (x2, y2) = self.points; + + let mut c_acc_x = b * (y1 - y2) + F::one() - b; + let mut c_acc_y = b * (x1 - x2); + + c_acc_x *= self.not_last; + c_acc_y *= self.not_last; + + (c_acc_x, c_acc_y) + } +} + + +#[cfg(test)] +mod tests { + use ark_ed_on_bls12_381_bandersnatch::EdwardsAffine; + use ark_poly::Polynomial; + use ark_std::test_rng; + + use crate::test_helpers::*; + use crate::test_helpers::cond_sum; + + use super::*; + + fn _test_te_cond_add_gadget(hiding: bool) { + let rng = &mut test_rng(); + + let log_n = 10; + let n = 2usize.pow(log_n); + let domain = Domain::new(n, hiding); + let seed = EdwardsAffine::generator(); + + let bitmask = random_bitvec(domain.capacity - 1, 0.5, rng); + let points = random_vec::(domain.capacity - 1, rng); + let expected_res = seed + cond_sum(&bitmask, &points); + + let bitmask_col = BitColumn::init(bitmask, &domain); + let points_col = AffineColumn::private_column(points, &domain); + let gadget = TeCondAdd::init(bitmask_col, points_col, seed, &domain); + let res = gadget.acc.points.last().unwrap(); + assert_eq!(res, &expected_res); + + let cs = gadget.constraints(); + let (c1, c2) = (&cs[0], &cs[1]); + let c1 = c1.interpolate_by_ref(); + let c2 = c2.interpolate_by_ref(); + assert_eq!(c1.degree(), 4 * n - 3); + assert_eq!(c2.degree(), 3 * n - 2); + + domain.divide_by_vanishing_poly(&c1); + domain.divide_by_vanishing_poly(&c2); + + // test_gadget(gadget); + } + + #[test] + fn test_te_cond_add_gadget() { + _test_te_cond_add_gadget(false); + _test_te_cond_add_gadget(true); + } +}