Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add edge case handling for batch_add #169

Merged
merged 8 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions benches/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use criterion::{BenchmarkId, Criterion};
use ff::{Field, PrimeField};
use group::prime::PrimeCurveAffine;
use halo2curves::bn256::{Fr as Scalar, G1Affine as Point};
use halo2curves::msm::{best_multiexp, multiexp_serial};
use halo2curves::msm::{msm_best, msm_serial};
use rand_core::{RngCore, SeedableRng};
use rand_xorshift::XorShiftRng;
use rayon::current_thread_index;
Expand Down Expand Up @@ -136,7 +136,7 @@ fn msm(c: &mut Criterion) {
assert!(k < 64);
let n: usize = 1 << k;
let mut acc = Point::identity().into();
b.iter(|| multiexp_serial(&coeffs[b_index][..n], &bases[..n], &mut acc));
b.iter(|| msm_serial(&coeffs[b_index][..n], &bases[..n], &mut acc));
})
.sample_size(10);
}
Expand All @@ -147,7 +147,7 @@ fn msm(c: &mut Criterion) {
assert!(k < 64);
let n: usize = 1 << k;
b.iter(|| {
best_multiexp(&coeffs[b_index][..n], &bases[..n]);
msm_best(&coeffs[b_index][..n], &bases[..n]);
})
})
.sample_size(SAMPLE_SIZE);
Expand Down
102 changes: 74 additions & 28 deletions src/msm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
// Booth encoding:
// * step by `window` size
// * slice by size of `window + 1``
// * each window overlap by 1 bit
// * append a zero bit to the least significant end
// * each window overlap by 1 bit * append a zero bit to the least significant end
// Indexing rule for example window size 3 where we slice by 4 bits:
// `[0, +1, +1, +2, +2, +3, +3, +4, -4, -3, -3 -2, -2, -1, -1, 0]``
// So we can reduce the bucket size without preprocessing scalars
Expand Down Expand Up @@ -54,14 +53,15 @@ fn get_booth_index(window_index: usize, window_size: usize, el: &[u8]) -> i32 {
}
}

/// Batch addition.
fn batch_add<C: CurveAffine>(
size: usize,
buckets: &mut [BucketAffine<C>],
points: &[SchedulePoint],
bases: &[Affine<C>],
) {
let mut t = vec![C::Base::ZERO; size];
let mut z = vec![C::Base::ZERO; size];
let mut t = vec![C::Base::ZERO; size]; // Stores x2 - x1
let mut z = vec![C::Base::ZERO; size]; // Stores y2 - y1
let mut acc = C::Base::ONE;

for (
Expand All @@ -76,16 +76,42 @@ fn batch_add<C: CurveAffine>(
z,
) in points.iter().zip(t.iter_mut()).zip(z.iter_mut())
{
*z = buckets[*buck_idx].x() - bases[*base_idx].x;
if buckets[*buck_idx].is_inf() {
// We assume bases[*base_idx] != infinity always.
continue;
}

if buckets[*buck_idx].x() == bases[*base_idx].x {
// y-coordinate matches:
// 1. y1 == y2 and sign = false or
// 2. y1 != y2 and sign = true
// => ( y1 == y2) xor !sign
// (This uses the fact that x1 == x2 and both points satisfy the curve eq.)
if (buckets[*buck_idx].y() == bases[*base_idx].y) ^ !*sign {
// Doubling
let x_squared = bases[*base_idx].x.square();
*z = buckets[*buck_idx].y() + buckets[*buck_idx].y(); // 2y
*t = acc * (x_squared + x_squared + x_squared); // acc * 3x^2
acc *= *z;
continue;
}
// P + (-P)
buckets[*buck_idx].set_inf();
continue;
}
// Addition
*z = buckets[*buck_idx].x() - bases[*base_idx].x; // x2 - x1
if *sign {
*t = acc * (buckets[*buck_idx].y() - bases[*base_idx].y);
} else {
*t = acc * (buckets[*buck_idx].y() + bases[*base_idx].y);
}
} // y2 - y1
acc *= *z;
}

acc = acc.invert().unwrap();
acc = acc
.invert()
.expect("Some edge case has not been handled properly");

for (
(
Expand All @@ -99,15 +125,18 @@ fn batch_add<C: CurveAffine>(
z,
) in points.iter().zip(t.iter()).zip(z.iter()).rev()
{
if buckets[*buck_idx].is_inf() {
// We assume bases[*base_idx] != infinity always.
continue;
}
let lambda = acc * t;
acc *= z;

let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x);
acc *= z; // update acc
let x = lambda.square() - (buckets[*buck_idx].x() + bases[*base_idx].x); // x_result
if *sign {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) - bases[*base_idx].y));
} else {
buckets[*buck_idx].set_y(&((lambda * (bases[*base_idx].x - x)) + bases[*base_idx].y));
}
} // y_result = lambda * (x1 - x_result) - y1
buckets[*buck_idx].set_x(&x);
}
}
Expand Down Expand Up @@ -207,6 +236,13 @@ impl<C: CurveAffine> BucketAffine<C> {
}
}

fn is_inf(&self) -> bool {
match self {
Self::None => true,
Self::Point(_) => false,
}
}

fn set_x(&mut self, x: &C::Base) {
match self {
Self::None => panic!("::set_x None"),
Expand All @@ -220,6 +256,13 @@ impl<C: CurveAffine> BucketAffine<C> {
Self::Point(ref mut a) => a.y = *y,
}
}

fn set_inf(&mut self) {
match self {
Self::None => {}
Self::Point(_) => *self = Self::None,
}
}
}

struct Schedule<C: CurveAffine> {
Expand Down Expand Up @@ -286,7 +329,10 @@ impl<C: CurveAffine> Schedule<C> {
}
}

pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
/// Performs a multi-scalar multiplication operation.
///
/// This function will panic if coeffs and bases have a different length.
pub fn msm_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
Expand All @@ -303,7 +349,7 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
let mut acc_or = vec![0; field_byte_size];
for coeff in &coeffs {
for (acc_limb, limb) in acc_or.iter_mut().zip(coeff.as_ref().iter()) {
*acc_limb = *acc_limb | *limb;
*acc_limb |= *limb;
}
}
let max_byte_size = field_byte_size
Expand All @@ -315,7 +361,7 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
if max_byte_size == 0 {
return;
}
let number_of_windows = max_byte_size * 8 as usize / c + 1;
let number_of_windows = max_byte_size * 8_usize / c + 1;

for current_window in (0..number_of_windows).rev() {
for _ in 0..c {
Expand Down Expand Up @@ -377,12 +423,12 @@ pub fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &
}
}

/// Performs a multi-exponentiation operation.
/// Performs a multi-scalar multiplication operation.
///
/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
pub fn msm_parallel<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

let num_threads = rayon::current_num_threads();
Expand All @@ -399,25 +445,22 @@ pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Cu
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
msm_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
msm_serial(coeffs, bases, &mut acc);
acc
}
}
///

/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp_independent_points<C: CurveAffine>(
coeffs: &[C::Scalar],
bases: &[C],
) -> C::Curve {
pub fn msm_best<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

// TODO: consider adjusting it with emprical data?
Expand All @@ -430,7 +473,7 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(
};

if c < 10 {
return best_multiexp(coeffs, bases);
return msm_parallel(coeffs, bases);
}

// coeffs to byte representation
Expand Down Expand Up @@ -491,7 +534,6 @@ pub fn best_multiexp_independent_points<C: CurveAffine>(

#[cfg(test)]
mod test {

use std::ops::Neg;

use crate::bn256::{Fr, G1Affine, G1};
Expand Down Expand Up @@ -548,27 +590,31 @@ mod test {
}

fn run_msm_cross<C: CurveAffine>(min_k: usize, max_k: usize) {
use rayon::iter::{IntoParallelIterator, ParallelIterator};

let points = (0..1 << max_k)
.into_par_iter()
.map(|_| C::Curve::random(OsRng))
.collect::<Vec<_>>();
let mut affine_points = vec![C::identity(); 1 << max_k];
C::Curve::batch_normalize(&points[..], &mut affine_points[..]);
let points = affine_points;

let scalars = (0..1 << max_k)
.into_par_iter()
.map(|_| C::Scalar::random(OsRng))
.collect::<Vec<_>>();

for k in min_k..=max_k {
let points = &points[..1 << k];
let scalars = &scalars[..1 << k];

let t0 = start_timer!(|| format!("cyclone k={}", k));
let e0 = super::best_multiexp_independent_points(scalars, points);
let t0 = start_timer!(|| format!("cyclone indep k={}", k));
let e0 = super::msm_best(scalars, points);
end_timer!(t0);

let t1 = start_timer!(|| format!("older k={}", k));
let e1 = super::best_multiexp(scalars, points);
let e1 = super::msm_parallel(scalars, points);
end_timer!(t1);
assert_eq!(e0, e1);
}
Expand Down
Loading