diff --git a/core/src/stark/folder.rs b/core/src/stark/folder.rs index 47d4217e3..28aa49e65 100644 --- a/core/src/stark/folder.rs +++ b/core/src/stark/folder.rs @@ -113,7 +113,7 @@ pub struct GenericVerifierConstraintFolder<'a, F, EF, Var, Expr> { pub is_first_row: Var, pub is_last_row: Var, pub is_transition: Var, - pub alpha: EF, + pub alpha: Var, pub accumulator: Expr, pub _marker: PhantomData, } @@ -170,7 +170,7 @@ where fn assert_zero>(&mut self, x: I) { let x: Expr = x.into(); - self.accumulator *= self.alpha; + self.accumulator *= self.alpha.into(); self.accumulator += x; } } diff --git a/recursion/compiler/Cargo.toml b/recursion/compiler/Cargo.toml index 2fbff1a25..9c1a82e5e 100644 --- a/recursion/compiler/Cargo.toml +++ b/recursion/compiler/Cargo.toml @@ -6,11 +6,11 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +p3-air = { workspace = true } p3-field = { workspace = true } sp1-recursion-core = { path = "../core" } +sp1-core = { path = "../../core" } [dev-dependencies] p3-baby-bear = { workspace = true } -p3-air = { workspace = true } -sp1-core = { path = "../../core" } rand = "0.8.4" diff --git a/recursion/compiler/examples/verifier.rs b/recursion/compiler/examples/verifier.rs index 36d340ae3..feb15fde1 100644 --- a/recursion/compiler/examples/verifier.rs +++ b/recursion/compiler/examples/verifier.rs @@ -1,33 +1,63 @@ -use p3_air::Air; +use std::fs::File; +use std::marker::PhantomData; +use p3_field::AbstractField; use sp1_core::air::MachineAir; -use sp1_core::stark::{GenericVerifierConstraintFolder, MachineChip, StarkGenericConfig}; -use sp1_recursion_compiler::ir::{Ext, SymbolicExt}; - -#[allow(clippy::type_complexity)] -#[allow(dead_code)] -fn verify_constraints>( - chip: MachineChip, - folder: &mut GenericVerifierConstraintFolder< - SC::Val, - SC::Challenge, - Ext, - SymbolicExt, - >, -) where - A: for<'a> Air< - GenericVerifierConstraintFolder< - 'a, - SC::Val, - SC::Challenge, - Ext, - SymbolicExt, - >, - >, -{ - chip.eval(folder); -} +use sp1_core::stark::RiscvAir; +use sp1_core::stark::StarkGenericConfig; +use sp1_core::utils; +use sp1_core::utils::BabyBearPoseidon2; +use sp1_core::SP1Prover; +use sp1_core::SP1Stdin; +use sp1_recursion_compiler::gnark::GnarkBackend; +use sp1_recursion_compiler::ir::Builder; +use sp1_recursion_compiler::ir::{Ext, Felt}; +use sp1_recursion_compiler::verifier::verify_constraints; +use sp1_recursion_compiler::verifier::StarkGenericBuilderConfig; +use std::collections::HashMap; +use std::io::Write; fn main() { - println!("Hello, world!"); + type SC = BabyBearPoseidon2; + type F = ::Val; + type EF = ::Challenge; + + // Generate a dummy proof. + utils::setup_logger(); + let elf = + include_bytes!("../../../examples/cycle-tracking/program/elf/riscv32im-succinct-zkvm-elf"); + let proofs = SP1Prover::prove(elf, SP1Stdin::new()) + .unwrap() + .proof + .shard_proofs; + let proof = &proofs[0]; + + // Extract verification metadata. + let machine = RiscvAir::machine(SC::new()); + let chips = machine + .chips() + .iter() + .filter(|chip| proof.chip_ids.contains(&chip.name())) + .collect::>(); + let chip = chips[0]; + let opened_values = &proof.opened_values.chips[0]; + + // Run the verify inside the DSL. + let mut builder = Builder::>::default(); + let g: Felt = builder.eval(F::one()); + let zeta: Ext = builder.eval(F::one()); + let alpha: Ext = builder.eval(F::one()); + verify_constraints::(&mut builder, chip, opened_values, g, zeta, alpha); + + // Emit the constraints using the Gnark backend. + let mut backend = GnarkBackend::> { + nb_backend_vars: 0, + used: HashMap::new(), + phantom: PhantomData, + }; + let result = backend.compile(builder.operations); + let manifest_dir = env!("CARGO_MANIFEST_DIR"); + let path = format!("{}/src/gnark/lib/main.go", manifest_dir); + let mut file = File::create(path).unwrap(); + file.write_all(result.as_bytes()).unwrap(); } diff --git a/recursion/compiler/src/gnark/lib/main.go b/recursion/compiler/src/gnark/lib/main.go index 60831197b..b39e4c2e1 100644 --- a/recursion/compiler/src/gnark/lib/main.go +++ b/recursion/compiler/src/gnark/lib/main.go @@ -13,15 +13,15 @@ type Circuit struct { func (circuit *Circuit) Define(api frontend.API) error { fieldChip := babybear.NewChip(api) - + // Variables. - var felt2 *babybear.Variable - var backend1 frontend.Variable var var0 frontend.Variable - var backend0 frontend.Variable - var felt1 *babybear.Variable var felt0 *babybear.Variable - + var felt1 *babybear.Variable + var backend0 frontend.Variable + var felt2 *babybear.Variable + var backend1 frontend.Variable + // Operations. var0 = frontend.Variable(0) felt0 = babybear.NewVariable(0) @@ -33,10 +33,10 @@ func (circuit *Circuit) Define(api frontend.API) error { } fieldChip.AssertEq(felt0, babybear.NewVariable(144)) backend0 = api.IsZero(api.Sub(var0, var0)) - felt0 = fieldChip.Select(backend0, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0) - felt0 = fieldChip.Select(backend0, fieldChip.Add(felt0, felt1), felt0) + felt0 = fieldChip.Select(backend0, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0) + felt0 = fieldChip.Select(backend0, fieldChip.Add(felt0, felt1), felt0) backend1 = api.Sub(frontend.Variable(1), api.IsZero(api.Sub(var0, var0))) - felt0 = fieldChip.Select(backend1, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0) - felt0 = fieldChip.Select(backend1, fieldChip.Add(felt0, felt1), felt0) + felt0 = fieldChip.Select(backend1, fieldChip.Add(felt1, babybear.NewVariable(0)), felt0) + felt0 = fieldChip.Select(backend1, fieldChip.Add(felt0, felt1), felt0) return nil } diff --git a/recursion/compiler/src/ir/builder.rs b/recursion/compiler/src/ir/builder.rs index c2d587c33..dcad47353 100644 --- a/recursion/compiler/src/ir/builder.rs +++ b/recursion/compiler/src/ir/builder.rs @@ -8,7 +8,7 @@ pub struct Builder { pub(crate) felt_count: u32, pub(crate) ext_count: u32, pub(crate) var_count: u32, - pub(crate) operations: Vec>, + pub operations: Vec>, } impl Default for Builder { diff --git a/recursion/compiler/src/ir/mod.rs b/recursion/compiler/src/ir/mod.rs index cc7378c44..79677f7fd 100644 --- a/recursion/compiler/src/ir/mod.rs +++ b/recursion/compiler/src/ir/mod.rs @@ -6,6 +6,7 @@ mod instructions; mod ptr; mod symbolic; mod types; +mod utils; mod var; pub use builder::*; diff --git a/recursion/compiler/src/ir/symbolic.rs b/recursion/compiler/src/ir/symbolic.rs index 5a13b2e92..53fd76c57 100644 --- a/recursion/compiler/src/ir/symbolic.rs +++ b/recursion/compiler/src/ir/symbolic.rs @@ -1002,7 +1002,21 @@ impl, E: Any> ExtensionOperand for E { let value = value_ref.clone(); ExtOperand::::Sym(value) } - _ => unimplemented!("Unsupported type"), + ty if ty == TypeId::of::>() => { + let value_ref = unsafe { mem::transmute::<&E, &ExtOperand>(&self) }; + value_ref.clone() + } + _ => unimplemented!("unsupported type"), } } } + +impl Div for Felt { + type Output = SymbolicFelt; + + fn div(self, rhs: F) -> Self::Output { + let lhs = SymbolicFelt::Val(self); + let rhs = SymbolicFelt::Const(rhs); + SymbolicFelt::Div(lhs.into(), rhs.into()) + } +} diff --git a/recursion/compiler/src/ir/utils.rs b/recursion/compiler/src/ir/utils.rs new file mode 100644 index 000000000..a993bdaf2 --- /dev/null +++ b/recursion/compiler/src/ir/utils.rs @@ -0,0 +1,20 @@ +use std::ops::MulAssign; + +use super::{Builder, Config, Variable}; + +impl Builder { + pub fn exp_power_of_2, E: Into>( + &mut self, + e: E, + power_log: usize, + ) -> V + where + V::Expression: MulAssign + Clone, + { + let mut e = e.into(); + for _ in 0..power_log { + e *= e.clone(); + } + self.eval(e) + } +} diff --git a/recursion/compiler/src/lib.rs b/recursion/compiler/src/lib.rs index 752c91323..ae52d5d1a 100644 --- a/recursion/compiler/src/lib.rs +++ b/recursion/compiler/src/lib.rs @@ -5,6 +5,7 @@ pub mod builder; pub mod gnark; pub mod ir; pub mod util; +pub mod verifier; pub mod prelude { pub use crate::asm::AsmCompiler; diff --git a/recursion/compiler/src/verifier/constraints.rs b/recursion/compiler/src/verifier/constraints.rs new file mode 100644 index 000000000..b98a74697 --- /dev/null +++ b/recursion/compiler/src/verifier/constraints.rs @@ -0,0 +1,116 @@ +use std::marker::PhantomData; + +use p3_air::Air; +use p3_field::AbstractExtensionField; +use p3_field::AbstractField; +use p3_field::Field; +use sp1_core::air::MachineAir; +use sp1_core::stark::ChipOpenedValues; +use sp1_core::stark::{ + AirOpenedValues, GenericVerifierConstraintFolder, MachineChip, StarkGenericConfig, +}; + +use crate::prelude::{Builder, Config, Ext, Felt, SymbolicExt}; +use crate::verifier::StarkGenericBuilderConfig; + +impl Builder { + pub fn const_opened_values( + &mut self, + opened_values: &AirOpenedValues, + ) -> AirOpenedValues> { + AirOpenedValues::> { + local: opened_values + .local + .iter() + .map(|s| self.eval(SymbolicExt::Const(*s))) + .collect(), + next: opened_values + .next + .iter() + .map(|s| self.eval(SymbolicExt::Const(*s))) + .collect(), + } + } +} + +pub fn verify_constraints>( + builder: &mut Builder>, + chip: &MachineChip, + opening: &ChipOpenedValues, + g: Felt, + zeta: Ext, + alpha: Ext, +) where + A: for<'a> Air< + GenericVerifierConstraintFolder< + 'a, + SC::Val, + SC::Challenge, + Ext, + SymbolicExt, + >, + >, +{ + let g_inv: Felt = builder.eval(g / SC::Val::one()); + let z_h: Ext = builder.exp_power_of_2(zeta, opening.log_degree); + let one: Ext = builder.eval(SC::Val::one()); + let is_first_row = builder.eval(z_h / (zeta - one)); + let is_last_row = builder.eval(z_h / (zeta - g_inv)); + let is_transition = builder.eval(zeta - g_inv); + + let preprocessed = builder.const_opened_values(&opening.preprocessed); + let main = builder.const_opened_values(&opening.main); + let perm = builder.const_opened_values(&opening.permutation); + + let zero: Ext = builder.eval(SC::Val::zero()); + let zero_expr: SymbolicExt = zero.into(); + let mut folder = GenericVerifierConstraintFolder::< + SC::Val, + SC::Challenge, + Ext, + SymbolicExt, + > { + preprocessed: preprocessed.view(), + main: main.view(), + perm: perm.view(), + perm_challenges: &[SC::Challenge::zero(), SC::Challenge::zero()], + cumulative_sum: builder.eval(SC::Val::zero()), + is_first_row, + is_last_row, + is_transition, + alpha, + accumulator: zero_expr, + _marker: PhantomData, + }; + + let monomials = (0..SC::Challenge::D) + .map(SC::Challenge::monomial) + .collect::>(); + + let quotient_parts = opening + .quotient + .chunks_exact(SC::Challenge::D) + .map(|chunk| { + chunk + .iter() + .zip(monomials.iter()) + .map(|(x, m)| *x * *m) + .sum() + }) + .collect::>(); + + let mut zeta_powers = zeta; + let quotient: Ext = builder.eval(SC::Val::zero()); + let quotient_expr: SymbolicExt = quotient.into(); + for quotient_part in quotient_parts { + zeta_powers = builder.eval(zeta_powers * zeta); + builder.assign(quotient, zeta_powers * quotient_part); + } + let quotient: Ext = builder.eval(quotient_expr); + folder.alpha = alpha; + + chip.eval(&mut folder); + let folded_constraints = folder.accumulator; + let expected_folded_constraints = z_h * quotient; + builder.assert_ext_eq(folded_constraints, expected_folded_constraints); +} diff --git a/recursion/compiler/src/verifier/mod.rs b/recursion/compiler/src/verifier/mod.rs new file mode 100644 index 000000000..a43e7638b --- /dev/null +++ b/recursion/compiler/src/verifier/mod.rs @@ -0,0 +1,21 @@ +mod constraints; + +use std::marker::PhantomData; + +#[allow(unused_imports)] +pub use constraints::*; +use p3_field::Field; +use sp1_core::stark::StarkGenericConfig; + +use crate::prelude::Config; + +#[derive(Clone)] +pub struct StarkGenericBuilderConfig { + marker: PhantomData<(N, SC)>, +} + +impl Config for StarkGenericBuilderConfig { + type N = N; + type F = SC::Val; + type EF = SC::Challenge; +}