From 415f9767b4a07c04eb59361aaa47c1f1fb4b303b Mon Sep 17 00:00:00 2001 From: Han Date: Mon, 12 Jun 2023 17:55:45 +0800 Subject: [PATCH] Implement `Sangria` backend based on `HyperPlonk` (#22) * feat: implement sangria with hyperplonk * feat: refactor evaluators * refactor: for future changes --- benchmark/benches/proof_system.rs | 2 +- benchmark/src/bin/plotter.rs | 11 +- plonkish_backend/benches/zero_check.rs | 2 +- plonkish_backend/src/backend.rs | 2 + plonkish_backend/src/backend/hyperplonk.rs | 6 +- .../src/backend/hyperplonk/folding.rs | 3 + .../src/backend/hyperplonk/folding/sangria.rs | 642 ++++++++++++++++++ .../folding/sangria/preprocessor.rs | 280 ++++++++ .../hyperplonk/folding/sangria/prover.rs | 271 ++++++++ .../hyperplonk/folding/sangria/verifier.rs | 86 +++ .../src/backend/hyperplonk/preprocessor.rs | 22 +- .../src/backend/hyperplonk/prover.rs | 42 +- .../src/backend/hyperplonk/util.rs | 8 +- .../src/backend/hyperplonk/verifier.rs | 23 +- plonkish_backend/src/piop/sum_check.rs | 8 +- .../src/piop/sum_check/classic.rs | 20 +- .../src/piop/sum_check/classic/eval.rs | 426 +++--------- plonkish_backend/src/util/expression.rs | 228 +++++-- .../src/util/expression/evaluator.rs | 324 +++++++++ .../src/util/expression/relaxed.rs | 245 +++++++ 20 files changed, 2177 insertions(+), 474 deletions(-) create mode 100644 plonkish_backend/src/backend/hyperplonk/folding.rs create mode 100644 plonkish_backend/src/backend/hyperplonk/folding/sangria.rs create mode 100644 plonkish_backend/src/backend/hyperplonk/folding/sangria/preprocessor.rs create mode 100644 plonkish_backend/src/backend/hyperplonk/folding/sangria/prover.rs create mode 100644 plonkish_backend/src/backend/hyperplonk/folding/sangria/verifier.rs create mode 100644 plonkish_backend/src/util/expression/evaluator.rs create mode 100644 plonkish_backend/src/util/expression/relaxed.rs diff --git a/benchmark/benches/proof_system.rs b/benchmark/benches/proof_system.rs index 30171500..6cef79ef 100644 --- a/benchmark/benches/proof_system.rs +++ b/benchmark/benches/proof_system.rs @@ -71,7 +71,7 @@ fn bench_hyperplonk>(k: usize) { let _timer = start_timer(|| format!("hyperplonk_verify-{k}")); let accept = { let mut transcript = Keccak256Transcript::from_proof(proof.as_slice()); - HyperPlonk::verify(&vp, &instances, &mut transcript, std_rng()).is_ok() + HyperPlonk::verify(&vp, (), &instances, &mut transcript, std_rng()).is_ok() }; assert!(accept); } diff --git a/benchmark/src/bin/plotter.rs b/benchmark/src/bin/plotter.rs index 4a6ac91a..7fe8e62f 100644 --- a/benchmark/src/bin/plotter.rs +++ b/benchmark/src/bin/plotter.rs @@ -57,11 +57,9 @@ fn main() { } fn parse_args() -> (bool, Vec) { - let (verbose, logs) = args() - .into_iter() - .chain(Some("".to_string())) - .tuple_windows() - .fold((false, None), |(mut verbose, mut logs), (key, value)| { + let (verbose, logs) = args().chain(Some("".to_string())).tuple_windows().fold( + (false, None), + |(mut verbose, mut logs), (key, value)| { match key.as_str() { "-" => { logs = Some( @@ -85,7 +83,8 @@ fn parse_args() -> (bool, Vec) { _ => {} }; (verbose, logs) - }); + }, + ); ( verbose, logs.expect("Either \"--log \" or \"-\" specified"), diff --git a/plonkish_backend/benches/zero_check.rs b/plonkish_backend/benches/zero_check.rs index b0cebf2a..de085d85 100644 --- a/plonkish_backend/benches/zero_check.rs +++ b/plonkish_backend/benches/zero_check.rs @@ -22,7 +22,7 @@ fn run(num_vars: usize, virtual_poly: VirtualPolynomial) { fn zero_check(c: &mut Criterion) { let setup = |num_vars: usize| { - let expression = vanilla_plonk_expression(); + let expression = vanilla_plonk_expression(num_vars); let (polys, challenges) = rand_vanilla_plonk_assignment::(num_vars, seeded_std_rng(), seeded_std_rng()); let ys = [rand_vec(num_vars, seeded_std_rng())]; diff --git a/plonkish_backend/src/backend.rs b/plonkish_backend/src/backend.rs index eb9b786b..da17e2bf 100644 --- a/plonkish_backend/src/backend.rs +++ b/plonkish_backend/src/backend.rs @@ -21,6 +21,7 @@ where type ProverParam: Debug; type VerifierParam: Debug; type ProverState: Debug; + type VerifierState: Debug; fn setup(circuit_info: &PlonkishCircuitInfo, rng: impl RngCore) -> Result; @@ -41,6 +42,7 @@ where fn verify( vp: &Self::VerifierParam, + state: impl BorrowMut, instances: &[&[F]], transcript: &mut impl TranscriptRead, rng: impl RngCore, diff --git a/plonkish_backend/src/backend/hyperplonk.rs b/plonkish_backend/src/backend/hyperplonk.rs index 378cfe29..af1e499d 100644 --- a/plonkish_backend/src/backend/hyperplonk.rs +++ b/plonkish_backend/src/backend/hyperplonk.rs @@ -29,6 +29,8 @@ mod preprocessor; mod prover; mod verifier; +pub mod folding; + #[cfg(any(test, feature = "benchmark"))] pub mod util; @@ -81,6 +83,7 @@ where type ProverParam = HyperPlonkProverParam; type VerifierParam = HyperPlonkVerifierParam; type ProverState = (); + type VerifierState = (); fn setup( circuit_info: &PlonkishCircuitInfo, @@ -291,6 +294,7 @@ where fn verify( vp: &Self::VerifierParam, + _: impl BorrowMut, instances: &[&[F]], transcript: &mut impl TranscriptRead, _: impl RngCore, @@ -447,7 +451,7 @@ pub(crate) mod test { let timer = start_timer(|| format!("verify-{num_vars}")); let result = { let mut transcript = T::from_proof(proof.as_slice()); - HyperPlonk::::verify(&vp, &instances, &mut transcript, seeded_std_rng()) + HyperPlonk::::verify(&vp, (), &instances, &mut transcript, seeded_std_rng()) }; assert_eq!(result, Ok(())); end_timer(timer); diff --git a/plonkish_backend/src/backend/hyperplonk/folding.rs b/plonkish_backend/src/backend/hyperplonk/folding.rs new file mode 100644 index 00000000..e46fffad --- /dev/null +++ b/plonkish_backend/src/backend/hyperplonk/folding.rs @@ -0,0 +1,3 @@ +mod sangria; + +pub use sangria::{Sangria, SangriaProverParam, SangriaProverState, SangriaVerifierParam}; diff --git a/plonkish_backend/src/backend/hyperplonk/folding/sangria.rs b/plonkish_backend/src/backend/hyperplonk/folding/sangria.rs new file mode 100644 index 00000000..e471cbd6 --- /dev/null +++ b/plonkish_backend/src/backend/hyperplonk/folding/sangria.rs @@ -0,0 +1,642 @@ +use crate::{ + backend::{ + hyperplonk::{ + folding::sangria::{ + preprocessor::batch_size, + preprocessor::preprocess, + prover::{evaluate_cross_term, lookup_h_polys, SangriaWitness}, + verifier::SangriaInstance, + }, + prover::{ + instance_polys, lookup_compressed_polys, lookup_m_polys, permutation_z_polys, + prove_zero_check, + }, + verifier::verify_zero_check, + HyperPlonk, HyperPlonkProverParam, HyperPlonkVerifierParam, + }, + PlonkishBackend, PlonkishCircuit, PlonkishCircuitInfo, WitnessEncoding, + }, + pcs::{AdditiveCommitment, PolynomialCommitmentScheme}, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::PrimeField, + end_timer, + expression::Expression, + start_timer, + transcript::{TranscriptRead, TranscriptWrite}, + Itertools, + }, + Error, +}; +use rand::RngCore; +use std::{borrow::BorrowMut, hash::Hash, iter, marker::PhantomData}; + +pub(crate) mod preprocessor; +pub(crate) mod prover; +mod verifier; + +#[derive(Clone, Debug)] +pub struct Sangria(PhantomData); + +#[derive(Debug)] +pub struct SangriaProverParam +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pp: HyperPlonkProverParam, + num_theta_primes: usize, + num_alpha_primes: usize, + num_folding_wintess_polys: usize, + num_folding_challenges: usize, + cross_term_expressions: Vec>, + zero_check_expression: Expression, +} + +impl SangriaProverParam +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub fn init(&self) -> SangriaProverState { + SangriaProverState { + is_folding: true, + witness: SangriaWitness::init( + self.pp.num_vars, + &self.pp.num_instances, + self.num_folding_wintess_polys, + self.num_folding_challenges, + ), + } + } +} + +#[derive(Debug)] +pub struct SangriaVerifierParam +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + vp: HyperPlonkVerifierParam, + num_theta_primes: usize, + num_alpha_primes: usize, + num_folding_wintess_polys: usize, + num_folding_challenges: usize, + num_cross_terms: usize, + zero_check_expression: Expression, +} + +impl SangriaVerifierParam +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub fn init(&self) -> SangriaVerifierState { + SangriaVerifierState { + is_folding: true, + instance: SangriaInstance::init( + &self.vp.num_instances, + self.num_folding_wintess_polys, + self.num_folding_challenges, + ), + } + } +} +#[derive(Debug)] +pub struct SangriaProverState +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + is_folding: bool, + witness: SangriaWitness, +} + +impl SangriaProverState +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub fn set_folding(&mut self, is_folding: bool) { + self.is_folding = is_folding; + } +} + +#[derive(Debug)] +pub struct SangriaVerifierState +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + is_folding: bool, + instance: SangriaInstance, +} + +impl SangriaVerifierState +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub fn set_folding(&mut self, is_folding: bool) { + self.is_folding = is_folding; + } +} + +impl PlonkishBackend for Sangria> +where + F: PrimeField + Ord + Hash, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: AdditiveCommitment, + Pcs::CommitmentChunk: AdditiveCommitment, +{ + type ProverParam = SangriaProverParam; + type VerifierParam = SangriaVerifierParam; + type ProverState = SangriaProverState; + type VerifierState = SangriaVerifierState; + + fn setup( + circuit_info: &PlonkishCircuitInfo, + rng: impl RngCore, + ) -> Result { + assert!(circuit_info.is_well_formed()); + + let num_vars = circuit_info.k; + let poly_size = 1 << num_vars; + let batch_size = batch_size(circuit_info); + Pcs::setup(poly_size, batch_size, rng) + } + + fn preprocess( + param: &Pcs::Param, + circuit_info: &PlonkishCircuitInfo, + ) -> Result<(Self::ProverParam, Self::VerifierParam), Error> { + assert!(circuit_info.is_well_formed()); + + preprocess(param, circuit_info) + } + + fn prove( + pp: &Self::ProverParam, + mut state: impl BorrowMut, + instances: &[&[F]], + circuit: &impl PlonkishCircuit, + transcript: &mut impl TranscriptWrite, + _: impl RngCore, + ) -> Result<(), Error> { + let SangriaProverParam { + pp, + num_theta_primes, + num_alpha_primes, + cross_term_expressions, + zero_check_expression, + .. + } = pp; + let state = state.borrow_mut(); + + for (num_instances, instances) in pp.num_instances.iter().zip_eq(instances) { + assert_eq!(instances.len(), *num_instances); + for instance in instances.iter() { + transcript.common_field_element(instance)?; + } + } + + // Round 0..n + + let mut witness_polys = Vec::with_capacity(pp.num_witness_polys.iter().sum()); + let mut witness_comms = Vec::with_capacity(witness_polys.len()); + let mut challenges = Vec::with_capacity(pp.num_challenges.iter().sum::()); + for (round, (num_witness_polys, num_folding_challenges)) in pp + .num_witness_polys + .iter() + .zip_eq(pp.num_challenges.iter()) + .enumerate() + { + let timer = start_timer(|| format!("witness_collector-{round}")); + let polys = circuit + .synthesize(round, &challenges)? + .into_iter() + .map(MultilinearPolynomial::new) + .collect_vec(); + assert_eq!(polys.len(), *num_witness_polys); + end_timer(timer); + + witness_comms.extend(Pcs::batch_commit_and_write(&pp.pcs, &polys, transcript)?); + witness_polys.extend(polys); + challenges.extend(transcript.squeeze_challenges(*num_folding_challenges)); + } + + // Round n + + let theta_primes = transcript.squeeze_challenges(*num_theta_primes); + + let timer = start_timer(|| format!("lookup_compressed_polys-{}", pp.lookups.len())); + let lookup_compressed_polys = { + let instance_polys = instance_polys(pp.num_vars, instances.iter().cloned()); + let polys = iter::empty() + .chain(instance_polys.iter()) + .chain(pp.preprocess_polys.iter()) + .chain(witness_polys.iter()) + .collect_vec(); + let thetas = iter::empty() + .chain(Some(F::ONE)) + .chain(theta_primes.iter().cloned()) + .collect_vec(); + lookup_compressed_polys(&pp.lookups, &polys, &challenges, &thetas) + }; + end_timer(timer); + + let timer = start_timer(|| format!("lookup_m_polys-{}", pp.lookups.len())); + let lookup_m_polys = lookup_m_polys(&lookup_compressed_polys)?; + end_timer(timer); + + let lookup_m_comms = Pcs::batch_commit_and_write(&pp.pcs, &lookup_m_polys, transcript)?; + + // Round n+1 + + let beta_prime = transcript.squeeze_challenge(); + + let timer = start_timer(|| format!("lookup_h_polys-{}", pp.lookups.len())); + let lookup_h_polys = lookup_h_polys(&lookup_compressed_polys, &lookup_m_polys, &beta_prime); + end_timer(timer); + + let lookup_h_comms = { + let polys = lookup_h_polys.iter().flatten(); + Pcs::batch_commit_and_write(&pp.pcs, polys, transcript)? + }; + + // Round n+2 + + let alpha_primes = transcript.squeeze_challenges(*num_alpha_primes); + + let incoming = SangriaWitness::from_committed( + pp.num_vars, + instances, + iter::empty() + .chain(witness_polys) + .chain(lookup_m_polys) + .chain(lookup_h_polys.into_iter().flatten()), + iter::empty() + .chain(witness_comms) + .chain(lookup_m_comms) + .chain(lookup_h_comms), + iter::empty() + .chain(challenges) + .chain(theta_primes) + .chain(Some(beta_prime)) + .chain(alpha_primes) + .collect(), + ); + + let timer = start_timer(|| format!("cross_term_polys-{}", cross_term_expressions.len())); + let cross_term_polys = evaluate_cross_term( + cross_term_expressions, + pp.num_vars, + &pp.preprocess_polys, + &state.witness, + &incoming, + ); + end_timer(timer); + + let cross_term_comms = Pcs::batch_commit_and_write(&pp.pcs, &cross_term_polys, transcript)?; + + // Round n+3 + + let r = transcript.squeeze_challenge(); + + let timer = start_timer(|| "fold"); + state + .witness + .fold(&incoming, &cross_term_polys, &cross_term_comms, &r); + end_timer(timer); + + if !state.is_folding { + let beta = transcript.squeeze_challenge(); + let gamma = transcript.squeeze_challenge(); + + let timer = + start_timer(|| format!("permutation_z_polys-{}", pp.permutation_polys.len())); + let builtin_witness_poly_offset = pp.num_witness_polys.iter().sum::(); + let instance_polys = instance_polys(pp.num_vars, &state.witness.instance.instances); + let polys = iter::empty() + .chain(&instance_polys) + .chain(&pp.preprocess_polys) + .chain(&state.witness.witness_polys[..builtin_witness_poly_offset]) + .chain(pp.permutation_polys.iter().map(|(_, poly)| poly)) + .collect_vec(); + let permutation_z_polys = permutation_z_polys( + pp.num_permutation_z_polys, + &pp.permutation_polys, + &polys, + &beta, + &gamma, + ); + end_timer(timer); + + let permutation_z_comms = + Pcs::batch_commit_and_write(&pp.pcs, &permutation_z_polys, transcript)?; + + // Round n+4 + + let alpha = transcript.squeeze_challenge(); + let y = transcript.squeeze_challenges(pp.num_vars); + + let polys = iter::empty() + .chain(polys) + .chain(&state.witness.witness_polys[builtin_witness_poly_offset..]) + .chain(permutation_z_polys.iter()) + .chain(Some(&state.witness.e_poly)) + .collect_vec(); + let challenges = iter::empty() + .chain(state.witness.instance.challenges.iter().copied()) + .chain([beta, gamma, alpha, state.witness.instance.u]) + .collect(); + let (points, evals) = { + prove_zero_check( + pp.num_instances.len(), + zero_check_expression, + &polys, + challenges, + y, + transcript, + )? + }; + + // PCS open + + let dummy_comm = Pcs::Commitment::default(); + let comms = iter::empty() + .chain(iter::repeat(&dummy_comm).take(pp.num_instances.len())) + .chain(&pp.preprocess_comms) + .chain(&state.witness.instance.witness_comms[..builtin_witness_poly_offset]) + .chain(&pp.permutation_comms) + .chain(&state.witness.instance.witness_comms[builtin_witness_poly_offset..]) + .chain(&permutation_z_comms) + .chain(Some(&state.witness.instance.e_comm)) + .collect_vec(); + let timer = start_timer(|| format!("pcs_batch_open-{}", evals.len())); + Pcs::batch_open(&pp.pcs, polys, comms, &points, &evals, transcript)?; + end_timer(timer); + } + + Ok(()) + } + + fn verify( + vp: &Self::VerifierParam, + mut state: impl BorrowMut, + instances: &[&[F]], + transcript: &mut impl TranscriptRead, + _: impl RngCore, + ) -> Result<(), Error> { + let SangriaVerifierParam { + vp, + num_theta_primes, + num_alpha_primes, + num_cross_terms, + zero_check_expression, + .. + } = vp; + let state = state.borrow_mut(); + + for (num_instances, instances) in vp.num_instances.iter().zip_eq(instances) { + assert_eq!(instances.len(), *num_instances); + for instance in instances.iter() { + transcript.common_field_element(instance)?; + } + } + + // Round 0..n + + let mut witness_comms = Vec::with_capacity(vp.num_witness_polys.iter().sum()); + let mut challenges = Vec::with_capacity(vp.num_challenges.iter().sum::() + 4); + for (num_polys, num_folding_challenges) in + vp.num_witness_polys.iter().zip_eq(vp.num_challenges.iter()) + { + witness_comms.extend(Pcs::read_commitments(&vp.pcs, *num_polys, transcript)?); + challenges.extend(transcript.squeeze_challenges(*num_folding_challenges)); + } + + // Round n + + let theta_primes = transcript.squeeze_challenges(*num_theta_primes); + + let lookup_m_comms = Pcs::read_commitments(&vp.pcs, vp.num_lookups, transcript)?; + + // Round n+1 + + let beta_prime = transcript.squeeze_challenge(); + + let lookup_h_comms = Pcs::read_commitments(&vp.pcs, 2 * vp.num_lookups, transcript)?; + + // Round n+2 + + let alpha_primes = transcript.squeeze_challenges(*num_alpha_primes); + + let incoming = SangriaInstance::from_committed( + instances, + iter::empty() + .chain(witness_comms) + .chain(lookup_m_comms) + .chain(lookup_h_comms), + iter::empty() + .chain(challenges) + .chain(theta_primes) + .chain(Some(beta_prime)) + .chain(alpha_primes) + .collect(), + ); + + let cross_term_comms = Pcs::read_commitments(&vp.pcs, *num_cross_terms, transcript)?; + + // Round n+3 + + let r = transcript.squeeze_challenge(); + + state.instance.fold(&incoming, &cross_term_comms, &r); + + if !state.is_folding { + let beta = transcript.squeeze_challenge(); + let gamma = transcript.squeeze_challenge(); + + let permutation_z_comms = + Pcs::read_commitments(&vp.pcs, vp.num_permutation_z_polys, transcript)?; + + // Round n+4 + + let alpha = transcript.squeeze_challenge(); + let y = transcript.squeeze_challenges(vp.num_vars); + + let instances = state.instance.instance_slices(); + let challenges = iter::empty() + .chain(state.instance.challenges.iter().copied()) + .chain([beta, gamma, alpha, state.instance.u]) + .collect_vec(); + let (points, evals) = { + verify_zero_check( + vp.num_vars, + zero_check_expression, + &instances, + &challenges, + &y, + transcript, + )? + }; + + // PCS verify + + let builtin_witness_poly_offset = vp.num_witness_polys.iter().sum::(); + let dummy_comm = Pcs::Commitment::default(); + let comms = iter::empty() + .chain(iter::repeat(&dummy_comm).take(vp.num_instances.len())) + .chain(&vp.preprocess_comms) + .chain(&state.instance.witness_comms[..builtin_witness_poly_offset]) + .chain(vp.permutation_comms.iter().map(|(_, comm)| comm)) + .chain(&state.instance.witness_comms[builtin_witness_poly_offset..]) + .chain(&permutation_z_comms) + .chain(Some(&state.instance.e_comm)) + .collect_vec(); + Pcs::batch_verify(&vp.pcs, comms, &points, &evals, transcript)?; + } + + Ok(()) + } +} + +impl WitnessEncoding for Sangria { + fn row_mapping(k: usize) -> Vec { + Pb::row_mapping(k) + } +} + +#[cfg(test)] +pub(crate) mod test { + use crate::{ + backend::{ + hyperplonk::{ + folding::sangria::Sangria, + util::{rand_vanilla_plonk_circuit, rand_vanilla_plonk_with_lookup_circuit}, + HyperPlonk, + }, + PlonkishBackend, PlonkishCircuit, PlonkishCircuitInfo, + }, + pcs::{ + multilinear::{MultilinearIpa, MultilinearKzg, MultilinearSimulator}, + univariate::UnivariateKzg, + AdditiveCommitment, PolynomialCommitmentScheme, + }, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::PrimeField, + end_timer, start_timer, + test::{seeded_std_rng, std_rng}, + transcript::{ + InMemoryTranscript, Keccak256Transcript, TranscriptRead, TranscriptWrite, + }, + Itertools, + }, + }; + use halo2_curves::{bn256::Bn256, grumpkin}; + use std::{hash::Hash, iter, ops::Range}; + + pub(crate) fn run_sangria_hyperplonk( + num_vars_range: Range, + circuit_fn: impl Fn(usize) -> (PlonkishCircuitInfo, Vec>>, Vec), + ) where + F: PrimeField + Ord + Hash, + Pcs: PolynomialCommitmentScheme>, + Pcs::Commitment: AdditiveCommitment, + Pcs::CommitmentChunk: AdditiveCommitment, + T: TranscriptRead + + TranscriptWrite + + InMemoryTranscript, + C: PlonkishCircuit, + { + for num_vars in num_vars_range { + let (circuit_info, instances, circuits) = circuit_fn(num_vars); + + let timer = start_timer(|| format!("setup-{num_vars}")); + let param = Sangria::>::setup(&circuit_info, seeded_std_rng()).unwrap(); + end_timer(timer); + + let timer = start_timer(|| format!("preprocess-{num_vars}")); + let (pp, vp) = Sangria::>::preprocess(¶m, &circuit_info).unwrap(); + end_timer(timer); + + let (mut prover_state, mut verifier_state) = (pp.init(), vp.init()); + for (idx, (instances, circuit)) in instances.iter().zip_eq(circuits.iter()).enumerate() + { + let is_folding = idx != circuits.len() - 1; + let instances = instances.iter().map(Vec::as_slice).collect_vec(); + + let timer = start_timer(|| format!("prove-{num_vars}")); + let proof = { + prover_state.set_folding(is_folding); + let mut transcript = T::default(); + Sangria::>::prove( + &pp, + &mut prover_state, + &instances, + circuit, + &mut transcript, + seeded_std_rng(), + ) + .unwrap(); + transcript.into_proof() + }; + end_timer(timer); + + let timer = start_timer(|| format!("verify-{num_vars}")); + let result = { + verifier_state.set_folding(is_folding); + let mut transcript = T::from_proof(proof.as_slice()); + Sangria::>::verify( + &vp, + &mut verifier_state, + &instances, + &mut transcript, + seeded_std_rng(), + ) + }; + assert_eq!(result, Ok(())); + end_timer(timer); + } + } + } + + macro_rules! tests { + ($name:ident, $pcs:ty, $num_vars_range:expr) => { + paste::paste! { + #[test] + fn [<$name _hyperplonk_sangria_vanilla_plonk>]() { + run_sangria_hyperplonk::<_, $pcs, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { + let (circuit_info, _, _) = rand_vanilla_plonk_circuit(num_vars, std_rng(), seeded_std_rng()); + let (instances, circuits) = iter::repeat_with(|| { + let (_, instances, circuit) = rand_vanilla_plonk_circuit(num_vars, std_rng(), seeded_std_rng()); + (instances, circuit) + }).take(3).unzip(); + (circuit_info, instances, circuits) + }); + } + + #[test] + fn [<$name _hyperplonk_sangria_vanilla_plonk_with_lookup>]() { + run_sangria_hyperplonk::<_, $pcs, Keccak256Transcript<_>, _>($num_vars_range, |num_vars| { + let (circuit_info, _, _) = rand_vanilla_plonk_with_lookup_circuit(num_vars, std_rng(), seeded_std_rng()); + let (instances, circuits) = iter::repeat_with(|| { + let (_, instances, circuit) = rand_vanilla_plonk_with_lookup_circuit(num_vars, std_rng(), seeded_std_rng()); + (instances, circuit) + }).take(3).unzip(); + (circuit_info, instances, circuits) + }); + } + } + }; + ($name:ident, $pcs:ty) => { + tests!($name, $pcs, 2..16); + }; + } + + tests!(ipa, MultilinearIpa); + tests!(kzg, MultilinearKzg); + tests!(sim_kzg, MultilinearSimulator>); +} diff --git a/plonkish_backend/src/backend/hyperplonk/folding/sangria/preprocessor.rs b/plonkish_backend/src/backend/hyperplonk/folding/sangria/preprocessor.rs new file mode 100644 index 00000000..b114e2b0 --- /dev/null +++ b/plonkish_backend/src/backend/hyperplonk/folding/sangria/preprocessor.rs @@ -0,0 +1,280 @@ +use crate::{ + backend::{ + hyperplonk::{ + folding::sangria::{SangriaProverParam, SangriaVerifierParam}, + preprocessor::permutation_constraints, + HyperPlonk, + }, + PlonkishBackend, PlonkishCircuitInfo, + }, + pcs::PolynomialCommitmentScheme, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::{div_ceil, PrimeField}, + chain, + expression::{ + relaxed::{cross_term_expressions, products, relaxed_expression}, + Expression, Query, Rotation, + }, + Itertools, + }, + Error, +}; +use std::{ + array, + borrow::Cow, + collections::{BTreeSet, HashSet}, + hash::Hash, + iter, +}; + +pub(crate) fn batch_size(circuit_info: &PlonkishCircuitInfo) -> usize { + let num_lookups = circuit_info.lookups.len(); + let num_permutation_polys = circuit_info.permutation_polys().len(); + chain![ + [circuit_info.preprocess_polys.len() + circuit_info.permutation_polys().len()], + circuit_info.num_witness_polys.clone(), + [num_lookups], + [2 * num_lookups + div_ceil(num_permutation_polys, max_degree(circuit_info, None) - 1)], + [1], + ] + .sum() +} + +#[allow(clippy::type_complexity)] +pub(super) fn preprocess( + param: &Pcs::Param, + circuit_info: &PlonkishCircuitInfo, +) -> Result<(SangriaProverParam, SangriaVerifierParam), Error> +where + F: PrimeField + Ord + Hash, + Pcs: PolynomialCommitmentScheme>, +{ + let challenge_offset = circuit_info.num_challenges.iter().sum::(); + let max_lookup_width = circuit_info.lookups.iter().map(Vec::len).max().unwrap_or(0); + let num_theta_primes = max_lookup_width.checked_sub(1).unwrap_or_default(); + let theta_primes = (challenge_offset..) + .take(num_theta_primes) + .map(Expression::::Challenge) + .collect_vec(); + let beta_prime = &Expression::::Challenge(challenge_offset + num_theta_primes); + + let (lookup_constraints, lookup_zero_checks) = + lookup_constraints(circuit_info, &theta_primes, beta_prime); + + let max_degree = iter::empty() + .chain(circuit_info.constraints.iter()) + .chain(lookup_constraints.iter()) + .map(Expression::degree) + .chain(circuit_info.max_degree) + .chain(Some(2)) + .max() + .unwrap(); + + let permutation_polys = circuit_info.permutation_polys(); + let preprocess_polys = iter::empty() + .chain((circuit_info.num_instances.len()..).take(circuit_info.preprocess_polys.len())) + .chain((circuit_info.num_poly()..).take(permutation_polys.len())) + .collect(); + + let num_constraints = circuit_info.constraints.len() + lookup_constraints.len(); + let num_alpha_primes = num_constraints.checked_sub(1).unwrap_or_default(); + + let products = { + let mut constraints = iter::empty() + .chain(circuit_info.constraints.iter()) + .chain(lookup_constraints.iter()) + .collect_vec(); + let folding_degrees = constraints + .iter() + .map(|constraint| folding_degree(&preprocess_polys, constraint)) + .enumerate() + .sorted_by(|a, b| b.1.cmp(&a.1)) + .collect_vec(); + if let &[a, b, ..] = &folding_degrees[..] { + if a.1 != b.1 { + constraints.swap(0, a.0); + } + } + let constraint = iter::empty() + .chain(constraints.first().cloned().cloned()) + .chain( + constraints + .into_iter() + .skip(1) + .zip((challenge_offset + num_theta_primes + 1..).map(Expression::Challenge)) + .map(|(constraint, challenge)| constraint * challenge), + ) + .sum(); + products(&preprocess_polys, &constraint) + }; + + let num_witness_polys = circuit_info.num_witness_polys.iter().sum::(); + let witness_poly_offset = + circuit_info.num_instances.len() + circuit_info.preprocess_polys.len(); + let internal_witness_poly_offset = + witness_poly_offset + num_witness_polys + permutation_polys.len(); + + let folding_polys = iter::empty() + .chain(0..circuit_info.num_instances.len()) + .chain((witness_poly_offset..).take(num_witness_polys)) + .chain((internal_witness_poly_offset..).take(3 * circuit_info.lookups.len())) + .collect::>(); + let num_folding_wintess_polys = num_witness_polys + 3 * circuit_info.lookups.len(); + let num_folding_challenges = challenge_offset + num_theta_primes + 1 + num_alpha_primes; + + let cross_term_expressions = cross_term_expressions( + circuit_info.num_instances.len(), + circuit_info.preprocess_polys.len(), + folding_polys, + num_folding_challenges, + &products, + ); + let num_cross_terms = cross_term_expressions.len(); + + let [beta, gamma, alpha] = + &array::from_fn(|idx| Expression::::Challenge(num_folding_challenges + idx)); + let (num_chunks, permutation_constraints) = permutation_constraints( + circuit_info, + max_degree, + beta, + gamma, + 3 * circuit_info.lookups.len(), + ); + + let relexed_constraint = { + let u = num_folding_challenges + 3; + let e_poly = circuit_info.num_poly() + + permutation_polys.len() + + circuit_info.lookups.len() * 3 + + num_chunks; + relaxed_expression(&products, u) + - Expression::Polynomial(Query::new(e_poly, Rotation::cur())) + }; + let zero_check_on_every_row = Expression::distribute_powers( + iter::empty() + .chain(Some(&relexed_constraint)) + .chain(permutation_constraints.iter()), + alpha, + ) * Expression::eq_xy(0); + let zero_check_expression = Expression::distribute_powers( + iter::empty() + .chain(lookup_zero_checks.iter()) + .chain(Some(&zero_check_on_every_row)), + alpha, + ); + + let (mut pp, mut vp) = HyperPlonk::preprocess(param, circuit_info)?; + let (pcs_pp, pcs_vp) = Pcs::trim(param, 1 << circuit_info.k, batch_size(circuit_info))?; + pp.pcs = pcs_pp; + vp.pcs = pcs_vp; + + Ok(( + SangriaProverParam { + pp, + num_theta_primes, + num_alpha_primes, + num_folding_wintess_polys, + num_folding_challenges, + cross_term_expressions, + zero_check_expression: zero_check_expression.clone(), + }, + SangriaVerifierParam { + vp, + num_theta_primes, + num_alpha_primes, + num_folding_wintess_polys, + num_folding_challenges, + num_cross_terms, + zero_check_expression, + }, + )) +} + +pub(crate) fn max_degree( + circuit_info: &PlonkishCircuitInfo, + lookup_constraints: Option<&[Expression]>, +) -> usize { + let lookup_constraints = lookup_constraints.map(Cow::Borrowed).unwrap_or_else(|| { + let n = circuit_info.lookups.iter().map(Vec::len).max().unwrap_or(1); + let dummy_challenges = vec![Expression::zero(); n]; + Cow::Owned( + self::lookup_constraints(circuit_info, &dummy_challenges, &dummy_challenges[0]).0, + ) + }); + iter::empty() + .chain(circuit_info.constraints.iter().map(Expression::degree)) + .chain(lookup_constraints.iter().map(Expression::degree)) + .chain(circuit_info.max_degree) + .chain(Some(2)) + .max() + .unwrap() +} + +pub(crate) fn lookup_constraints( + circuit_info: &PlonkishCircuitInfo, + theta_primes: &[Expression], + beta_prime: &Expression, +) -> (Vec>, Vec>) { + let one = &Expression::one(); + let m_offset = circuit_info.num_poly() + circuit_info.permutation_polys().len(); + let h_offset = m_offset + circuit_info.lookups.len(); + let constraints = circuit_info + .lookups + .iter() + .zip(m_offset..) + .zip((h_offset..).step_by(2)) + .flat_map(|((lookup, m), h)| { + let [m, h_input, h_table] = &[m, h, h + 1] + .map(|poly| Query::new(poly, Rotation::cur())) + .map(Expression::::Polynomial); + let (inputs, tables) = lookup + .iter() + .map(|(input, table)| (input, table)) + .unzip::<_, _, Vec<_>, Vec<_>>(); + let [input, table] = &[inputs, tables].map(|exprs| { + iter::empty() + .chain(exprs.first().cloned().cloned()) + .chain( + exprs + .into_iter() + .skip(1) + .zip(theta_primes) + .map(|(expr, theta_prime)| expr * theta_prime), + ) + .sum::>() + }); + [ + h_input * (input + beta_prime) - one, + h_table * (table + beta_prime) - m, + ] + }) + .collect_vec(); + let sum_check = (h_offset..) + .step_by(2) + .take(circuit_info.lookups.len()) + .map(|h| { + let [h_input, h_table] = &[h, h + 1] + .map(|poly| Query::new(poly, Rotation::cur())) + .map(Expression::::Polynomial); + h_input - h_table + }) + .collect_vec(); + (constraints, sum_check) +} + +pub(crate) fn folding_degree( + preprocess_polys: &HashSet, + expression: &Expression, +) -> usize { + expression.evaluate( + &|_| 0, + &|_| 0, + &|query| (!preprocess_polys.contains(&query.poly())) as usize, + &|_| 1, + &|a| a, + &|a, b| a.max(b), + &|a, b| a + b, + &|a, _| a, + ) +} diff --git a/plonkish_backend/src/backend/hyperplonk/folding/sangria/prover.rs b/plonkish_backend/src/backend/hyperplonk/folding/sangria/prover.rs new file mode 100644 index 00000000..ec837c77 --- /dev/null +++ b/plonkish_backend/src/backend/hyperplonk/folding/sangria/prover.rs @@ -0,0 +1,271 @@ +use crate::{ + backend::hyperplonk::{folding::sangria::verifier::SangriaInstance, prover::instance_polys}, + pcs::{AdditiveCommitment, Polynomial, PolynomialCommitmentScheme}, + poly::multilinear::MultilinearPolynomial, + util::{ + arithmetic::{div_ceil, powers, sum, BatchInvert, BooleanHypercube, PrimeField}, + chain, + expression::{evaluator::ExpressionRegistry, Expression}, + izip, izip_eq, + parallel::{num_threads, par_map_collect, parallelize, parallelize_iter}, + Itertools, + }, +}; +use std::{hash::Hash, iter}; + +#[derive(Debug)] +pub(crate) struct SangriaWitness +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub(crate) instance: SangriaInstance, + pub(crate) witness_polys: Vec, + pub(crate) e_poly: Pcs::Polynomial, +} + +impl SangriaWitness +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub(crate) fn init( + k: usize, + num_instances: &[usize], + num_witness_polys: usize, + num_challenges: usize, + ) -> Self { + let zero_poly = Pcs::Polynomial::from_evals(vec![F::ZERO; 1 << k]); + Self { + instance: SangriaInstance::init(num_instances, num_witness_polys, num_challenges), + witness_polys: iter::repeat_with(|| zero_poly.clone()) + .take(num_witness_polys) + .collect(), + e_poly: zero_poly, + } + } + + pub(crate) fn from_committed( + k: usize, + instances: &[&[F]], + witness_polys: impl IntoIterator, + witness_comms: impl IntoIterator, + challenges: Vec, + ) -> Self { + Self { + instance: SangriaInstance::from_committed(instances, witness_comms, challenges), + witness_polys: witness_polys.into_iter().collect(), + e_poly: Pcs::Polynomial::from_evals(vec![F::ZERO; 1 << k]), + } + } +} + +impl SangriaWitness +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, + Pcs::Commitment: AdditiveCommitment, +{ + pub(crate) fn fold( + &mut self, + rhs: &Self, + cross_term_polys: &[Pcs::Polynomial], + cross_term_comms: &[Pcs::Commitment], + r: &F, + ) { + self.instance.fold(&rhs.instance, cross_term_comms, r); + izip_eq!(&mut self.witness_polys, &rhs.witness_polys) + .for_each(|(lhs, rhs)| *lhs += (r, rhs)); + izip!(powers(*r).skip(1), chain![cross_term_polys, [&rhs.e_poly]]) + .for_each(|(power_of_r, poly)| self.e_poly += (&power_of_r, poly)); + } +} + +pub(crate) fn lookup_h_polys( + compressed_polys: &[[MultilinearPolynomial; 2]], + m_polys: &[MultilinearPolynomial], + beta: &F, +) -> Vec<[MultilinearPolynomial; 2]> { + compressed_polys + .iter() + .zip(m_polys.iter()) + .map(|(compressed_polys, m_poly)| lookup_h_poly(compressed_polys, m_poly, beta)) + .collect() +} + +fn lookup_h_poly( + compressed_polys: &[MultilinearPolynomial; 2], + m_poly: &MultilinearPolynomial, + beta: &F, +) -> [MultilinearPolynomial; 2] { + let [input, table] = compressed_polys; + let mut h_input = vec![F::ZERO; 1 << input.num_vars()]; + let mut h_table = vec![F::ZERO; 1 << input.num_vars()]; + + parallelize(&mut h_input, |(h_input, start)| { + for (h_input, input) in h_input.iter_mut().zip(input[start..].iter()) { + *h_input = *beta + input; + } + }); + parallelize(&mut h_table, |(h_table, start)| { + for (h_table, table) in h_table.iter_mut().zip(table[start..].iter()) { + *h_table = *beta + table; + } + }); + + let chunk_size = div_ceil(2 * h_input.len(), num_threads()); + parallelize_iter( + iter::empty() + .chain(h_input.chunks_mut(chunk_size)) + .chain(h_table.chunks_mut(chunk_size)), + |h| { + h.iter_mut().batch_invert(); + }, + ); + + parallelize(&mut h_table, |(h_table, start)| { + for (h_table, m) in h_table.iter_mut().zip(m_poly[start..].iter()) { + *h_table *= m; + } + }); + + if cfg!(feature = "sanity-check") { + assert_eq!(sum::(&h_input), sum::(&h_table)); + } + + [ + MultilinearPolynomial::new(h_input), + MultilinearPolynomial::new(h_table), + ] +} + +pub(crate) fn evaluate_cross_term( + cross_term_expressions: &[Expression], + num_vars: usize, + preprocess_polys: &[MultilinearPolynomial], + folded: &SangriaWitness, + incoming: &SangriaWitness, +) -> Vec> +where + F: PrimeField + Ord, + Pcs: PolynomialCommitmentScheme>, +{ + if cross_term_expressions.is_empty() { + return Vec::new(); + } + + let folded_instance_polys = instance_polys(num_vars, &folded.instance.instances); + let incoming_instance_polys = instance_polys(num_vars, &incoming.instance.instances); + let polys = iter::empty() + .chain(preprocess_polys) + .chain(&folded_instance_polys) + .chain(&folded.witness_polys) + .chain(&incoming_instance_polys) + .chain(&incoming.witness_polys) + .collect_vec(); + let challenges = iter::empty() + .chain(folded.instance.challenges.iter().cloned()) + .chain(Some(folded.instance.u)) + .chain(incoming.instance.challenges.iter().cloned()) + .chain(Some(incoming.instance.u)) + .collect_vec(); + + let cross_term_expressions = cross_term_expressions + .iter() + .map(|expression| { + expression + .simplified(Some(&challenges)) + .unwrap_or_else(Expression::zero) + }) + .collect_vec(); + + let ev = HadamardEvaluator::new(num_vars, &cross_term_expressions); + let size = 1 << num_vars; + let chunk_size = div_ceil(size, num_threads()); + let num_cross_terms = cross_term_expressions.len(); + + let mut outputs = vec![F::ZERO; num_cross_terms << num_vars]; + parallelize_iter( + outputs + .chunks_mut(chunk_size * num_cross_terms) + .zip((0..).step_by(chunk_size)), + |(outputs, start)| { + let mut data = ev.cache(); + let bs = start..(start + chunk_size).min(size); + for (b, outputs) in bs.zip(outputs.chunks_mut(num_cross_terms)) { + ev.evaluate(outputs, &mut data, polys.as_slice(), b); + } + }, + ); + + (0..num_cross_terms) + .map(|offset| par_map_collect(0..size, |idx| outputs[idx * num_cross_terms + offset])) + .map(MultilinearPolynomial::new) + .collect_vec() +} + +#[derive(Clone, Debug)] +pub(crate) struct HadamardEvaluator { + num_vars: usize, + reg: ExpressionRegistry, + lagranges: Vec, +} + +impl HadamardEvaluator { + pub(crate) fn new(num_vars: usize, expressions: &[Expression]) -> Self { + let mut reg = ExpressionRegistry::new(); + for expression in expressions.iter() { + reg.register(expression); + } + assert!(reg.eq_xys().is_empty()); + + let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); + let lagranges = reg + .lagranges() + .iter() + .map(|i| bh[i.rem_euclid(1 << num_vars) as usize]) + .collect_vec(); + + Self { + num_vars, + reg, + lagranges, + } + } + + pub(crate) fn cache(&self) -> Vec { + self.reg.cache() + } + + pub(crate) fn evaluate( + &self, + evals: &mut [F], + cache: &mut [F], + polys: &[&MultilinearPolynomial], + b: usize, + ) { + let bh = BooleanHypercube::new(self.num_vars); + if self.reg.has_identity() { + cache[self.reg.offsets().identity()] = F::from(b as u64); + } + cache[self.reg.offsets().lagranges()..] + .iter_mut() + .zip(&self.lagranges) + .for_each(|(value, i)| *value = if &b == i { F::ONE } else { F::ZERO }); + cache[self.reg.offsets().polys()..] + .iter_mut() + .zip(self.reg.polys()) + .for_each(|(value, (query, _))| { + *value = polys[query.poly()][bh.rotate(b, query.rotation())] + }); + self.reg + .indexed_calculations() + .iter() + .zip(self.reg.offsets().calculations()..) + .for_each(|(calculation, idx)| calculation.calculate(cache, idx)); + evals + .iter_mut() + .zip(self.reg.indexed_outputs()) + .for_each(|(eval, idx)| *eval = cache[*idx]) + } +} diff --git a/plonkish_backend/src/backend/hyperplonk/folding/sangria/verifier.rs b/plonkish_backend/src/backend/hyperplonk/folding/sangria/verifier.rs new file mode 100644 index 00000000..81f2c66c --- /dev/null +++ b/plonkish_backend/src/backend/hyperplonk/folding/sangria/verifier.rs @@ -0,0 +1,86 @@ +use crate::{ + pcs::{AdditiveCommitment, PolynomialCommitmentScheme}, + util::{ + arithmetic::{powers, PrimeField}, + chain, izip_eq, Itertools, + }, +}; +use std::{fmt::Debug, iter}; + +#[derive(Debug)] +pub(crate) struct SangriaInstance +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub(crate) instances: Vec>, + pub(crate) witness_comms: Vec, + pub(crate) challenges: Vec, + pub(crate) u: F, + pub(crate) e_comm: Pcs::Commitment, +} + +impl SangriaInstance +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, +{ + pub(crate) fn init( + num_instances: &[usize], + num_witness_polys: usize, + num_challenges: usize, + ) -> Self { + Self { + instances: num_instances.iter().map(|n| vec![F::ZERO; *n]).collect(), + witness_comms: iter::repeat_with(Pcs::Commitment::default) + .take(num_witness_polys) + .collect(), + challenges: vec![F::ZERO; num_challenges], + u: F::ZERO, + e_comm: Pcs::Commitment::default(), + } + } + + pub(crate) fn from_committed( + instances: &[&[F]], + witness_comms: impl IntoIterator, + challenges: Vec, + ) -> Self { + Self { + instances: instances + .iter() + .map(|instances| instances.to_vec()) + .collect(), + witness_comms: witness_comms.into_iter().collect(), + challenges, + u: F::ONE, + e_comm: Pcs::Commitment::default(), + } + } + + pub(crate) fn instance_slices(&self) -> Vec<&[F]> { + self.instances.iter().map(Vec::as_slice).collect() + } +} + +impl SangriaInstance +where + F: PrimeField, + Pcs: PolynomialCommitmentScheme, + Pcs::Commitment: AdditiveCommitment, +{ + pub(crate) fn fold(&mut self, rhs: &Self, cross_term_comms: &[Pcs::Commitment], r: &F) { + let one = F::ONE; + let powers_of_r = powers(*r).take(cross_term_comms.len() + 2).collect_vec(); + izip_eq!(&mut self.instances, &rhs.instances) + .for_each(|(lhs, rhs)| izip_eq!(lhs, rhs).for_each(|(lhs, rhs)| *lhs += &(*rhs * r))); + izip_eq!(&mut self.witness_comms, &rhs.witness_comms) + .for_each(|(lhs, rhs)| *lhs = Pcs::Commitment::sum_with_scalar([&one, r], [lhs, rhs])); + izip_eq!(&mut self.challenges, &rhs.challenges).for_each(|(lhs, rhs)| *lhs += &(*rhs * r)); + self.u += &(rhs.u * r); + self.e_comm = { + let comms = chain![[&self.e_comm], cross_term_comms, [&rhs.e_comm]]; + Pcs::Commitment::sum_with_scalar(&powers_of_r, comms) + }; + } +} diff --git a/plonkish_backend/src/backend/hyperplonk/preprocessor.rs b/plonkish_backend/src/backend/hyperplonk/preprocessor.rs index fc750dac..12552144 100644 --- a/plonkish_backend/src/backend/hyperplonk/preprocessor.rs +++ b/plonkish_backend/src/backend/hyperplonk/preprocessor.rs @@ -4,7 +4,7 @@ use crate::{ util::{ arithmetic::{div_ceil, steps, PrimeField}, chain, - expression::{CommonPolynomial, Expression, Query, Rotation}, + expression::{Expression, Query, Rotation}, Itertools, }, }; @@ -19,7 +19,6 @@ pub(super) fn batch_size(circuit_info: &PlonkishCircuitInfo) - [num_lookups], [num_lookups + div_ceil(num_permutation_polys, max_degree(circuit_info, None) - 1)], ] - .into_iter() .sum() } @@ -126,7 +125,10 @@ pub(super) fn permutation_constraints( .map(|idx| Expression::Polynomial(Query::new(*idx, Rotation::cur()))) .collect_vec(); let ids = (0..polys.len()) - .map(|idx| Expression::CommonPolynomial(CommonPolynomial::Identity(idx))) + .map(|idx| { + let offset = F::from((idx << circuit_info.k) as u64); + Expression::Constant(offset) + Expression::identity() + }) .collect_vec(); let permutations = (permutation_offset..) .map(|idx| Expression::Polynomial(Query::new(idx, Rotation::cur()))) @@ -213,7 +215,8 @@ pub(crate) mod test { #[test] fn compose_vanilla_plonk() { - let expression = vanilla_plonk_expression(); + let num_vars = 3; + let expression = vanilla_plonk_expression(num_vars); assert_eq!(expression, { let [pi, q_l, q_r, q_m, q_o, q_c, w_l, w_r, w_o, s_1, s_2, s_3] = &array::from_fn(|poly| Query::new(poly, Rotation::cur())) @@ -224,7 +227,9 @@ pub(crate) mod test { ] .map(Expression::Polynomial); let [beta, gamma, alpha] = &array::from_fn(Expression::::Challenge); - let [id_1, id_2, id_3] = array::from_fn(Expression::identity); + let [id_1, id_2, id_3] = array::from_fn(|idx| { + Expression::Constant(Fr::from((idx << num_vars) as u64)) + Expression::identity() + }); let l_1 = Expression::::lagrange(1); let one = Expression::one(); let constraints = { @@ -247,7 +252,8 @@ pub(crate) mod test { #[test] fn compose_vanilla_plonk_with_lookup() { - let expression = vanilla_plonk_with_lookup_expression(); + let num_vars = 3; + let expression = vanilla_plonk_with_lookup_expression(num_vars); assert_eq!(expression, { let [pi, q_l, q_r, q_m, q_o, q_c, q_lookup, t_l, t_r, t_o, w_l, w_r, w_o, s_1, s_2, s_3] = &array::from_fn(|poly| Query::new(poly, Rotation::cur())) @@ -263,7 +269,9 @@ pub(crate) mod test { ] .map(Expression::Polynomial); let [beta, gamma, alpha] = &array::from_fn(Expression::::Challenge); - let [id_1, id_2, id_3] = array::from_fn(Expression::identity); + let [id_1, id_2, id_3] = array::from_fn(|idx| { + Expression::Constant(Fr::from((idx << num_vars) as u64)) + Expression::identity() + }); let l_1 = &Expression::::lagrange(1); let one = &Expression::one(); let lookup_input = diff --git a/plonkish_backend/src/backend/hyperplonk/prover.rs b/plonkish_backend/src/backend/hyperplonk/prover.rs index b7144a27..1cc0e279 100644 --- a/plonkish_backend/src/backend/hyperplonk/prover.rs +++ b/plonkish_backend/src/backend/hyperplonk/prover.rs @@ -70,28 +70,15 @@ pub(super) fn lookup_compressed_polys( .map(|i| (i, bh[i.rem_euclid(1 << num_vars) as usize])) .collect::>() }; - let identities = { - let max_used_identity = expression - .used_identity() - .into_iter() - .max() - .unwrap_or_default(); - (0..=max_used_identity) - .map(|idx| (idx as u64) << num_vars) - .collect_vec() - }; lookups .iter() - .map(|lookup| { - lookup_compressed_poly(lookup, &lagranges, &identities, polys, challenges, betas) - }) + .map(|lookup| lookup_compressed_poly(lookup, &lagranges, polys, challenges, betas)) .collect() } pub(super) fn lookup_compressed_poly( lookup: &[(Expression, Expression)], lagranges: &HashSet<(i32, usize)>, - identities: &[u64], polys: &[&MultilinearPolynomial], challenges: &[F], betas: &[F], @@ -109,6 +96,7 @@ pub(super) fn lookup_compressed_poly( *compressed = expression.evaluate( &|constant| constant, &|common_poly| match common_poly { + CommonPolynomial::Identity => F::from(b as u64), CommonPolynomial::Lagrange(i) => { if lagranges.contains(&(i, b)) { F::ONE @@ -116,9 +104,6 @@ pub(super) fn lookup_compressed_poly( F::ZERO } } - CommonPolynomial::Identity(idx) => { - F::from(b as u64 + identities[idx]) - } CommonPolynomial::EqXY(_) => unreachable!(), }, &|query| polys[query.poly()][bh.rotate(b, query.rotation())], @@ -367,6 +352,27 @@ pub(super) fn prove_zero_check( challenges: Vec, y: Vec, transcript: &mut impl FieldTranscriptWrite, +) -> Result<(Vec>, Vec>), Error> { + prove_sum_check( + num_instance_poly, + expression, + F::ZERO, + polys, + challenges, + y, + transcript, + ) +} + +#[allow(clippy::type_complexity)] +pub(super) fn prove_sum_check( + num_instance_poly: usize, + expression: &Expression, + sum: F, + polys: &[&MultilinearPolynomial], + challenges: Vec, + y: Vec, + transcript: &mut impl FieldTranscriptWrite, ) -> Result<(Vec>, Vec>), Error> { let num_vars = polys[0].num_vars(); let ys = [y]; @@ -375,7 +381,7 @@ pub(super) fn prove_zero_check( &(), num_vars, virtual_poly, - F::ZERO, + sum, transcript, )?; diff --git a/plonkish_backend/src/backend/hyperplonk/util.rs b/plonkish_backend/src/backend/hyperplonk/util.rs index 7b120757..09f117d5 100644 --- a/plonkish_backend/src/backend/hyperplonk/util.rs +++ b/plonkish_backend/src/backend/hyperplonk/util.rs @@ -48,9 +48,9 @@ pub fn vanilla_plonk_circuit_info( } } -pub fn vanilla_plonk_expression() -> Expression { +pub fn vanilla_plonk_expression(num_vars: usize) -> Expression { let circuit_info = vanilla_plonk_circuit_info( - 0, + num_vars, 0, Default::default(), vec![vec![(6, 1)], vec![(7, 1)], vec![(8, 1)]], @@ -85,9 +85,9 @@ pub fn vanilla_plonk_with_lookup_circuit_info( } } -pub fn vanilla_plonk_with_lookup_expression() -> Expression { +pub fn vanilla_plonk_with_lookup_expression(num_vars: usize) -> Expression { let circuit_info = vanilla_plonk_with_lookup_circuit_info( - 0, + num_vars, 0, Default::default(), vec![vec![(10, 1)], vec![(11, 1)], vec![(12, 1)]], diff --git a/plonkish_backend/src/backend/hyperplonk/verifier.rs b/plonkish_backend/src/backend/hyperplonk/verifier.rs index 4417ba0d..c1e9b033 100644 --- a/plonkish_backend/src/backend/hyperplonk/verifier.rs +++ b/plonkish_backend/src/backend/hyperplonk/verifier.rs @@ -23,12 +23,33 @@ pub(super) fn verify_zero_check( challenges: &[F], y: &[F], transcript: &mut impl FieldTranscriptRead, +) -> Result<(Vec>, Vec>), Error> { + verify_sum_check( + num_vars, + expression, + F::ZERO, + instances, + challenges, + y, + transcript, + ) +} + +#[allow(clippy::type_complexity)] +pub(super) fn verify_sum_check( + num_vars: usize, + expression: &Expression, + sum: F, + instances: &[&[F]], + challenges: &[F], + y: &[F], + transcript: &mut impl FieldTranscriptRead, ) -> Result<(Vec>, Vec>), Error> { let (x_eval, x) = ClassicSumCheck::>::verify( &(), num_vars, expression.degree(), - F::ZERO, + sum, transcript, )?; diff --git a/plonkish_backend/src/piop/sum_check.rs b/plonkish_backend/src/piop/sum_check.rs index a4a63315..6617f5cd 100644 --- a/plonkish_backend/src/piop/sum_check.rs +++ b/plonkish_backend/src/piop/sum_check.rs @@ -66,6 +66,7 @@ pub fn evaluate( x: &[F], ) -> F { assert!(num_vars > 0 && expression.max_used_rotation_distance() <= num_vars); + let identity = identity_eval(x); let lagranges = { let bh = BooleanHypercube::new(num_vars).iter().collect_vec(); expression @@ -78,13 +79,12 @@ pub fn evaluate( .collect::>() }; let eq_xys = ys.iter().map(|y| eq_xy_eval(x, y)).collect_vec(); - let identity = identity_eval(x); expression.evaluate( &|scalar| scalar, &|poly| match poly { + CommonPolynomial::Identity => identity, CommonPolynomial::Lagrange(i) => lagranges[&i], CommonPolynomial::EqXY(idx) => eq_xys[idx], - CommonPolynomial::Identity(idx) => F::from((idx << num_vars) as u64) + identity, }, &|query| evals[&query], &|idx| challenges[idx], @@ -309,7 +309,7 @@ pub(super) mod test { run_zero_check::<$impl>( 2..16, - |_| vanilla_plonk_expression(), + |num_vars| vanilla_plonk_expression(num_vars), |_| ((), ()), |num_vars| { let (polys, challenges) = rand_vanilla_plonk_assignment( @@ -336,7 +336,7 @@ pub(super) mod test { run_zero_check::<$impl>( 2..16, - |_| vanilla_plonk_with_lookup_expression(), + |num_vars| vanilla_plonk_with_lookup_expression(num_vars), |_| ((), ()), |num_vars| { let (polys, challenges) = rand_vanilla_plonk_with_lookup_assignment( diff --git a/plonkish_backend/src/piop/sum_check/classic.rs b/plonkish_backend/src/piop/sum_check/classic.rs index 9143beb5..531c9253 100644 --- a/plonkish_backend/src/piop/sum_check/classic.rs +++ b/plonkish_backend/src/piop/sum_check/classic.rs @@ -28,7 +28,7 @@ pub struct ProverState<'a, F: Field> { degree: usize, sum: F, lagranges: HashMap, - identities: Vec, + identity: F, eq_xys: Vec>, polys: Vec>>>, challenges: &'a [F], @@ -53,18 +53,6 @@ impl<'a, F: PrimeField> ProverState<'a, F> { }) .collect() }; - let identities = (0..) - .map(|idx| F::from(idx << num_vars)) - .take( - virtual_poly - .expression - .used_identity() - .into_iter() - .max() - .unwrap_or_default() - + 1, - ) - .collect_vec(); let eq_xys = virtual_poly .ys .iter() @@ -85,7 +73,7 @@ impl<'a, F: PrimeField> ProverState<'a, F> { degree: virtual_poly.expression.degree(), sum, lagranges, - identities, + identity: F::ZERO, eq_xys, polys, challenges: virtual_poly.challenges, @@ -101,6 +89,7 @@ impl<'a, F: PrimeField> ProverState<'a, F> { fn next_round(&mut self, sum: F, challenge: &F) { self.sum = sum; + self.identity += F::from(1 << self.round) * challenge; self.lagranges.values_mut().for_each(|(b, value)| { if b.is_even() { *value *= &(F::ONE - challenge); @@ -109,9 +98,6 @@ impl<'a, F: PrimeField> ProverState<'a, F> { } *b >>= 1; }); - self.identities - .iter_mut() - .for_each(|constant| *constant += F::from(1 << self.round) * challenge); self.eq_xys .iter_mut() .for_each(|eq_xy| eq_xy.fix_var_in_place(challenge, &mut self.buf)); diff --git a/plonkish_backend/src/piop/sum_check/classic/eval.rs b/plonkish_backend/src/piop/sum_check/classic/eval.rs index 868af1a3..be3ea1de 100644 --- a/plonkish_backend/src/piop/sum_check/classic/eval.rs +++ b/plonkish_backend/src/piop/sum_check/classic/eval.rs @@ -5,7 +5,10 @@ use crate::{ barycentric_interpolate, barycentric_weights, div_ceil, steps, BooleanHypercube, PrimeField, }, - expression::{CommonPolynomial, Expression, Query, Rotation}, + expression::{ + evaluator::{ExpressionRegistry, Offsets}, + CommonPolynomial, Expression, + }, impl_index, parallel::{num_threads, parallelize_iter}, transcript::{FieldTranscriptRead, FieldTranscriptWrite}, @@ -13,12 +16,7 @@ use crate::{ Error, }; use num_integer::Integer; -use std::{ - collections::BTreeSet, - fmt::Debug, - iter, - ops::{AddAssign, Deref}, -}; +use std::{collections::BTreeSet, fmt::Debug, iter, ops::AddAssign}; #[derive(Clone, Debug)] pub struct Evaluations(Vec); @@ -70,7 +68,7 @@ impl<'rhs, F: PrimeField> AddAssign<&'rhs Evaluations> for Evaluations { impl_index!(Evaluations, 0); #[derive(Clone, Debug)] -pub struct EvaluationsProver(Vec>); +pub struct EvaluationsProver(Vec>); impl ClassicSumCheckProver for EvaluationsProver where @@ -85,7 +83,7 @@ where .chain(Some((&dense, false))) .chain(sparse.iter().zip(iter::repeat(true))) .filter_map(|(expression, is_sparse)| { - GraphEvaluator::new(state.num_vars, state.challenges, expression, is_sparse) + SumCheckEvaluator::new(state.num_vars, state.challenges, expression, is_sparse) }) .collect(), ) @@ -109,18 +107,18 @@ impl EvaluationsProver { let mut partials = vec![Evaluations::new(state.degree); div_ceil(size, chunk_size)]; for ev in self.0.iter() { if let Some(sparse_bs) = ev.sparse_bs(state) { - let mut data = ev.data(state, 0); + let mut cache = ev.cache(state); sparse_bs.into_iter().for_each(|b| { - ev.evaluate::(&mut partials[0], &mut data, state, b) + ev.evaluate::(&mut partials[0], &mut cache, state, b) }) } else { parallelize_iter( partials.iter_mut().zip((0..).step_by(chunk_size)), |(partials, start)| { let bs = start..(start + chunk_size).min(size); - let mut data = ev.data(state, start); + let mut cache = ev.cache(state); bs.for_each(|b| { - ev.evaluate::(partials, &mut data, state, b) + ev.evaluate::(partials, &mut cache, state, b) }) }, ); @@ -134,193 +132,30 @@ impl EvaluationsProver { } #[derive(Clone, Debug, Default)] -struct GraphEvaluator { +struct SumCheckEvaluator { num_vars: usize, - constants: Vec, - lagranges: Vec, - identities: Vec, - eq_xys: Vec, - rotations: Vec<(Rotation, usize)>, - polys: Vec<(usize, usize)>, - calculations: Vec>, - indexed_calculations: Vec>, - offsets: Offsets, + reg: ExpressionRegistry, sparse: Option>, } -impl GraphEvaluator { +impl SumCheckEvaluator { fn new( num_vars: usize, challenges: &[F], expression: &Expression, is_sparse: bool, ) -> Option { - let mut ev = Self { - num_vars, - constants: vec![F::ZERO, F::ONE, F::ONE.double()], - rotations: vec![(Rotation(0), num_vars)], - ..Default::default() - }; - let expression = expression.simplified(Some(challenges))?; - ev.register_expression(&expression); - ev.offsets = Offsets::new( - ev.constants.len(), - ev.lagranges.len(), - ev.identities.len(), - ev.eq_xys.len(), - ev.polys.len(), - ); - ev.indexed_calculations = ev - .calculations - .iter() - .map(|calculation| calculation.indexed(&ev.offsets)) - .collect(); - - if is_sparse { - ev.sparse = Some(expression) - } - - Some(ev) - } - - fn register( - &mut self, - field: impl FnOnce(&mut Self) -> &mut Vec, - item: &T, - ) -> usize { - let field = field(self); - if let Some(idx) = field.iter().position(|lhs| lhs == item) { - idx - } else { - let idx = field.len(); - field.push(item.clone()); - idx - } - } - - fn register_constant(&mut self, constant: &F) -> ValueSource { - ValueSource::Constant(self.register(|ev| &mut ev.constants, constant)) - } - - fn register_lagrange(&mut self, i: i32) -> ValueSource { - ValueSource::Lagrange(self.register(|ev| &mut ev.lagranges, &i)) - } - - fn register_identity(&mut self, idx: usize) -> ValueSource { - ValueSource::Identity(self.register(|ev| &mut ev.identities, &idx)) - } - - fn register_eq_xy(&mut self, idx: usize) -> ValueSource { - ValueSource::EqXY(self.register(|ev| &mut ev.eq_xys, &idx)) - } - - fn register_rotation(&mut self, rotation: Rotation) -> usize { - let rotated_poly = (rotation.0 + self.num_vars as i32) as usize; - self.register(|ev| &mut ev.rotations, &(rotation, rotated_poly)) - } + let mut reg = ExpressionRegistry::new(); + reg.register(&expression); - fn register_poly_eval(&mut self, query: &Query) -> ValueSource { - let rotation = self.register_rotation(query.rotation()); - ValueSource::Poly(self.register(|ev| &mut ev.polys, &(query.poly(), rotation))) - } - - fn register_calculation(&mut self, calculation: Calculation) -> ValueSource { - ValueSource::Calculation(self.register(|ev| &mut ev.calculations, &calculation)) - } + let sparse = is_sparse.then_some(expression); - fn register_expression(&mut self, expr: &Expression) -> ValueSource { - match expr { - Expression::Constant(constant) => self.register_constant(constant), - Expression::CommonPolynomial(poly) => match poly { - CommonPolynomial::Lagrange(i) => self.register_lagrange(*i), - CommonPolynomial::Identity(idx) => self.register_identity(*idx), - CommonPolynomial::EqXY(idx) => self.register_eq_xy(*idx), - }, - Expression::Polynomial(query) => self.register_poly_eval(query), - Expression::Challenge(_) => unreachable!(), - Expression::Negated(value) => { - if let Expression::Constant(constant) = value.deref() { - self.register_constant(&-*constant) - } else { - let value = self.register_expression(value); - if let ValueSource::Constant(idx) = value { - self.register_constant(&-self.constants[idx]) - } else { - self.register_calculation(Calculation::Negate(value)) - } - } - } - Expression::Sum(lhs, rhs) => match (lhs.deref(), rhs.deref()) { - (minuend, Expression::Negated(subtrahend)) - | (Expression::Negated(subtrahend), minuend) => { - let minuend = self.register_expression(minuend); - let subtrahend = self.register_expression(subtrahend); - match (minuend, subtrahend) { - (ValueSource::Constant(minuend), ValueSource::Constant(subtrahend)) => self - .register_constant( - &(self.constants[minuend] - &self.constants[subtrahend]), - ), - (ValueSource::Constant(0), _) => { - self.register_calculation(Calculation::Negate(subtrahend)) - } - (_, ValueSource::Constant(0)) => minuend, - _ => self.register_calculation(Calculation::Sub(minuend, subtrahend)), - } - } - _ => { - let lhs = self.register_expression(lhs); - let rhs = self.register_expression(rhs); - match (lhs, rhs) { - (ValueSource::Constant(lhs), ValueSource::Constant(rhs)) => { - self.register_constant(&(self.constants[lhs] + &self.constants[rhs])) - } - (ValueSource::Constant(0), other) | (other, ValueSource::Constant(0)) => { - other - } - _ => { - if lhs <= rhs { - self.register_calculation(Calculation::Add(lhs, rhs)) - } else { - self.register_calculation(Calculation::Add(rhs, lhs)) - } - } - } - } - }, - Expression::Product(lhs, rhs) => { - let lhs = self.register_expression(lhs); - let rhs = self.register_expression(rhs); - match (lhs, rhs) { - (ValueSource::Constant(0), _) | (_, ValueSource::Constant(0)) => { - ValueSource::Constant(0) - } - (ValueSource::Constant(1), other) | (other, ValueSource::Constant(1)) => other, - (ValueSource::Constant(2), other) | (other, ValueSource::Constant(2)) => { - self.register_calculation(Calculation::Add(other, other)) - } - (lhs, rhs) => { - if lhs <= rhs { - self.register_calculation(Calculation::Mul(lhs, rhs)) - } else { - self.register_calculation(Calculation::Mul(rhs, lhs)) - } - } - } - } - Expression::Scaled(value, scalar) => { - if scalar == &F::ZERO { - ValueSource::Constant(0) - } else if scalar == &F::ONE { - self.register_expression(value) - } else { - let value = self.register_expression(value); - let scalar = self.register_constant(scalar); - self.register_calculation(Calculation::Mul(value, scalar)) - } - } - Expression::DistributePowers(_, _) => unreachable!(), - } + Some(Self { + num_vars, + reg, + sparse, + }) } fn sparse_bs(&self, state: &ProverState) -> Option> { @@ -329,8 +164,8 @@ impl GraphEvaluator { .evaluate( &|_| None, &|poly| match poly { + CommonPolynomial::Identity => unimplemented!(), CommonPolynomial::Lagrange(i) => Some(vec![state.lagranges[&i].0 >> 1]), - CommonPolynomial::Identity(_) => unimplemented!(), _ => None, }, &|_| None, @@ -360,42 +195,31 @@ impl GraphEvaluator { }) } - fn data(&self, state: &ProverState, start: usize) -> EvaluatorData { - let mut data = EvaluatorData { - offsets: self.offsets, - bs: vec![(0, 0); self.rotations.len()], - lagrange_steps: vec![F::ZERO; self.lagranges.len()], - identity_step_first: F::from(1 << (state.round + 1)) - - F::from(((state.degree - 1) << state.round) as u64), + fn cache(&self, state: &ProverState) -> EvaluatorCache { + EvaluatorCache { + offsets: *self.reg.offsets(), + bs: vec![(0, 0); self.reg.rotations().len()], identity_step: F::from(1 << state.round), - eq_xy_steps: vec![F::ZERO; self.eq_xys.len()], - poly_steps: vec![F::ZERO; self.polys.len()], - calculations: vec![F::ZERO; self.offsets.calculations() + self.calculations.len()], - }; - data.calculations[..self.constants.len()].clone_from_slice(&self.constants); - data.calculations[self.offsets.identities()..] - .iter_mut() - .zip(self.identities.iter()) - .for_each(|(eval, idx)| { - *eval = state.identities[*idx] - + F::from((1 << state.round) + (start << (state.round + 1)) as u64) - - data.identity_step_first; - }); - data + lagrange_steps: vec![F::ZERO; self.reg.lagranges().len()], + eq_xy_steps: vec![F::ZERO; self.reg.eq_xys().len()], + poly_steps: vec![F::ZERO; self.reg.polys().len()], + cache: self.reg.cache(), + } } fn evaluate_polys_next( &self, - data: &mut EvaluatorData, + cache: &mut EvaluatorCache, state: &ProverState, b: usize, ) { if IS_FIRST_ROUND && IS_FIRST_POINT { let bh = BooleanHypercube::new(self.num_vars); - data.bs + cache + .bs .iter_mut() - .zip(self.rotations.iter()) - .for_each(|(bs, (rotation, _))| { + .zip(self.reg.rotations()) + .for_each(|(bs, rotation)| { let [b_0, b_1] = [b << 1, (b << 1) + 1].map(|b| bh.rotate(b, *rotation)); *bs = (b_0, b_1); }); @@ -403,12 +227,15 @@ impl GraphEvaluator { if IS_FIRST_POINT { let (b_0, b_1) = if IS_FIRST_ROUND { - data.bs[0] + cache.bs[0] } else { (b << 1, (b << 1) + 1) }; - data.lagrange_iter_mut() - .zip(self.lagranges.iter()) + cache.cache[cache.offsets.identity()] = + state.identity + F::from(((1 << state.round) + (b << (state.round + 1))) as u64); + cache + .lagrange_iter_mut() + .zip(self.reg.lagranges()) .for_each(|((eval, step), i)| { let lagrange = &state.lagranges[i]; if b == lagrange.0 >> 1 { @@ -423,37 +250,39 @@ impl GraphEvaluator { *step = F::ZERO; } }); - data.identity_iter_mut() - .for_each(|(eval, (step_first, _))| *eval += step_first); - data.eq_xy_iter_mut() - .zip(self.eq_xys.iter()) + cache + .eq_xy_iter_mut() + .zip(self.reg.eq_xys()) .for_each(|((eval, step), idx)| { *eval = state.eq_xys[*idx][b_1]; *step = state.eq_xys[*idx][b_1] - &state.eq_xys[*idx][b_0]; }); - data.poly_iter_mut().zip(self.polys.iter()).for_each( - |(((eval, step), bs), (poly, rotation))| { + cache.poly_iter_mut().zip(self.reg.polys()).for_each( + |(((eval, step), bs), (query, rotation))| { if IS_FIRST_ROUND { let (b_0, b_1) = bs[*rotation]; - let poly = &state.polys[*poly][self.num_vars]; + let poly = &state.polys[query.poly()][self.num_vars]; *eval = poly[b_1]; *step = poly[b_1] - &poly[b_0]; } else { - let poly = &state.polys[*poly][self.rotations[*rotation].1]; + let rotation = (self.num_vars as i32 + query.rotation().0) as usize; + let poly = &state.polys[query.poly()][rotation]; *eval = poly[b_1]; *step = poly[b_1] - &poly[b_0]; } }, ); } else { - data.lagrange_iter_mut() + cache.cache[cache.offsets.identity()] += &cache.identity_step; + cache + .lagrange_iter_mut() .for_each(|(eval, step)| *eval += step as &_); - data.eq_xy_iter_mut() + cache + .eq_xy_iter_mut() .for_each(|(eval, step)| *eval += step as &_); - data.poly_iter_mut() + cache + .poly_iter_mut() .for_each(|((eval, step), _)| *eval += step as &_); - data.identity_iter_mut() - .for_each(|(eval, (_, step))| *eval += step); } } @@ -461,167 +290,64 @@ impl GraphEvaluator { &self, eval: &mut F, state: &ProverState, - data: &mut EvaluatorData, + cache: &mut EvaluatorCache, b: usize, ) { - self.evaluate_polys_next::(data, state, b); + self.evaluate_polys_next::(cache, state, b); for (calculation, idx) in self - .indexed_calculations + .reg + .indexed_calculations() .iter() - .zip(self.offsets.calculations()..) + .zip(self.reg.offsets().calculations()..) { - calculation.calculate(&mut data.calculations, idx); + calculation.calculate(&mut cache.cache, idx); } - *eval += data.calculations.last().unwrap(); + *eval += cache.cache.last().unwrap(); } fn evaluate( &self, evals: &mut Evaluations, - data: &mut EvaluatorData, + cache: &mut EvaluatorCache, state: &ProverState, b: usize, ) { assert!(evals.0.len() > 2); - self.evaluate_next::(&mut evals[1], state, data, b); + self.evaluate_next::(&mut evals[1], state, cache, b); for eval in evals[2..].iter_mut() { - self.evaluate_next::(eval, state, data, b); - } - } -} - -#[derive(Clone, Copy, Debug, Default)] -struct Offsets(usize, usize, usize, usize, usize); - -impl Offsets { - fn new( - num_constants: usize, - num_lagranges: usize, - num_identities: usize, - num_eq_xys: usize, - num_polys: usize, - ) -> Self { - let mut offset = Self::default(); - offset.0 = num_constants; - offset.1 = offset.0 + num_lagranges; - offset.2 = offset.1 + num_identities; - offset.3 = offset.2 + num_eq_xys; - offset.4 = offset.3 + num_polys; - offset - } - - fn lagranges(&self) -> usize { - self.0 - } - - fn identities(&self) -> usize { - self.1 - } - - fn eq_xys(&self) -> usize { - self.2 - } - - fn polys(&self) -> usize { - self.3 - } - - fn calculations(&self) -> usize { - self.4 - } -} - -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -enum ValueSource { - Constant(usize), - Lagrange(usize), - Identity(usize), - EqXY(usize), - Poly(usize), - Calculation(usize), -} - -impl ValueSource { - fn indexed(&self, offsets: &Offsets) -> usize { - use ValueSource::*; - match self { - Constant(idx) => *idx, - Lagrange(idx) => offsets.lagranges() + idx, - Identity(idx) => offsets.identities() + idx, - EqXY(idx) => offsets.eq_xys() + idx, - Poly(idx) => offsets.polys() + idx, - Calculation(idx) => offsets.calculations() + idx, + self.evaluate_next::(eval, state, cache, b); } } } -#[derive(Clone, Debug, PartialEq, Eq)] -enum Calculation { - Negate(T), - Add(T, T), - Sub(T, T), - Mul(T, T), -} - -impl Calculation { - fn indexed(&self, offsets: &Offsets) -> Calculation { - use Calculation::*; - match self { - Negate(value) => Negate(value.indexed(offsets)), - Add(lhs, rhs) => Add(lhs.indexed(offsets), rhs.indexed(offsets)), - Sub(lhs, rhs) => Sub(lhs.indexed(offsets), rhs.indexed(offsets)), - Mul(lhs, rhs) => Mul(lhs.indexed(offsets), rhs.indexed(offsets)), - } - } -} - -impl Calculation { - fn calculate(&self, data: &mut [F], idx: usize) { - use Calculation::*; - data[idx] = match self { - Negate(value) => -data[*value], - Add(lhs, rhs) => data[*lhs] + &data[*rhs], - Sub(lhs, rhs) => data[*lhs] - &data[*rhs], - Mul(lhs, rhs) => data[*lhs] * &data[*rhs], - }; - } -} - -#[derive(Debug)] -struct EvaluatorData { +#[derive(Debug, Default)] +struct EvaluatorCache { offsets: Offsets, bs: Vec<(usize, usize)>, - lagrange_steps: Vec, - identity_step_first: F, identity_step: F, + lagrange_steps: Vec, eq_xy_steps: Vec, poly_steps: Vec, - calculations: Vec, + cache: Vec, } -impl EvaluatorData { +impl EvaluatorCache { fn lagrange_iter_mut(&mut self) -> impl Iterator { - self.calculations[self.offsets.lagranges()..] + self.cache[self.offsets.lagranges()..] .iter_mut() .zip(self.lagrange_steps.iter_mut()) } - fn identity_iter_mut(&mut self) -> impl Iterator { - self.calculations[self.offsets.identities()..self.offsets.eq_xys()] - .iter_mut() - .zip(iter::repeat(&self.identity_step_first).zip(iter::repeat(&self.identity_step))) - } - fn eq_xy_iter_mut(&mut self) -> impl Iterator { - self.calculations[self.offsets.eq_xys()..] + self.cache[self.offsets.eq_xys()..] .iter_mut() .zip(self.eq_xy_steps.iter_mut()) } fn poly_iter_mut(&mut self) -> impl Iterator { - self.calculations[self.offsets.polys()..] + self.cache[self.offsets.polys()..] .iter_mut() .zip(self.poly_steps.iter_mut()) .zip(iter::repeat(self.bs.as_slice())) diff --git a/plonkish_backend/src/util/expression.rs b/plonkish_backend/src/util/expression.rs index 4f6ae24d..f2c05b8a 100644 --- a/plonkish_backend/src/util/expression.rs +++ b/plonkish_backend/src/util/expression.rs @@ -7,6 +7,9 @@ use std::{ ops::{Add, Mul, Neg, Sub}, }; +pub(crate) mod evaluator; +pub mod relaxed; + #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Rotation(pub i32); @@ -56,9 +59,9 @@ impl Query { #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub enum CommonPolynomial { + Identity, Lagrange(i32), EqXY(usize), - Identity(usize), } #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] @@ -75,6 +78,10 @@ pub enum Expression { } impl Expression { + pub fn identity() -> Self { + Expression::CommonPolynomial(CommonPolynomial::Identity) + } + pub fn lagrange(i: i32) -> Self { Expression::CommonPolynomial(CommonPolynomial::Lagrange(i)) } @@ -83,10 +90,6 @@ impl Expression { Expression::CommonPolynomial(CommonPolynomial::EqXY(idx)) } - pub fn identity(idx: usize) -> Self { - Expression::CommonPolynomial(CommonPolynomial::Identity(idx)) - } - pub fn distribute_powers<'a>( exprs: impl IntoIterator + 'a, base: &Self, @@ -188,16 +191,6 @@ impl Expression { ) } - pub fn used_identity(&self) -> BTreeSet { - self.used_primitive( - &|poly| match poly { - CommonPolynomial::Identity(idx) => idx.into(), - _ => None, - }, - &|_| None, - ) - } - pub fn used_query(&self) -> BTreeSet { self.used_primitive(&|_| None, &|query| query.into()) } @@ -257,8 +250,8 @@ impl Expression { match self { Expression::Constant(constant) => write!(writer, "{:?}", *constant), Expression::CommonPolynomial(poly) => match poly { + CommonPolynomial::Identity => write!(writer, "id"), CommonPolynomial::Lagrange(i) => write!(writer, "l_{i}"), - CommonPolynomial::Identity(idx) => write!(writer, "id_{idx}"), CommonPolynomial::EqXY(idx) => write!(writer, "eq_{idx}"), }, Expression::Polynomial(query) => { @@ -331,57 +324,164 @@ impl Expression { } pub fn simplified(&self, challenges: Option<&[F]>) -> Option> { - let combine = |(scalar, expression): (F, Option>)| { - if scalar == F::ZERO { - None - } else if let Some(expression) = expression { - if scalar == F::ONE { - Some(expression) - } else if scalar == -F::ONE { - Some(-expression) - } else { - Some(expression * scalar) + #[derive(Clone)] + enum Case { + Constant(F), + Sum(F, Expression), + Scaled(F, F, Expression), + } + + impl Case { + fn into_simplified(self) -> Self { + match self { + Case::Scaled(scalar, constant, expression) => { + if scalar == F::ZERO { + Case::Constant(F::ZERO) + } else if scalar == F::ONE { + Case::Sum(constant, expression) + } else if scalar == -F::ONE { + Case::Sum(-constant, -expression) + } else { + Case::Scaled(scalar, constant, expression) + } + } + rest => rest, } - } else { - Some(Expression::Constant(scalar)) } - }; - let output = self.evaluate( - &|constant| (constant, None), - &|poly| (F::ONE, Some(poly.into())), - &|query| (F::ONE, Some(query.into())), + + fn into_expression(self) -> Option> { + match self { + Case::Constant(constant) => Some(Expression::Constant(constant)), + Case::Sum(constant, expression) => { + if constant == F::ZERO { + Some(expression) + } else { + Some(expression + Expression::Constant(constant)) + } + } + Case::Scaled(scalar, constant, expression) => { + debug_assert!(![F::ZERO, F::ONE, -F::ONE].contains(&scalar)); + Case::Sum(scalar * constant, expression * scalar).into_expression() + } + } + } + } + + impl Add for Case { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Case::Constant(lhs), Case::Constant(rhs)) => Case::Constant(lhs + rhs), + (Case::Constant(lhs), Case::Sum(rhs, expression)) + | (Case::Sum(rhs, expression), Case::Constant(lhs)) => { + Case::Sum(lhs + rhs, expression) + } + ( + Case::Sum(lhs_constant, lhs_expression), + Case::Sum(rhs_constant, rhs_expression), + ) => Case::Sum(lhs_constant + rhs_constant, lhs_expression + rhs_expression), + (Case::Constant(lhs), Case::Scaled(scalar, rhs, expression)) + | (Case::Scaled(scalar, rhs, expression), Case::Constant(lhs)) => { + Case::Sum(lhs + scalar * rhs, expression * scalar) + } + ( + Case::Sum(lhs_constant, lhs_expression), + Case::Scaled(rhs_scalar, rhs_constant, rhs_expression), + ) + | ( + Case::Scaled(rhs_scalar, rhs_constant, rhs_expression), + Case::Sum(lhs_constant, lhs_expression), + ) => { + let rhs_constant = rhs_scalar * rhs_constant; + let rhs_expression = rhs_expression * rhs_scalar; + Case::Sum(lhs_constant + rhs_constant, lhs_expression + rhs_expression) + } + ( + Case::Scaled(lhs_scalar, lhs_constant, lhs_expression), + Case::Scaled(rhs_scalar, rhs_constant, rhs_expression), + ) => { + let lhs_constant = lhs_scalar * lhs_constant; + let lhs_expression = lhs_expression * lhs_scalar; + let rhs_constant = rhs_scalar * rhs_constant; + let rhs_expression = rhs_expression * rhs_scalar; + Case::Sum(lhs_constant + rhs_constant, lhs_expression + rhs_expression) + } + } + } + } + + impl Neg for Case { + type Output = Self; + + fn neg(self) -> Self::Output { + match self { + Case::Constant(constant) => Case::Constant(-constant), + Case::Sum(constant, expression) => Case::Sum(-constant, -expression), + Case::Scaled(scalar, constant, expression) => { + Case::Scaled(-scalar, constant, expression) + } + } + .into_simplified() + } + } + + impl Mul for Case { + type Output = Self; + + fn mul(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (Case::Constant(lhs), Case::Constant(rhs)) => Case::Constant(lhs * rhs), + (Case::Constant(scalar), Case::Sum(constant, expression)) + | (Case::Sum(constant, expression), Case::Constant(scalar)) => { + Case::Scaled(scalar, constant, expression) + } + (Case::Constant(lhs), Case::Scaled(rhs, constant, expression)) + | (Case::Scaled(rhs, constant, expression), Case::Constant(lhs)) => { + Case::Scaled(lhs * rhs, constant, expression) + } + (lhs, rhs) => match (lhs.into_expression(), rhs.into_expression()) { + (Some(lhs), Some(rhs)) => Case::Sum(F::ZERO, lhs * rhs), + (Some(expression), None) | (None, Some(expression)) => { + Case::Sum(F::ZERO, expression) + } + (None, None) => Case::Constant(F::ZERO), + }, + } + .into_simplified() + } + } + + impl Mul for Case { + type Output = Self; + + fn mul(self, rhs: F) -> Self::Output { + match self { + Case::Constant(lhs) => Case::Constant(lhs * rhs), + Case::Sum(constant, expression) => Case::Scaled(rhs, constant, expression), + Case::Scaled(lhs, constant, expression) => { + Case::Scaled(lhs * rhs, constant, expression) + } + } + .into_simplified() + } + } + + self.evaluate( + &|constant| Case::Constant(constant), + &|poly| Case::Sum(F::ZERO, poly.into()), + &|query| Case::Sum(F::ZERO, query.into()), &|challenge| { challenges - .map(|challenges| (challenges[challenge], None)) - .unwrap_or_else(|| (F::ONE, Some(Expression::Challenge(challenge)))) + .map(|challenges| Case::Constant(challenges[challenge])) + .unwrap_or_else(|| Case::Sum(F::ZERO, Expression::Challenge(challenge))) }, - &|(scalar, expression)| match expression { - Some(expression) if scalar == F::ONE => (F::ONE, Some(-expression)), - _ => (-scalar, expression), - }, - &|lhs, rhs| match (lhs, rhs) { - ((lhs, None), (rhs, None)) => (lhs + rhs, None), - (lhs, rhs) => { - let output = match (combine(lhs), combine(rhs)) { - (Some(lhs), Some(rhs)) => Some(lhs + rhs), - (Some(expression), None) | (None, Some(expression)) => Some(expression), - (None, None) => None, - }; - let scalar = if output.is_some() { F::ONE } else { F::ZERO }; - (scalar, output) - } - }, - &|(lhs_scalar, lhs_expression), (rhs_scalar, rhs_expression)| { - let output = match (lhs_expression, rhs_expression) { - (Some(lhs), Some(rhs)) => Some(lhs * rhs), - (Some(expression), None) | (None, Some(expression)) => Some(expression), - (None, None) => None, - }; - (lhs_scalar * rhs_scalar, output) - }, - &|(lhs, expression), rhs| (lhs * rhs, expression), - ); - combine(output) + &|case| -case, + &|lhs, rhs| lhs + rhs, + &|lhs, rhs| lhs * rhs, + &|lhs, rhs| lhs * rhs, + ) + .into_expression() } } @@ -465,8 +565,8 @@ fn merge_left_right( ) -> Option> { match (lhs, rhs) { (Some(lhs), None) | (None, Some(lhs)) => Some(lhs), - (Some(mut lhs), Some(rhs)) => { - lhs.extend(rhs); + (Some(mut lhs), Some(mut rhs)) => { + lhs.append(&mut rhs); Some(lhs) } _ => None, diff --git a/plonkish_backend/src/util/expression/evaluator.rs b/plonkish_backend/src/util/expression/evaluator.rs new file mode 100644 index 00000000..742d3f5b --- /dev/null +++ b/plonkish_backend/src/util/expression/evaluator.rs @@ -0,0 +1,324 @@ +use crate::util::{ + arithmetic::Field, + expression::{CommonPolynomial, Expression, Query, Rotation}, +}; +use std::{fmt::Debug, ops::Deref}; + +#[derive(Clone, Debug, Default)] +pub(crate) struct ExpressionRegistry { + offsets: Offsets, + constants: Vec, + has_identity: bool, + lagranges: Vec, + eq_xys: Vec, + rotations: Vec, + polys: Vec<(Query, usize)>, + calculations: Vec>, + indexed_calculations: Vec>, + outputs: Vec, + indexed_outputs: Vec, +} + +impl ExpressionRegistry { + pub(crate) fn new() -> Self { + Self { + constants: vec![F::ZERO, F::ONE, F::ONE.double()], + rotations: vec![Rotation(0)], + ..Default::default() + } + } + + pub(crate) fn register(&mut self, expression: &Expression) { + let output = self.register_expression(expression); + self.offsets = Offsets::new( + self.constants.len(), + self.lagranges.len(), + self.eq_xys.len(), + self.polys.len(), + ); + self.indexed_calculations = self + .calculations + .iter() + .map(|calculation| calculation.indexed(&self.offsets)) + .collect(); + self.outputs.push(output); + self.indexed_outputs = self + .outputs + .iter() + .map(|output| output.indexed(&self.offsets)) + .collect(); + } + + pub(crate) fn offsets(&self) -> &Offsets { + &self.offsets + } + + pub(crate) fn has_identity(&self) -> bool { + self.has_identity + } + + pub(crate) fn lagranges(&self) -> &[i32] { + &self.lagranges + } + + pub(crate) fn eq_xys(&self) -> &[usize] { + &self.eq_xys + } + + pub(crate) fn rotations(&self) -> &[Rotation] { + &self.rotations + } + + pub(crate) fn polys(&self) -> &[(Query, usize)] { + &self.polys + } + + pub(crate) fn indexed_calculations(&self) -> &[Calculation] { + &self.indexed_calculations + } + + pub(crate) fn indexed_outputs(&self) -> &[usize] { + &self.indexed_outputs + } + + pub(crate) fn cache(&self) -> Vec { + let mut cache = vec![F::ZERO; self.offsets.calculations() + self.calculations.len()]; + cache[..self.constants.len()].clone_from_slice(&self.constants); + cache + } + + fn register_value( + &mut self, + field: impl FnOnce(&mut Self) -> &mut Vec, + item: &T, + ) -> usize { + let field = field(self); + if let Some(idx) = field.iter().position(|lhs| lhs == item) { + idx + } else { + let idx = field.len(); + field.push(item.clone()); + idx + } + } + + fn register_constant(&mut self, constant: &F) -> ValueSource { + ValueSource::Constant(self.register_value(|ev| &mut ev.constants, constant)) + } + + fn register_identity(&mut self) -> ValueSource { + self.has_identity = true; + ValueSource::Identity + } + + fn register_lagrange(&mut self, i: i32) -> ValueSource { + ValueSource::Lagrange(self.register_value(|ev| &mut ev.lagranges, &i)) + } + + fn register_eq_xy(&mut self, idx: usize) -> ValueSource { + ValueSource::EqXY(self.register_value(|ev| &mut ev.eq_xys, &idx)) + } + + fn register_rotation(&mut self, rotation: Rotation) -> usize { + self.register_value(|ev| &mut ev.rotations, &(rotation)) + } + + fn register_poly(&mut self, query: &Query) -> ValueSource { + let rotation = self.register_rotation(query.rotation()); + ValueSource::Poly(self.register_value(|ev| &mut ev.polys, &(*query, rotation))) + } + + fn register_calculation(&mut self, calculation: Calculation) -> ValueSource { + ValueSource::Calculation(self.register_value(|ev| &mut ev.calculations, &calculation)) + } + + fn register_expression(&mut self, expr: &Expression) -> ValueSource { + match expr { + Expression::Constant(constant) => self.register_constant(constant), + Expression::CommonPolynomial(poly) => match poly { + CommonPolynomial::Identity => self.register_identity(), + CommonPolynomial::Lagrange(i) => self.register_lagrange(*i), + CommonPolynomial::EqXY(idx) => self.register_eq_xy(*idx), + }, + Expression::Polynomial(query) => self.register_poly(query), + Expression::Challenge(_) => unreachable!(), + Expression::Negated(value) => { + if let Expression::Constant(constant) = value.deref() { + self.register_constant(&-*constant) + } else { + let value = self.register_expression(value); + if let ValueSource::Constant(idx) = value { + self.register_constant(&-self.constants[idx]) + } else { + self.register_calculation(Calculation::Negated(value)) + } + } + } + Expression::Sum(lhs, rhs) => match (lhs.deref(), rhs.deref()) { + (minuend, Expression::Negated(subtrahend)) + | (Expression::Negated(subtrahend), minuend) => { + let minuend = self.register_expression(minuend); + let subtrahend = self.register_expression(subtrahend); + match (minuend, subtrahend) { + (ValueSource::Constant(minuend), ValueSource::Constant(subtrahend)) => self + .register_constant( + &(self.constants[minuend] - &self.constants[subtrahend]), + ), + (ValueSource::Constant(0), _) => { + self.register_calculation(Calculation::Negated(subtrahend)) + } + (_, ValueSource::Constant(0)) => minuend, + _ => self.register_calculation(Calculation::Sub(minuend, subtrahend)), + } + } + _ => { + let lhs = self.register_expression(lhs); + let rhs = self.register_expression(rhs); + match (lhs, rhs) { + (ValueSource::Constant(lhs), ValueSource::Constant(rhs)) => { + self.register_constant(&(self.constants[lhs] + &self.constants[rhs])) + } + (ValueSource::Constant(0), other) | (other, ValueSource::Constant(0)) => { + other + } + _ => { + if lhs <= rhs { + self.register_calculation(Calculation::Add(lhs, rhs)) + } else { + self.register_calculation(Calculation::Add(rhs, lhs)) + } + } + } + } + }, + Expression::Product(lhs, rhs) => { + let lhs = self.register_expression(lhs); + let rhs = self.register_expression(rhs); + match (lhs, rhs) { + (ValueSource::Constant(0), _) | (_, ValueSource::Constant(0)) => { + ValueSource::Constant(0) + } + (ValueSource::Constant(1), other) | (other, ValueSource::Constant(1)) => other, + (ValueSource::Constant(2), other) | (other, ValueSource::Constant(2)) => { + self.register_calculation(Calculation::Add(other, other)) + } + (lhs, rhs) => { + if lhs <= rhs { + self.register_calculation(Calculation::Mul(lhs, rhs)) + } else { + self.register_calculation(Calculation::Mul(rhs, lhs)) + } + } + } + } + Expression::Scaled(value, scalar) => { + if scalar == &F::ZERO { + ValueSource::Constant(0) + } else if scalar == &F::ONE { + self.register_expression(value) + } else { + let value = self.register_expression(value); + let scalar = self.register_constant(scalar); + self.register_calculation(Calculation::Mul(value, scalar)) + } + } + Expression::DistributePowers(_, _) => unreachable!(), + } + } +} + +#[derive(Clone, Copy, Debug, Default)] +pub(crate) struct Offsets(usize, usize, usize, usize, usize); + +impl Offsets { + fn new( + num_constants: usize, + num_lagranges: usize, + num_eq_xys: usize, + num_polys: usize, + ) -> Self { + let mut offset = Self::default(); + offset.0 = num_constants; + offset.1 = offset.0 + 1; + offset.2 = offset.1 + num_lagranges; + offset.3 = offset.2 + num_eq_xys; + offset.4 = offset.3 + num_polys; + offset + } + + pub(crate) fn identity(&self) -> usize { + self.0 + } + + pub(crate) fn lagranges(&self) -> usize { + self.1 + } + + pub(crate) fn eq_xys(&self) -> usize { + self.2 + } + + pub(crate) fn polys(&self) -> usize { + self.3 + } + + pub(crate) fn calculations(&self) -> usize { + self.4 + } +} + +#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +enum ValueSource { + Constant(usize), + Identity, + Lagrange(usize), + EqXY(usize), + Poly(usize), + Calculation(usize), +} + +impl ValueSource { + fn indexed(&self, offsets: &Offsets) -> usize { + use ValueSource::*; + match self { + Constant(idx) => *idx, + Identity => offsets.identity(), + Lagrange(idx) => offsets.lagranges() + idx, + EqXY(idx) => offsets.eq_xys() + idx, + Poly(idx) => offsets.polys() + idx, + Calculation(idx) => offsets.calculations() + idx, + } + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub(crate) enum Calculation { + Negated(T), + Add(T, T), + Sub(T, T), + Mul(T, T), +} + +impl Calculation { + fn indexed(&self, offsets: &Offsets) -> Calculation { + use Calculation::*; + match self { + Negated(value) => Negated(value.indexed(offsets)), + Add(lhs, rhs) => Add(lhs.indexed(offsets), rhs.indexed(offsets)), + Sub(lhs, rhs) => Sub(lhs.indexed(offsets), rhs.indexed(offsets)), + Mul(lhs, rhs) => Mul(lhs.indexed(offsets), rhs.indexed(offsets)), + } + } +} + +impl Calculation { + pub(crate) fn calculate(&self, cache: &mut [F], idx: usize) { + use Calculation::*; + cache[idx] = match self { + Negated(value) => -cache[*value], + Add(lhs, rhs) => cache[*lhs] + &cache[*rhs], + Sub(lhs, rhs) => cache[*lhs] - &cache[*rhs], + Mul(lhs, rhs) => cache[*lhs] * &cache[*rhs], + }; + } +} diff --git a/plonkish_backend/src/util/expression/relaxed.rs b/plonkish_backend/src/util/expression/relaxed.rs new file mode 100644 index 00000000..5f9cd0c3 --- /dev/null +++ b/plonkish_backend/src/util/expression/relaxed.rs @@ -0,0 +1,245 @@ +use crate::util::{ + arithmetic::PrimeField, + expression::{Expression, Query}, + BitIndex, Itertools, +}; +use std::{ + collections::{BTreeMap, BTreeSet, HashSet}, + fmt::Debug, + iter, +}; + +pub(crate) fn cross_term_expressions( + num_instance_polys: usize, + num_preprocess_polys: usize, + folding_polys: BTreeSet, + num_challenges: usize, + products: &[Product], +) -> Vec> { + let folding_degree = folding_degree(products); + let num_ts = folding_degree.checked_sub(1).unwrap_or_default(); + let u = num_challenges; + let folding_poly_indices = folding_polys.iter().zip(0..).collect::>(); + + products + .iter() + .fold( + vec![BTreeMap::>, Expression>::new(); num_ts], + |mut scalars, product| { + let (common_scalar, common_expr) = product.preprocess.evaluate( + &|constant| (constant, Vec::new()), + &|common_poly| (F::ONE, vec![Expression::CommonPolynomial(common_poly)]), + &|query| { + let poly = query.poly() - num_instance_polys; + ( + F::ONE, + vec![Expression::Polynomial(Query::new(poly, query.rotation()))], + ) + }, + &|_| unreachable!(), + &|(scalar, expr)| (-scalar, expr), + &|_, _| unreachable!(), + &|(lhs_scalar, lhs_expr), (rhs_scalar, rhs_expr)| { + (lhs_scalar * rhs_scalar, [lhs_expr, rhs_expr].concat()) + }, + &|(lhs, expr), rhs| (lhs * rhs, expr), + ); + for idx in 1usize..(1 << folding_degree) - 1 { + let (scalar, mut exprs) = iter::empty() + .chain(iter::repeat(None).take(folding_degree - product.folding_degree())) + .chain(product.foldees.iter().map(Some)) + .enumerate() + .fold( + (Expression::Constant(common_scalar), common_expr.clone()), + |(mut scalar, mut exprs), (nth, foldee)| { + let (poly_offset, challenge_offset) = if idx.nth_bit(nth) { + ( + num_preprocess_polys + folding_poly_indices.len(), + num_challenges + 1, + ) + } else { + (num_preprocess_polys, 0) + }; + match foldee { + None => { + scalar = + &scalar * Expression::Challenge(challenge_offset + u) + } + Some(Expression::Challenge(challenge)) => { + scalar = &scalar + * Expression::Challenge(challenge_offset + challenge) + } + Some(Expression::Polynomial(query)) => { + let poly = + poly_offset + folding_poly_indices[&query.poly()]; + let query = Query::new(poly, query.rotation()); + exprs.push(Expression::Polynomial(query)); + } + _ => unreachable!(), + } + (scalar, exprs) + }, + ); + exprs.sort_unstable(); + scalars[idx.count_ones() as usize - 1] + .entry(exprs) + .and_modify(|value| *value = value as &Expression<_> + &scalar) + .or_insert(scalar); + } + scalars + }, + ) + .into_iter() + .map(|exprs| { + exprs + .into_iter() + .map(|(polys, scalar)| polys.into_iter().product::>() * scalar) + .sum::>() + }) + .collect_vec() +} + +pub(crate) fn relaxed_expression( + products: &[Product], + u: usize, +) -> Expression { + let folding_degree = folding_degree(products); + let powers_of_u = iter::successors(Some(Expression::::one()), |power_of_u| { + Some(power_of_u * Expression::Challenge(u)) + }) + .take(folding_degree + 1) + .collect_vec(); + products + .iter() + .map(|product| { + &powers_of_u[folding_degree - product.folding_degree()] * product.expression() + }) + .sum() +} + +pub(crate) fn products( + preprocess_polys: &HashSet, + constraint: &Expression, +) -> Vec> { + let products = constraint.evaluate( + &|constant| vec![Product::new(Expression::Constant(constant), Vec::new())], + &|poly| vec![Product::new(Expression::CommonPolynomial(poly), Vec::new())], + &|query| { + if preprocess_polys.contains(&query.poly()) { + vec![Product::new(Expression::Polynomial(query), Vec::new())] + } else { + vec![Product::new( + Expression::Constant(F::ONE), + vec![Expression::Polynomial(query)], + )] + } + }, + &|challenge| { + vec![Product::new( + Expression::Constant(F::ONE), + vec![Expression::Challenge(challenge)], + )] + }, + &|products| { + products + .into_iter() + .map(|mut product| { + product.preprocess = -product.preprocess; + product + }) + .collect_vec() + }, + &|lhs, rhs| [lhs, rhs].concat(), + &|lhs, rhs| { + lhs.iter() + .cartesian_product(rhs.iter()) + .map(|(lhs, rhs)| { + Product::new( + &lhs.preprocess * &rhs.preprocess, + iter::empty() + .chain(&lhs.foldees) + .chain(&rhs.foldees) + .cloned() + .collect(), + ) + }) + .collect_vec() + }, + &|products, scalar| { + products + .into_iter() + .map(|mut product| { + product.preprocess = product.preprocess * scalar; + product + }) + .collect_vec() + }, + ); + products + .into_iter() + .map(|mut product| { + let (scalar, preprocess) = product.preprocess.evaluate( + &|constant| (constant, None), + &|poly| (F::ONE, Some(Expression::CommonPolynomial(poly))), + &|query| (F::ONE, Some(Expression::Polynomial(query))), + &|_| unreachable!(), + &|(scalar, preprocess)| (-scalar, preprocess), + &|_, _| unreachable!(), + &|(lhs_scalar, lhs_common), (rhs_scalar, rhs_common)| { + let preprocess = match (lhs_common, rhs_common) { + (Some(lhs_common), Some(rhs_common)) => Some(lhs_common * rhs_common), + (Some(preprocess), None) | (None, Some(preprocess)) => Some(preprocess), + (None, None) => None, + }; + (lhs_scalar * rhs_scalar, preprocess) + }, + &|(lhs, preprocess), rhs| (lhs * rhs, preprocess), + ); + + product.preprocess = preprocess + .map(|preprocess| { + if scalar == F::ONE { + preprocess + } else { + preprocess * scalar + } + }) + .unwrap_or_else(|| Expression::Constant(scalar)); + product + }) + .collect() +} + +#[derive(Clone, Debug)] +pub(crate) struct Product { + preprocess: Expression, + foldees: Vec>, +} + +impl Product { + fn new(preprocess: Expression, foldees: Vec>) -> Self { + Self { + preprocess, + foldees, + } + } + + fn folding_degree(&self) -> usize { + self.foldees.len() + } + + fn expression(&self) -> Expression + where + F: PrimeField, + { + &self.preprocess * self.foldees.iter().product::>() + } +} + +fn folding_degree(products: &[Product]) -> usize { + products + .iter() + .map(Product::folding_degree) + .max() + .unwrap_or_default() +}