Skip to content

Commit

Permalink
feat: use halo2curves cycloneMSM (#36)
Browse files Browse the repository at this point in the history
* feat: use halo2curves cycloneMSM

* chore: remove small_multiexp benchmark
  • Loading branch information
jonathanpwang authored Aug 14, 2024
1 parent 03155a4 commit a4140d7
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 228 deletions.
19 changes: 12 additions & 7 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "halo2-axiom"
version = "0.4.4"
version = "0.5.0-rc.1"
authors = [
"Sean Bowe <[email protected]>",
"Ying Tong Lai <[email protected]>",
Expand Down Expand Up @@ -32,10 +32,6 @@ autoexamples = false
all-features = true
rustdoc-args = ["--cfg", "docsrs", "--html-in-header", "katex-header.html"]

[[bench]]
name = "arithmetic"
harness = false

[[bench]]
name = "commit_zk"
harness = false
Expand Down Expand Up @@ -63,7 +59,11 @@ crossbeam = "0.8"
ff = "0.13"
group = "0.13"
pairing = "0.23"
halo2curves = { package = "halo2curves-axiom", version = "0.5.0", default-features = false, features = ["bits", "bn256-table", "derive_serde"] }
halo2curves = { package = "halo2curves-axiom", version = "0.7.0", default-features = false, features = [
"bits",
"bn256-table",
"derive_serde",
] }
rand = "0.8"
rand_core = { version = "0.6", default-features = false }
tracing = "0.1"
Expand Down Expand Up @@ -96,7 +96,12 @@ getrandom = { version = "0.2", features = ["js"] }
default = ["batch", "multicore", "circuit-params"]
multicore = ["maybe-rayon/threads"]
dev-graph = ["plotters", "tabbycat"]
test-dev-graph = ["dev-graph", "plotters/bitmap_backend", "plotters/bitmap_encoder", "plotters/ttf"]
test-dev-graph = [
"dev-graph",
"plotters/bitmap_backend",
"plotters/bitmap_encoder",
"plotters/ttf",
]
gadget-traces = ["backtrace"]
# thread-safe-region = []
sanity-checks = []
Expand Down
39 changes: 0 additions & 39 deletions halo2_proofs/benches/arithmetic.rs

This file was deleted.

185 changes: 4 additions & 181 deletions halo2_proofs/src/arithmetic.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
//! This module provides common utilities, traits and structures for group,
//! field and polynomial arithmetic.

use std::cmp;

use super::multicore;
pub use ff::Field;
use group::{
ff::{BatchInvert, PrimeField},
prime::PrimeCurveAffine,
Curve, Group, GroupOpsOwned, ScalarMulOwned,
Curve, GroupOpsOwned, ScalarMulOwned,
};

use halo2curves::msm::msm_best;
pub use halo2curves::{CurveAffine, CurveExt};

/// This represents an element of a group with basic operations that can be
Expand All @@ -28,190 +27,14 @@ where
{
}

// ASSUMES C::Scalar::Repr is little endian
fn multiexp_serial<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C], acc: &mut C::Curve) {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();

let c = if bases.len() < 4 {
1
} else if bases.len() < 32 {
3
} else {
(f64::from(bases.len() as u32)).ln().ceil() as usize
};

// Group `bytes` into bits and take the `segment`th chunk of `c` bits
fn get_at<F: PrimeField>(segment: usize, c: usize, bytes: &F::Repr) -> usize {
let skip_bits = segment * c;
let skip_bytes = skip_bits / 8;

if skip_bytes >= 32 {
return 0;
}

let mut v = [0; 8];
for (v, o) in v.iter_mut().zip(bytes.as_ref()[skip_bytes..].iter()) {
*v = *o;
}

let mut tmp = u64::from_le_bytes(v);
tmp >>= skip_bits - (skip_bytes * 8);
tmp %= 1 << c;

tmp as usize
}

let segments = (C::Scalar::NUM_BITS as usize + c - 1) / c;

// this can be optimized
let mut coeffs_in_segments = Vec::with_capacity(segments);
// track what is the last segment where we actually have nonzero bits, so we completely skip buckets where the scalar bits for all coeffs are 0
let mut max_nonzero_segment = None;
for current_segment in 0..segments {
let coeff_segments: Vec<_> = coeffs
.iter()
.map(|coeff| {
let c_bits = get_at::<C::Scalar>(current_segment, c, coeff);
if c_bits != 0 {
max_nonzero_segment = Some(current_segment);
}
c_bits
})
.collect();
coeffs_in_segments.push(coeff_segments);
}

if max_nonzero_segment.is_none() {
return;
}
for coeffs_seg in coeffs_in_segments
.into_iter()
.take(max_nonzero_segment.unwrap() + 1)
.rev()
{
for _ in 0..c {
*acc = acc.double();
}

#[derive(Clone, Copy)]
enum Bucket<C: CurveAffine> {
None,
Affine(C),
Projective(C::Curve),
}

impl<C: CurveAffine> Bucket<C> {
fn add_assign(&mut self, other: &C) {
*self = match *self {
Bucket::None => Bucket::Affine(*other),
Bucket::Affine(a) => Bucket::Projective(a + *other),
Bucket::Projective(mut a) => {
a += *other;
Bucket::Projective(a)
}
}
}

fn add(self, mut other: C::Curve) -> C::Curve {
match self {
Bucket::None => other,
Bucket::Affine(a) => {
other += a;
other
}
Bucket::Projective(a) => other + &a,
}
}
}

let mut buckets: Vec<Bucket<C>> = vec![Bucket::None; (1 << c) - 1];

let mut max_bits = 0;
for (coeff, base) in coeffs_seg.into_iter().zip(bases.iter()) {
if coeff != 0 {
max_bits = cmp::max(max_bits, coeff);
buckets[coeff - 1].add_assign(base);
}
}

// Summation by parts
// e.g. 3a + 2b + 1c = a +
// (a) + b +
// ((a) + b) + c
let mut running_sum = C::Curve::identity();
for exp in buckets.into_iter().take(max_bits).rev() {
running_sum = exp.add(running_sum);
*acc += &running_sum;
}
}
}

/// Performs a small multi-exponentiation operation.
/// Uses the double-and-add algorithm with doublings shared across points.
pub fn small_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
let coeffs: Vec<_> = coeffs.iter().map(|a| a.to_repr()).collect();
let mut acc = C::Curve::identity();

// for byte idx
for byte_idx in (0..32).rev() {
// for bit idx
for bit_idx in (0..8).rev() {
acc = acc.double();
// for each coeff
for coeff_idx in 0..coeffs.len() {
let byte = coeffs[coeff_idx].as_ref()[byte_idx];
if ((byte >> bit_idx) & 1) != 0 {
acc += bases[coeff_idx];
}
}
}
}

acc
}

// [JPW] Keep this adapter to halo2curves to minimize code changes.
/// Performs a multi-exponentiation operation.
///
/// This function will panic if coeffs and bases have a different length.
///
/// This will use multithreading if beneficial.
pub fn best_multiexp<C: CurveAffine>(coeffs: &[C::Scalar], bases: &[C]) -> C::Curve {
assert_eq!(coeffs.len(), bases.len());

//println!("msm: {}", coeffs.len());

// let start = get_time();
let num_threads = multicore::current_num_threads();
let res = if coeffs.len() > num_threads {
let chunk = coeffs.len() / num_threads;
let num_chunks = coeffs.chunks(chunk).len();
let mut results = vec![C::Curve::identity(); num_chunks];
multicore::scope(|scope| {
let chunk = coeffs.len() / num_threads;

for ((coeffs, bases), acc) in coeffs
.chunks(chunk)
.zip(bases.chunks(chunk))
.zip(results.iter_mut())
{
scope.spawn(move |_| {
multiexp_serial(coeffs, bases, acc);
});
}
});
results.iter().fold(C::Curve::identity(), |a, b| a + b)
} else {
let mut acc = C::Curve::identity();
multiexp_serial(coeffs, bases, &mut acc);
acc
};

// let duration = get_duration(start);
#[allow(unsafe_code)]
// unsafe {
// MULTIEXP_TOTAL_TIME += duration;
// }
res
msm_best(coeffs, bases)
}

/// Dispatcher
Expand Down
2 changes: 1 addition & 1 deletion rust-toolchain
Original file line number Diff line number Diff line change
@@ -1 +1 @@
nightly-2023-08-11
nightly-2024-07-25

0 comments on commit a4140d7

Please sign in to comment.