Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WNAF over EC points #858

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
5 changes: 3 additions & 2 deletions math/benches/criterion_elliptic_curve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@ use pprof::criterion::{Output, PProfProfiler};

mod elliptic_curves;
use elliptic_curves::{
bls12_377::bls12_377_elliptic_curve_benchmarks, bls12_381::bls12_381_elliptic_curve_benchmarks,
bls12_377::bls12_377_elliptic_curve_benchmarks,
bls12_381::{bls12_381_elliptic_curve_benchmarks, wnaf::wnaf_bls12_381_benchmarks},
};

criterion_group!(
name = elliptic_curve_benches;
config = Criterion::default().with_profiler(PProfProfiler::new(100, Output::Flamegraph(None)));
targets = bls12_381_elliptic_curve_benchmarks, bls12_377_elliptic_curve_benchmarks
targets = bls12_381_elliptic_curve_benchmarks, bls12_377_elliptic_curve_benchmarks, wnaf_bls12_381_benchmarks
);
criterion_main!(elliptic_curve_benches);
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
pub mod wnaf;

use criterion::{black_box, Criterion};
use lambdaworks_math::{
cyclic_group::IsGroup,
Expand Down
73 changes: 73 additions & 0 deletions math/benches/elliptic_curves/bls12_381/wnaf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
use criterion::{black_box, Criterion};
use lambdaworks_math::{
cyclic_group::IsGroup,
elliptic_curve::{
short_weierstrass::curves::bls12_381::{
curve::BLS12381Curve,
default_types::{FrElement, FrField},
},
traits::IsEllipticCurve,
wnaf::WnafTable,
},
unsigned_integer::element::U256,
};
use rand::{Rng, SeedableRng};

#[allow(dead_code)]
pub fn wnaf_bls12_381_benchmarks(c: &mut Criterion) {
let scalar_size = 1000;

let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9001);
let mut scalars = Vec::new();
for _i in 0..scalar_size {
scalars.push(FrElement::new(U256::from(rng.gen::<u128>())));
}

let g = BLS12381Curve::generator();

let mut group = c.benchmark_group("BLS12-381 WNAF");
group.significance_level(0.1).sample_size(100);
group.throughput(criterion::Throughput::Elements(1));

group.bench_function(
format!(
"Naive BLS12-381 vector multiplication with size {}",
scalar_size
),
|bencher| {
bencher.iter(|| {
black_box(
scalars
.clone()
.iter()
.map(|scalar| {
black_box(
black_box(g.clone())
.operate_with_self(black_box(scalar.clone().representative()))
.to_affine(),
)
})
.collect::<Vec<_>>(),
)
});
},
);

group.bench_function(
format!(
"WNAF BLS12-381 vector multiplication with size {}",
scalar_size
),
|bencher| {
bencher.iter(|| {
black_box(
black_box(WnafTable::<BLS12381Curve, FrField>::new(
black_box(&g.clone()),
scalar_size,
))
.multi_scalar_mul(&black_box(scalars.clone())),
)
});
},
);
}
1 change: 1 addition & 0 deletions math/src/elliptic_curve/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ pub mod montgomery;
pub mod point;
pub mod short_weierstrass;
pub mod traits;
pub mod wnaf;
189 changes: 189 additions & 0 deletions math/src/elliptic_curve/wnaf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
use crate::{
cyclic_group::IsGroup,
elliptic_curve::short_weierstrass::{
point::ShortWeierstrassProjectivePoint, traits::IsShortWeierstrass,
},
field::{element::FieldElement, traits::IsPrimeField},
traits::ByteConversion,
};
use alloc::{vec, vec::Vec};
use core::marker::PhantomData;

#[cfg(feature = "parallel")]
use rayon::iter::{
IndexedParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator,
};

extern crate std; // To be able to use f64::ln()

pub struct WnafTable<EC, ScalarField>
where
EC: IsShortWeierstrass<PointRepresentation = ShortWeierstrassProjectivePoint<EC>>,
EC::PointRepresentation: Send + Sync,
ScalarField: IsPrimeField + Sync,
FieldElement<ScalarField>: ByteConversion + Send + Sync,
{
table: Vec<Vec<ShortWeierstrassProjectivePoint<EC>>>,
window_size: usize,
phantom: PhantomData<ScalarField>,
}

impl<EC, ScalarField> WnafTable<EC, ScalarField>
where
EC: IsShortWeierstrass<PointRepresentation = ShortWeierstrassProjectivePoint<EC>>,
EC::PointRepresentation: Send + Sync,
ScalarField: IsPrimeField + Sync,
FieldElement<ScalarField>: ByteConversion + Send + Sync,
{
pub fn new(base: &ShortWeierstrassProjectivePoint<EC>, max_num_of_scalars: usize) -> Self {
let scalar_field_bit_size = ScalarField::field_bit_size();
let window = Self::get_mul_window_size(max_num_of_scalars);
let in_window = 1 << window;
let outerc = (scalar_field_bit_size + window - 1) / window;
let last_in_window = 1 << (scalar_field_bit_size - (outerc - 1) * window);

let mut g_outer = base.clone();
let mut g_outers = Vec::with_capacity(outerc);
for _ in 0..outerc {
g_outers.push(g_outer.clone());
for _ in 0..window {
g_outer = g_outer.double();
}
}

let mut table =
vec![vec![ShortWeierstrassProjectivePoint::<EC>::neutral_element(); in_window]; outerc];

#[cfg(feature = "parallel")]
let iter = table.par_iter_mut();
#[cfg(not(feature = "parallel"))]
let iter = table.iter_mut();

iter.enumerate().take(outerc).zip(g_outers).for_each(
|((outer, multiples_of_g), g_outer)| {
let curr_in_window = if outer == outerc - 1 {
last_in_window
} else {
in_window
};

let mut g_inner = ShortWeierstrassProjectivePoint::<EC>::neutral_element();
for inner in multiples_of_g.iter_mut().take(curr_in_window) {
*inner = g_inner.clone();
g_inner = g_inner.operate_with(&g_outer);
}
},
);

Self {
table,
window_size: window,
phantom: PhantomData,
}
}

pub fn multi_scalar_mul(
&self,
v: &[FieldElement<ScalarField>],
) -> Vec<ShortWeierstrassProjectivePoint<EC>> {
#[cfg(feature = "parallel")]
let iter = v.par_iter();
#[cfg(not(feature = "parallel"))]
let iter = v.iter();

iter.map(|e| self.windowed_mul(e.clone())).collect()
}

fn windowed_mul(
&self,
scalar: FieldElement<ScalarField>,
) -> ShortWeierstrassProjectivePoint<EC> {
let mut res = self.table[0][0].clone();

let modulus_size = ScalarField::field_bit_size();
let outerc = (modulus_size + self.window_size - 1) / self.window_size;
let scalar_bits_le: Vec<bool> = scalar
.to_bytes_le()
.iter()
.flat_map(|byte| (0..8).map(|i| (byte >> i) & 1 == 1).collect::<Vec<_>>())
.collect();

for outer in 0..outerc {
let mut inner = 0usize;
for i in 0..self.window_size {
if outer * self.window_size + i < modulus_size
&& scalar_bits_le[outer * self.window_size + i]
{
inner |= 1 << i;
}
}
res = res.operate_with(&self.table[outer][inner]);
}

res.to_affine()
}

fn get_mul_window_size(max_num_of_scalars: usize) -> usize {
if max_num_of_scalars < 32 {
3
} else {
f64::ln(max_num_of_scalars as f64).ceil() as usize
}
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::{
elliptic_curve::{
short_weierstrass::curves::bls12_381::{
curve::BLS12381Curve,
default_types::{FrElement, FrField},
},
traits::IsEllipticCurve,
},
unsigned_integer::element::U256,
};
use rand::*;
use std::time::Instant;

#[test]
fn wnaf_works() {
let point_count = 100;
let g1 = BLS12381Curve::generator();

let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(9001);
let mut scalars = Vec::new();
for _i in 0..point_count {
scalars.push(FrElement::new(U256::from(rng.gen::<u128>())));
}

let start1 = Instant::now();
let naive_result: Vec<_> = scalars
.iter()
.map(|scalar| {
g1.operate_with_self(scalar.clone().representative())
.to_affine()
})
.collect();
let duration1 = start1.elapsed();
println!(
"Time taken for naive ksk with {} scalars: {:?}",
point_count, duration1
);

let start2 = Instant::now();
let wnaf_result =
WnafTable::<BLS12381Curve, FrField>::new(&g1, point_count).multi_scalar_mul(&scalars);
let duration2 = start2.elapsed();
println!(
"Time taken for wnaf msm including table generation with {} scalars: {:?}",
point_count, duration2
);

for i in 0..point_count {
assert_eq!(naive_result[i], wnaf_result[i]);
}
}
}
59 changes: 32 additions & 27 deletions provers/groth16/src/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use crate::{common::*, QuadraticArithmeticProgram};
use lambdaworks_math::{
cyclic_group::IsGroup,
elliptic_curve::{
short_weierstrass::{point::ShortWeierstrassProjectivePoint, traits::IsShortWeierstrass},
short_weierstrass::curves::bls12_381::{curve::BLS12381Curve, twist::BLS12381TwistCurve},
traits::{IsEllipticCurve, IsPairing},
wnaf::WnafTable,
},
};

Expand Down Expand Up @@ -92,44 +93,48 @@ pub fn setup(qap: &QuadraticArithmeticProgram) -> (ProvingKey, VerifyingKey) {

let delta_g2 = g2.operate_with_self(tw.delta.representative());

let z_powers_of_tau = &core::iter::successors(
// Start from delta^{-1} * t(τ)
// Note that t(τ) = (τ^N - 1) because our domain is roots of unity
Some(&delta_inv * (&tw.tau.pow(qap.num_of_gates) - FrElement::one())),
|prev| Some(prev * &tw.tau),
)
.take(qap.num_of_gates * 2)
.collect::<Vec<_>>();

let g1_wnaf = WnafTable::<BLS12381Curve, FrField>::new(
&g1,
*[
qap.num_of_public_inputs,
r_tau.len(),
l_tau.len(),
k_tau.len() - qap.num_of_public_inputs,
z_powers_of_tau.len(),
]
.iter()
.max()
.unwrap(),
);

(
ProvingKey {
alpha_g1,
beta_g1: g1.operate_with_self(tw.beta.representative()),
beta_g2,
delta_g1: g1.operate_with_self(tw.delta.representative()),
delta_g2: delta_g2.clone(),
l_tau_g1: batch_operate(&l_tau, &g1),
r_tau_g1: batch_operate(&r_tau, &g1),
r_tau_g2: batch_operate(&r_tau, &g2),
prover_k_tau_g1: batch_operate(&k_tau[qap.num_of_public_inputs..], &g1),
z_powers_of_tau_g1: batch_operate(
&core::iter::successors(
// Start from delta^{-1} * t(τ)
// Note that t(τ) = (τ^N - 1) because our domain is roots of unity
Some(&delta_inv * (&tw.tau.pow(qap.num_of_gates) - FrElement::one())),
|prev| Some(prev * &tw.tau),
)
.take(qap.num_of_gates * 2)
.collect::<Vec<_>>(),
&g1,
),
l_tau_g1: g1_wnaf.multi_scalar_mul(&l_tau),
r_tau_g1: g1_wnaf.multi_scalar_mul(&r_tau),
r_tau_g2: WnafTable::<BLS12381TwistCurve, FrField>::new(&g2, r_tau.len())
.multi_scalar_mul(&r_tau),
prover_k_tau_g1: g1_wnaf.multi_scalar_mul(&k_tau[qap.num_of_public_inputs..]),
z_powers_of_tau_g1: g1_wnaf.multi_scalar_mul(z_powers_of_tau),
},
VerifyingKey {
alpha_g1_times_beta_g2,
delta_g2,
gamma_g2: g2.operate_with_self(tw.gamma.representative()),
verifier_k_tau_g1: batch_operate(&k_tau[..qap.num_of_public_inputs], &g1),
verifier_k_tau_g1: g1_wnaf.multi_scalar_mul(&k_tau[..qap.num_of_public_inputs]),
},
)
}

fn batch_operate<E: IsShortWeierstrass>(
elems: &[FrElement],
point: &ShortWeierstrassProjectivePoint<E>,
) -> Vec<ShortWeierstrassProjectivePoint<E>> {
elems
.iter()
.map(|elem| point.operate_with_self(elem.representative()))
.collect()
}
Loading