From 5d3075e6cc4f61ab3bc1879e65e83f4a9dc1055c Mon Sep 17 00:00:00 2001 From: Tamir Hemo Date: Thu, 21 Mar 2024 14:46:40 -0700 Subject: [PATCH] feat: verify constraints (#409) --- Cargo.lock | 1 + core/src/air/extension.rs | 46 +- core/src/stark/verifier.rs | 101 ++-- recursion/compiler/Cargo.toml | 1 + recursion/compiler/examples/fibonacci.rs | 14 + recursion/compiler/examples/verifier.rs | 65 --- recursion/compiler/src/asm/compiler.rs | 164 +++--- recursion/compiler/src/asm/instruction.rs | 235 ++++++++ recursion/compiler/src/builder.rs | 500 ------------------ recursion/compiler/src/gnark/lib/main.go | 6 +- recursion/compiler/src/ir/symbolic.rs | 174 ++++++ recursion/compiler/src/ir/types.rs | 57 +- recursion/compiler/src/lib.rs | 1 - .../src/verifier/constraints/domain.rs | 218 ++++++++ .../compiler/src/verifier/constraints/mod.rs | 430 ++++++++++++--- .../src/verifier/constraints/opening.rs | 30 ++ .../src/verifier/constraints/utils.rs | 46 +- recursion/compiler/src/verifier/folder.rs | 96 ++++ recursion/compiler/src/verifier/fri/pcs.rs | 7 +- recursion/compiler/src/verifier/mod.rs | 1 + recursion/compiler/tests/eval_constraints.rs | 165 ++++++ recursion/core/src/air/block.rs | 104 ++++ recursion/core/src/air/extension.rs | 18 + recursion/core/src/air/is_ext_zero.rs | 93 ++++ recursion/core/src/air/mod.rs | 90 +--- recursion/core/src/cpu/air.rs | 210 ++++---- recursion/core/src/cpu/columns.rs | 59 --- recursion/core/src/cpu/columns/alu.rs | 28 + recursion/core/src/cpu/columns/branch.rs | 14 + recursion/core/src/cpu/columns/instruction.rs | 38 ++ recursion/core/src/cpu/columns/jump.rs | 15 + recursion/core/src/cpu/columns/mod.rs | 36 ++ recursion/core/src/cpu/columns/opcode.rs | 129 +++++ .../core/src/cpu/columns/opcode_specific.rs | 1 + recursion/core/src/cpu/mod.rs | 3 + recursion/core/src/program/mod.rs | 2 +- recursion/core/src/runtime/instruction.rs | 7 +- recursion/core/src/runtime/mod.rs | 4 +- recursion/core/src/runtime/opcode.rs | 10 + recursion/core/src/stark/mod.rs | 9 +- 40 files changed, 2146 insertions(+), 1082 deletions(-) delete mode 100644 recursion/compiler/examples/verifier.rs delete mode 100644 recursion/compiler/src/builder.rs create mode 100644 recursion/compiler/src/verifier/constraints/domain.rs create mode 100644 recursion/compiler/src/verifier/constraints/opening.rs create mode 100644 recursion/compiler/src/verifier/folder.rs create mode 100644 recursion/compiler/tests/eval_constraints.rs create mode 100644 recursion/core/src/air/block.rs create mode 100644 recursion/core/src/air/is_ext_zero.rs delete mode 100644 recursion/core/src/cpu/columns.rs create mode 100644 recursion/core/src/cpu/columns/alu.rs create mode 100644 recursion/core/src/cpu/columns/branch.rs create mode 100644 recursion/core/src/cpu/columns/instruction.rs create mode 100644 recursion/core/src/cpu/columns/jump.rs create mode 100644 recursion/core/src/cpu/columns/mod.rs create mode 100644 recursion/core/src/cpu/columns/opcode.rs create mode 100644 recursion/core/src/cpu/columns/opcode_specific.rs diff --git a/Cargo.lock b/Cargo.lock index 674c88b2f..88b937fdb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2606,6 +2606,7 @@ dependencies = [ "itertools 0.12.1", "p3-air", "p3-baby-bear", + "p3-challenger", "p3-commit", "p3-field", "p3-fri", diff --git a/core/src/air/extension.rs b/core/src/air/extension.rs index 881499765..75fb6ad07 100644 --- a/core/src/air/extension.rs +++ b/core/src/air/extension.rs @@ -1,4 +1,7 @@ -use p3_field::AbstractField; +use p3_field::{ + extension::{BinomialExtensionField, BinomiallyExtendable}, + AbstractExtensionField, AbstractField, +}; use sp1_derive::AlignedBorrow; use std::ops::{Add, Mul, Neg, Sub}; @@ -8,6 +11,17 @@ const DEGREE: usize = 4; #[repr(C)] pub struct BinomialExtension(pub [T; DEGREE]); +impl BinomialExtension { + pub fn from_base(b: T) -> Self + where + T: AbstractField, + { + let mut arr: [T; DEGREE] = core::array::from_fn(|_| T::zero()); + arr[0] = b; + Self(arr) + } +} + impl + Clone> Add for BinomialExtension { type Output = Self; @@ -62,3 +76,33 @@ impl Neg for BinomialExtension { Self([-self.0[0], -self.0[1], -self.0[2], -self.0[3]]) } } + +impl From> for BinomialExtension +where + AF: AbstractField + Copy, + AF::F: BinomiallyExtendable, +{ + fn from(value: BinomialExtensionField) -> Self { + let arr: [AF; DEGREE] = value.as_base_slice().try_into().unwrap(); + Self(arr) + } +} + +impl From> for BinomialExtensionField +where + AF: AbstractField + Copy, + AF::F: BinomiallyExtendable, +{ + fn from(value: BinomialExtension) -> Self { + BinomialExtensionField::from_base_slice(&value.0) + } +} + +impl IntoIterator for BinomialExtension { + type Item = T; + type IntoIter = core::array::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} diff --git a/core/src/stark/verifier.rs b/core/src/stark/verifier.rs index dc6ca1ccb..3c5c90ee1 100644 --- a/core/src/stark/verifier.rs +++ b/core/src/stark/verifier.rs @@ -5,6 +5,7 @@ use itertools::Itertools; use p3_air::Air; use p3_challenger::CanObserve; use p3_challenger::FieldChallenger; +use p3_commit::LagrangeSelectors; use p3_commit::Pcs; use p3_commit::PolynomialSpace; use p3_field::AbstractExtensionField; @@ -197,38 +198,31 @@ impl>> Verifier { where A: for<'a> Air>, { - use p3_field::Field; let sels = trace_domain.selectors_at_point(zeta); - let zps = qc_domains - .iter() - .enumerate() - .map(|(i, domain)| { - qc_domains - .iter() - .enumerate() - .filter(|(j, _)| *j != i) - .map(|(_, other_domain)| { - other_domain.zp_at_point(zeta) - * other_domain.zp_at_point(domain.first_point()).inverse() - }) - .product::() - }) - .collect_vec(); + let quotient = Self::recompute_quotient(&opening, &qc_domains, zeta); + let folded_constraints = + Self::eval_constraints(chip, &opening, &sels, alpha, permutation_challenges); - let quotient = opening - .quotient - .iter() - .enumerate() - .map(|(ch_i, ch)| { - assert_eq!(ch.len(), SC::Challenge::D); - ch.iter() - .enumerate() - .map(|(e_i, &c)| zps[ch_i] * SC::Challenge::monomial(e_i) * c) - .sum::() - }) - .sum::(); + // Check that the constraints match the quotient, i.e. + // folded_constraints(zeta) / Z_H(zeta) = quotient(zeta) + match folded_constraints * sels.inv_zeroifier == quotient { + true => Ok(()), + false => Err(OodEvaluationMismatch), + } + } + #[cfg(feature = "perf")] + pub fn eval_constraints( + chip: &MachineChip, + opening: &ChipOpenedValues, + selectors: &LagrangeSelectors, + alpha: SC::Challenge, + permutation_challenges: &[SC::Challenge], + ) -> SC::Challenge + where + A: for<'a> Air>, + { // Reconstruct the prmutation opening values as extention elements. let unflatten = |v: &[SC::Challenge]| { v.chunks_exact(SC::Challenge::D) @@ -253,23 +247,54 @@ impl>> Verifier { perm: perm_opening.view(), perm_challenges: permutation_challenges, cumulative_sum: opening.cumulative_sum, - is_first_row: sels.is_first_row, - is_last_row: sels.is_last_row, - is_transition: sels.is_transition, + is_first_row: selectors.is_first_row, + is_last_row: selectors.is_last_row, + is_transition: selectors.is_transition, alpha, accumulator: SC::Challenge::zero(), _marker: PhantomData, }; chip.eval(&mut folder); - let folded_constraints = folder.accumulator; + folder.accumulator + } - // Check that the constraints match the quotient, i.e. - // folded_constraints(zeta) / Z_H(zeta) = quotient(zeta) - match folded_constraints * sels.inv_zeroifier == quotient { - true => Ok(()), - false => Err(OodEvaluationMismatch), - } + #[cfg(feature = "perf")] + pub fn recompute_quotient( + opening: &ChipOpenedValues, + qc_domains: &[Domain], + zeta: SC::Challenge, + ) -> SC::Challenge { + use p3_field::Field; + + let zps = qc_domains + .iter() + .enumerate() + .map(|(i, domain)| { + qc_domains + .iter() + .enumerate() + .filter(|(j, _)| *j != i) + .map(|(_, other_domain)| { + other_domain.zp_at_point(zeta) + * other_domain.zp_at_point(domain.first_point()).inverse() + }) + .product::() + }) + .collect_vec(); + + opening + .quotient + .iter() + .enumerate() + .map(|(ch_i, ch)| { + assert_eq!(ch.len(), SC::Challenge::D); + ch.iter() + .enumerate() + .map(|(e_i, &c)| zps[ch_i] * SC::Challenge::monomial(e_i) * c) + .sum::() + }) + .sum::() } } diff --git a/recursion/compiler/Cargo.toml b/recursion/compiler/Cargo.toml index 9955ac55a..fdd6287e5 100644 --- a/recursion/compiler/Cargo.toml +++ b/recursion/compiler/Cargo.toml @@ -19,4 +19,5 @@ serde = { version = "1.0.197", features = ["derive"] } [dev-dependencies] p3-baby-bear = { workspace = true } +p3-challenger = { workspace = true } rand = "0.8.4" diff --git a/recursion/compiler/examples/fibonacci.rs b/recursion/compiler/examples/fibonacci.rs index 31dda8b06..9ad4d33d3 100644 --- a/recursion/compiler/examples/fibonacci.rs +++ b/recursion/compiler/examples/fibonacci.rs @@ -48,7 +48,21 @@ fn main() { println!("{}", code); let program = code.machine_code(); + println!("Program size = {}", program.instructions.len()); let mut runtime = Runtime::::new(&program); runtime.run(); + + // let config = SC::new(); + // let machine = RecursionAir::machine(config); + // let (pk, vk) = machine.setup(&program); + // let mut challenger = machine.config().challenger(); + + // let start = Instant::now(); + // let proof = machine.prove::>(&pk, runtime.record, &mut challenger); + // let duration = start.elapsed().as_secs(); + + // let mut challenger = machine.config().challenger(); + // machine.verify(&vk, &proof, &mut challenger).unwrap(); + // println!("proving duration = {}", duration); } diff --git a/recursion/compiler/examples/verifier.rs b/recursion/compiler/examples/verifier.rs deleted file mode 100644 index 6177ca9ca..000000000 --- a/recursion/compiler/examples/verifier.rs +++ /dev/null @@ -1,65 +0,0 @@ -// use std::fs::File; -// use std::marker::PhantomData; - -// use p3_field::AbstractField; -// use sp1_core::air::MachineAir; -// use sp1_core::stark::RiscvAir; -// use sp1_core::stark::StarkGenericConfig; -// use sp1_core::utils; -// use sp1_core::utils::BabyBearPoseidon2; -// use sp1_core::SP1Prover; -// use sp1_core::SP1Stdin; -// use sp1_recursion_compiler::gnark::GnarkBackend; -// use sp1_recursion_compiler::ir::Builder; -// use sp1_recursion_compiler::ir::{Ext, Felt}; -// use sp1_recursion_compiler::verifier::verify_constraints; -// use sp1_recursion_compiler::verifier::StarkGenericBuilderConfig; -// use std::collections::HashMap; -// use std::io::Write; - -// fn main() { -// type SC = BabyBearPoseidon2; -// type F = ::Val; -// type EF = ::Challenge; - -// // Generate a dummy proof. -// utils::setup_logger(); -// let elf = -// include_bytes!("../../../examples/cycle-tracking/program/elf/riscv32im-succinct-zkvm-elf"); -// let proofs = SP1Prover::prove(elf, SP1Stdin::new()) -// .unwrap() -// .proof -// .shard_proofs; -// let proof = &proofs[0]; - -// // Extract verification metadata. -// let machine = RiscvAir::machine(SC::new()); -// let chips = machine -// .chips() -// .iter() -// .filter(|chip| proof.chip_ids.contains(&chip.name())) -// .collect::>(); -// let chip = chips[0]; -// let opened_values = &proof.opened_values.chips[0]; - -// // Run the verify inside the DSL. -// let mut builder = Builder::>::default(); -// let g: Felt = builder.eval(F::one()); -// let zeta: Ext = builder.eval(F::one()); -// let alpha: Ext = builder.eval(F::one()); -// verify_constraints::(&mut builder, chip, opened_values, g, zeta, alpha); - -// // Emit the constraints using the Gnark backend. -// let mut backend = GnarkBackend::> { -// nb_backend_vars: 0, -// used: HashMap::new(), -// phantom: PhantomData, -// }; -// let result = backend.compile(builder.operations); -// let manifest_dir = env!("CARGO_MANIFEST_DIR"); -// let path = format!("{}/src/gnark/lib/main.go", manifest_dir); -// let mut file = File::create(path).unwrap(); -// file.write_all(result.as_bytes()).unwrap(); -// } - -fn main() {} diff --git a/recursion/compiler/src/asm/compiler.rs b/recursion/compiler/src/asm/compiler.rs index aed401c90..63ebac7ce 100644 --- a/recursion/compiler/src/asm/compiler.rs +++ b/recursion/compiler/src/asm/compiler.rs @@ -17,9 +17,16 @@ use crate::ir::Usize; use crate::ir::{Config, DslIR, Ext, Felt, Ptr, Var}; use p3_field::Field; +pub(crate) const STACK_START_OFFSET: i32 = 16; + pub(crate) const ZERO: i32 = 0; pub(crate) const HEAP_PTR: i32 = -4; +#[allow(dead_code)] +pub(crate) const A0: i32 = -8; +#[allow(dead_code)] +pub(crate) const A1: i32 = -12; + pub type VmBuilder = Builder>; #[derive(Debug, Clone)] @@ -54,13 +61,13 @@ impl> VmBuilder { impl Var { fn fp(&self) -> i32 { - -((self.0 as i32) * 3 + 1 + 8) + -((self.0 as i32) * 3 + 1 + STACK_START_OFFSET) } } impl Felt { fn fp(&self) -> i32 { - -((self.0 as i32) * 3 + 2 + 8) + -((self.0 as i32) * 3 + 2 + STACK_START_OFFSET) } } @@ -72,7 +79,7 @@ impl Ptr { impl Ext { pub fn fp(&self) -> i32 { - -((self.0 as i32) * 3 + 8) + -((self.0 as i32) * 3 + STACK_START_OFFSET) } } @@ -118,8 +125,12 @@ impl> AsmCompiler { DslIR::AddEI(dst, lhs, rhs) => { self.push(AsmInstruction::EADDI(dst.fp(), lhs.fp(), rhs)); } - DslIR::AddEF(_dst, _lhs, _rhs) => todo!(), - DslIR::AddEFFI(_dst, _lhs, _rhs) => todo!(), + DslIR::AddEF(dst, lhs, rhs) => { + self.push(AsmInstruction::EADDF(dst.fp(), lhs.fp(), rhs.fp())); + } + DslIR::AddEFFI(dst, lhs, rhs) => { + self.push(AsmInstruction::FADDEI(dst.fp(), lhs.fp(), rhs)); + } DslIR::AddEFI(dst, lhs, rhs) => { self.push(AsmInstruction::EADDI( dst.fp(), @@ -166,7 +177,9 @@ impl> AsmCompiler { DslIR::InvF(dst, src) => { self.push(AsmInstruction::DIVIN(dst.fp(), F::one(), src.fp())); } - DslIR::DivEF(_dst, _lhs, _rhs) => todo!(), + DslIR::DivEF(dst, lhs, rhs) => { + self.push(AsmInstruction::EDIVF(dst.fp(), lhs.fp(), rhs.fp())); + } DslIR::DivEFI(dst, lhs, rhs) => { self.push(AsmInstruction::EDIVI( dst.fp(), @@ -200,7 +213,9 @@ impl> AsmCompiler { rhs.fp(), )); } - DslIR::SubEF(_dst, _lhs, _rhs) => todo!(), + DslIR::SubEF(dst, lhs, rhs) => { + self.push(AsmInstruction::ESUBF(dst.fp(), lhs.fp(), rhs.fp())); + } DslIR::SubEFI(dst, lhs, rhs) => { self.push(AsmInstruction::ESUBI( dst.fp(), @@ -238,8 +253,16 @@ impl> AsmCompiler { DslIR::MulEI(dst, lhs, rhs) => { self.push(AsmInstruction::EMULI(dst.fp(), lhs.fp(), rhs)); } - DslIR::MulEF(_dst, _lhs, _rhs) => todo!(), - DslIR::MulEFI(_dst, _lhs, _rhs) => todo!(), + DslIR::MulEF(dst, lhs, rhs) => { + self.push(AsmInstruction::EMULF(dst.fp(), lhs.fp(), rhs.fp())); + } + DslIR::MulEFI(dst, lhs, rhs) => { + self.push(AsmInstruction::EMULI( + dst.fp(), + lhs.fp(), + EF::from_base(rhs), + )); + } DslIR::IfEq(lhs, rhs, then_block, else_block) => { let if_compiler = IfCompiler { compiler: self, @@ -324,123 +347,51 @@ impl> AsmCompiler { } DslIR::AssertEqV(lhs, rhs) => { // If lhs != rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Val(rhs.fp()), - is_eq: false, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Val(rhs.fp()), false) } DslIR::AssertEqVI(lhs, rhs) => { // If lhs != rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Const(rhs), - is_eq: false, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Const(rhs), false) } DslIR::AssertNeV(lhs, rhs) => { // If lhs == rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Val(rhs.fp()), - is_eq: true, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Val(rhs.fp()), true) } DslIR::AssertNeVI(lhs, rhs) => { // If lhs == rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Const(rhs), - is_eq: true, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Const(rhs), true) } DslIR::AssertEqF(lhs, rhs) => { // If lhs != rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Val(rhs.fp()), - is_eq: false, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Val(rhs.fp()), false) } DslIR::AssertEqFI(lhs, rhs) => { // If lhs != rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Const(rhs), - is_eq: false, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Const(rhs), false) } DslIR::AssertNeF(lhs, rhs) => { // If lhs == rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Val(rhs.fp()), - is_eq: true, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Val(rhs.fp()), true) } DslIR::AssertNeFI(lhs, rhs) => { // If lhs == rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::Const(rhs), - is_eq: true, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::Const(rhs), true) } DslIR::AssertEqE(lhs, rhs) => { // If lhs != rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::ExtVal(rhs.fp()), - is_eq: false, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::ExtVal(rhs.fp()), false) } DslIR::AssertEqEI(lhs, rhs) => { // If lhs != rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::ExtConst(rhs), - is_eq: false, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::ExtConst(rhs), false) } DslIR::AssertNeE(lhs, rhs) => { // If lhs == rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::ExtVal(rhs.fp()), - is_eq: true, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::ExtVal(rhs.fp()), true) } DslIR::AssertNeEI(lhs, rhs) => { // If lhs == rhs, execute TRAP - let if_compiler = IfCompiler { - compiler: self, - lhs: lhs.fp(), - rhs: ValueOrConst::ExtConst(rhs), - is_eq: true, - }; - if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + self.assert(lhs.fp(), ValueOrConst::ExtConst(rhs), true) } DslIR::Alloc(ptr, len) => { self.alloc(ptr, len); @@ -451,7 +402,15 @@ impl> AsmCompiler { DslIR::StoreV(ptr, var) => self.push(AsmInstruction::SW(ptr.fp(), var.fp())), DslIR::StoreF(ptr, var) => self.push(AsmInstruction::SW(ptr.fp(), var.fp())), DslIR::StoreE(ptr, var) => self.push(AsmInstruction::SE(ptr.fp(), var.fp())), - _ => unimplemented!(), + DslIR::Num2BitsF(_, _) => unimplemented!(), + DslIR::Num2BitsV(_, _) => unimplemented!(), + DslIR::Poseidon2Compress(_, _, _) => unimplemented!(), + DslIR::Poseidon2Permute(_, _) => unimplemented!(), + DslIR::ReverseBitsLen(_, _, _) => unimplemented!(), + DslIR::TwoAdicGenerator(_, _) => unimplemented!(), + DslIR::ExpUsizeV(_, _, _) => unimplemented!(), + DslIR::ExpUsizeF(_, _, _) => unimplemented!(), + DslIR::Error() => self.push(AsmInstruction::TRAP), } } } @@ -471,6 +430,23 @@ impl> AsmCompiler { } } + pub fn assert(&mut self, lhs: i32, rhs: ValueOrConst, is_eq: bool) { + let if_compiler = IfCompiler { + compiler: self, + lhs, + rhs, + is_eq, + }; + if_compiler.then(|builder| builder.push(AsmInstruction::TRAP)); + } + + pub fn num_2_bits(&mut self, array_ptr: Ptr, element: i32) { + // Store the bits of `F` in the array, storing `4` bits at a time. + for i in 0..4 { + self.push(AsmInstruction::HintBits(array_ptr.fp(), element, i)); + } + } + pub fn code(self) -> AssemblyCode { let labels = self .function_labels diff --git a/recursion/compiler/src/asm/instruction.rs b/recursion/compiler/src/asm/instruction.rs index 15faa6883..35d338614 100644 --- a/recursion/compiler/src/asm/instruction.rs +++ b/recursion/compiler/src/asm/instruction.rs @@ -12,6 +12,7 @@ use super::ZERO; #[derive(Debug, Clone)] pub enum AsmInstruction { + // Field operations /// Load work (dst, src) : load a value from the address stored at src(fp) into dstfp). LW(i32, i32), /// Store word (dst, src) : store a value from src(fp) into the address stored at dest(fp). @@ -67,6 +68,42 @@ pub enum AsmInstruction { /// Divide value from immediate extension, dst = lhs / rhs. EDIVIN(i32, EF, i32), + // Mixed base-extension operations + /// Add base to extension, dst = lhs + rhs. + EADDF(i32, i32, i32), + /// Add immediate base to extension, dst = lhs + rhs. + EADDFI(i32, i32, F), + /// Add immediate extension element to base, dst = lhs + rhs. + FADDEI(i32, i32, EF), + // Subtract base from extension, dst = lhs - rhs. + ESUBF(i32, i32, i32), + /// Subtract immediate base from extension, dst = lhs - rhs. + ESUBFI(i32, i32, F), + /// Subtract value from immediate base to extension, dst = lhs - rhs. + ESUBFIN(i32, F, i32), + /// Subtract extension from base, dst = lhs - rhs. + FSUBE(i32, i32, i32), + /// Subtract immediate extension from base, dst = lhs - rhs. + FSUBEI(i32, i32, EF), + /// Subtract value from immediate extension to base, dst = lhs - rhs. + FSUBEIN(i32, EF, i32), + /// Multiply base and extension, dst = lhs * rhs. + EMULF(i32, i32, i32), + /// Multiply immediate base and extension. + EMULFI(i32, i32, F), + /// Multiply base by immediate extension, dst = lhs * rhs. + FMULEI(i32, i32, EF), + /// Divide base and extension, dst = lhs / rhs. + EDIVF(i32, i32, i32), + /// Divide immediate base and extension, dst = lhs / rhs. + EDIVFI(i32, i32, F), + /// Divide value from immediate base to extension, dst = lhs / rhs. + EDIVFIN(i32, F, i32), + /// Divide extension from immediate base, dst = lhs / rhs. + FDIVI(i32, i32, EF), + /// Divide value from immediate extension to base, dst = lhs / rhs. + FDIVIN(i32, EF, i32), + /// Jump and link JAL(i32, F, F), /// Jump and link value @@ -89,6 +126,13 @@ pub enum AsmInstruction { EBEQI(F, i32, EF), /// Trap TRAP, + + // Store the 4 most significant bits of the source into a contiguous chunk of memory at the + // address stored in the destination. + HintBits(i32, i32, usize), + /// Compute srarting at the address stored in the source and store the result at the + /// destination. + Bits4toNum(i32, i32, usize), } impl> AsmInstruction { @@ -292,6 +336,142 @@ impl> AsmInstruction { true, false, ), + AsmInstruction::EADDF(dst, lhs, rhs) => Instruction::new( + Opcode::EFADD, + i32_f(dst), + i32_f_arr(lhs), + i32_f_arr(rhs), + false, + false, + ), + AsmInstruction::EADDFI(dst, lhs, rhs) => Instruction::new( + Opcode::EFADD, + i32_f(dst), + i32_f_arr(lhs), + f_u32(rhs), + false, + true, + ), + AsmInstruction::FADDEI(dst, lhs, rhs) => Instruction::new( + Opcode::EFADD, + i32_f(dst), + rhs.as_base_slice().try_into().unwrap(), + i32_f_arr(lhs), + true, + false, + ), + AsmInstruction::ESUBF(dst, lhs, rhs) => Instruction::new( + Opcode::EFSUB, + i32_f(dst), + i32_f_arr(lhs), + i32_f_arr(rhs), + false, + false, + ), + AsmInstruction::ESUBFI(dst, lhs, rhs) => Instruction::new( + Opcode::EFSUB, + i32_f(dst), + i32_f_arr(lhs), + f_u32(rhs), + false, + true, + ), + AsmInstruction::ESUBFIN(dst, lhs, rhs) => Instruction::new( + Opcode::EFSUB, + i32_f(dst), + f_u32(lhs), + i32_f_arr(rhs), + true, + false, + ), + AsmInstruction::FSUBE(dst, lhs, rhs) => Instruction::new( + Opcode::FESUB, + i32_f(dst), + i32_f_arr(rhs), + i32_f_arr(lhs), + false, + false, + ), + AsmInstruction::FSUBEI(dst, lhs, rhs) => Instruction::new( + Opcode::FESUB, + i32_f(dst), + i32_f_arr(lhs), + rhs.as_base_slice().try_into().unwrap(), + false, + true, + ), + AsmInstruction::FSUBEIN(dst, lhs, rhs) => Instruction::new( + Opcode::FESUB, + i32_f(dst), + lhs.as_base_slice().try_into().unwrap(), + i32_f_arr(rhs), + true, + false, + ), + AsmInstruction::EMULF(dst, lhs, rhs) => Instruction::new( + Opcode::EFMUL, + i32_f(dst), + i32_f_arr(lhs), + i32_f_arr(rhs), + false, + false, + ), + AsmInstruction::EMULFI(dst, lhs, rhs) => Instruction::new( + Opcode::EFMUL, + i32_f(dst), + i32_f_arr(lhs), + f_u32(rhs), + false, + true, + ), + AsmInstruction::FMULEI(dst, lhs, rhs) => Instruction::new( + Opcode::EFMUL, + i32_f(dst), + i32_f_arr(lhs), + rhs.as_base_slice().try_into().unwrap(), + false, + true, + ), + AsmInstruction::EDIVF(dst, lhs, rhs) => Instruction::new( + Opcode::EFDIV, + i32_f(dst), + i32_f_arr(lhs), + i32_f_arr(rhs), + false, + false, + ), + AsmInstruction::EDIVFI(dst, lhs, rhs) => Instruction::new( + Opcode::EFDIV, + i32_f(dst), + i32_f_arr(lhs), + f_u32(rhs), + false, + true, + ), + AsmInstruction::EDIVFIN(dst, lhs, rhs) => Instruction::new( + Opcode::FEDIV, + i32_f(dst), + f_u32(lhs), + i32_f_arr(rhs), + true, + false, + ), + AsmInstruction::FDIVI(dst, lhs, rhs) => Instruction::new( + Opcode::FEDIV, + i32_f(dst), + i32_f_arr(lhs), + rhs.as_base_slice().try_into().unwrap(), + false, + true, + ), + AsmInstruction::FDIVIN(dst, lhs, rhs) => Instruction::new( + Opcode::EFDIV, + i32_f(dst), + lhs.as_base_slice().try_into().unwrap(), + i32_f_arr(rhs), + true, + false, + ), AsmInstruction::BEQ(label, lhs, rhs) => { let offset = F::from_canonical_usize(label_to_pc[&label]) - F::from_canonical_usize(pc); @@ -411,6 +591,8 @@ impl> AsmInstruction { AsmInstruction::TRAP => { Instruction::new(Opcode::TRAP, F::zero(), zero, zero, false, false) } + AsmInstruction::HintBits(_, _, _) => unimplemented!(), + AsmInstruction::Bits4toNum(_, _, _) => unimplemented!(), } } @@ -499,6 +681,57 @@ impl> AsmInstruction { offset ) } + AsmInstruction::EADDF(dst, lhs, rhs) => { + write!(f, "eaddf ({})fp, ({})fp, ({})fp", dst, lhs, rhs) + } + AsmInstruction::EADDFI(dst, lhs, rhs) => { + write!(f, "eaddfi ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::FADDEI(dst, lhs, rhs) => { + write!(f, "faddei ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::ESUBF(dst, lhs, rhs) => { + write!(f, "esubf ({})fp, ({})fp, ({})fp", dst, lhs, rhs) + } + AsmInstruction::ESUBFI(dst, lhs, rhs) => { + write!(f, "esubfi ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::ESUBFIN(dst, lhs, rhs) => { + write!(f, "esubfin ({})fp, {}, ({})fp", dst, lhs, rhs) + } + AsmInstruction::FSUBE(dst, lhs, rhs) => { + write!(f, "fsube ({})fp, ({})fp, ({})fp", dst, lhs, rhs) + } + AsmInstruction::FSUBEI(dst, lhs, rhs) => { + write!(f, "fsubei ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::FSUBEIN(dst, lhs, rhs) => { + write!(f, "fsubein ({})fp, {}, ({})fp", dst, lhs, rhs) + } + AsmInstruction::EMULF(dst, lhs, rhs) => { + write!(f, "emulf ({})fp, ({})fp, ({})fp", dst, lhs, rhs) + } + AsmInstruction::EMULFI(dst, lhs, rhs) => { + write!(f, "emulfi ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::FMULEI(dst, lhs, rhs) => { + write!(f, "fmulei ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::EDIVF(dst, lhs, rhs) => { + write!(f, "edivf ({})fp, ({})fp, ({})fp", dst, lhs, rhs) + } + AsmInstruction::EDIVFI(dst, lhs, rhs) => { + write!(f, "edivfi ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::EDIVFIN(dst, lhs, rhs) => { + write!(f, "edivfin ({})fp, {}, ({})fp", dst, lhs, rhs) + } + AsmInstruction::FDIVI(dst, lhs, rhs) => { + write!(f, "fdivi ({})fp, ({})fp, {}", dst, lhs, rhs) + } + AsmInstruction::FDIVIN(dst, lhs, rhs) => { + write!(f, "fdivin ({})fp, {}, ({})fp", dst, lhs, rhs) + } AsmInstruction::JALR(dst, label, offset) => { write!(f, "jalr ({})fp, ({})fp, ({})fp", dst, label, offset) } @@ -575,6 +808,8 @@ impl> AsmInstruction { ) } AsmInstruction::TRAP => write!(f, "trap"), + AsmInstruction::HintBits(_, _, _) => unimplemented!(), + AsmInstruction::Bits4toNum(_, _, _) => unimplemented!(), } } } diff --git a/recursion/compiler/src/builder.rs b/recursion/compiler/src/builder.rs deleted file mode 100644 index 97fb99b3f..000000000 --- a/recursion/compiler/src/builder.rs +++ /dev/null @@ -1,500 +0,0 @@ -// use crate::asm::AsmInstruction; -// use crate::old_ir::Constant; -// use crate::old_ir::Variable; - -// use crate::asm::BasicBlock; -// use crate::old_ir::Expression; -// use crate::old_ir::Felt; -// use crate::old_ir::Int; -// use crate::prelude::Symbolic; -// use crate::prelude::SymbolicLogic; - -// use p3_field::AbstractField; -// use p3_field::PrimeField32; - -// pub trait Builder: Sized { -// type F: PrimeField32; -// /// Get stack memory. -// fn get_mem(&mut self, size: usize) -> i32; -// // Allocate heap memory. -// fn alloc(&mut self, size: Int) -> Int; - -// fn push(&mut self, instruction: AsmInstruction); - -// fn get_block_mut(&mut self, label: Self::F) -> &mut BasicBlock; - -// fn basic_block(&mut self); - -// fn block_label(&mut self) -> Self::F; - -// fn push_to_block(&mut self, block_label: Self::F, instruction: AsmInstruction) { -// self.get_block_mut(block_label).push(instruction); -// } - -// fn uninit>(&mut self) -> T { -// T::uninit(self) -// } - -// fn constant>(&mut self, value: T::Constant) -> T { -// let var = T::uninit(self); -// var.imm(value, self); -// var -// } - -// fn assign>(&mut self, dst: E::Value, expr: E) { -// expr.assign(dst, self); -// } - -// fn eval>(&mut self, expr: E) -> E::Value { -// let dst = E::Value::uninit(self); -// expr.assign(dst, self); -// dst -// } - -// fn range(&mut self, start: Felt, end: Felt) -> ForBuilder { -// let loop_var = Felt::uninit(self); -// ForBuilder { -// builder: self, -// start, -// end, -// loop_var, -// } -// } - -// fn if_eq(&mut self, lhs: E1, rhs: E2) -> IfBuilder -// where -// E1: Into>, -// E2: Into>, -// { -// IfBuilder { -// builder: self, -// lhs: lhs.into(), -// rhs: rhs.into(), -// is_eq: true, -// } -// } - -// fn if_neq(&mut self, lhs: E1, rhs: E2) -> IfBuilder -// where -// E1: Into>, -// E2: Into>, -// { -// IfBuilder { -// builder: self, -// lhs: lhs.into(), -// rhs: rhs.into(), -// is_eq: false, -// } -// } - -// fn assert_eq(&mut self, lhs: E1, rhs: E2) -// where -// E1: Into>, -// E2: Into>, -// { -// self.if_neq(lhs, rhs) -// .then(|builder| builder.push(AsmInstruction::TRAP)); -// } - -// fn assert_ne(&mut self, lhs: E1, rhs: E2) -// where -// E1: Into>, -// E2: Into>, -// { -// self.if_eq(lhs, rhs) -// .then(|builder| builder.push(AsmInstruction::TRAP)); -// } - -// fn if_true(&mut self, expr: E) -> IfBoolBuilder -// where -// E: Into, -// { -// IfBoolBuilder { -// builder: self, -// expr: expr.into(), -// is_true: true, -// } -// } - -// fn if_false(&mut self, expr: E) -> IfBoolBuilder -// where -// E: Into, -// { -// IfBoolBuilder { -// builder: self, -// expr: expr.into(), -// is_true: false, -// } -// } - -// fn assert(&mut self, expr: E) -// where -// E: Into, -// { -// self.if_false(expr) -// .then(|builder| builder.push(AsmInstruction::TRAP)); -// } - -// fn assert_not(&mut self, expr: E) -// where -// E: Into, -// { -// self.if_true(expr) -// .then(|builder| builder.push(AsmInstruction::TRAP)); -// } -// } - -// pub struct IfBoolBuilder<'a, B: Builder> { -// builder: &'a mut B, -// expr: SymbolicLogic, -// is_true: bool, -// } - -// impl<'a, B: Builder> Builder for IfBoolBuilder<'a, B> { -// type F = B::F; -// fn get_mem(&mut self, size: usize) -> i32 { -// self.builder.get_mem(size) -// } - -// fn alloc(&mut self, size: Int) -> Int { -// self.builder.alloc(size) -// } - -// fn push(&mut self, instruction: AsmInstruction) { -// self.builder.push(instruction); -// } - -// fn get_block_mut(&mut self, label: Self::F) -> &mut BasicBlock { -// self.builder.get_block_mut(label) -// } - -// fn basic_block(&mut self) { -// self.builder.basic_block(); -// } - -// fn block_label(&mut self) -> B::F { -// self.builder.block_label() -// } -// } - -// impl<'a, B: Builder> IfBoolBuilder<'a, B> { -// pub fn then(self, f: Func) -// where -// Func: FnOnce(&mut B), -// { -// let Self { -// builder, -// expr, -// is_true, -// } = self; -// let after_if_block = builder.block_label() + B::F::two(); -// Self::branch(expr, is_true, after_if_block, builder); -// builder.basic_block(); -// f(builder); -// builder.basic_block(); -// } - -// pub fn then_or_else(self, then_f: ThenFunc, else_f: ElseFunc) -// where -// ThenFunc: FnOnce(&mut B), -// ElseFunc: FnOnce(&mut B), -// { -// let Self { -// builder, -// expr, -// is_true, -// } = self; -// let else_block = builder.block_label() + B::F::two(); -// let main_flow_block = else_block + B::F::one(); -// Self::branch(expr, is_true, else_block, builder); -// builder.basic_block(); -// then_f(builder); -// let instr = AsmInstruction::j(main_flow_block, builder); -// builder.push(instr); -// builder.basic_block(); -// else_f(builder); -// builder.basic_block(); -// } - -// fn branch(expr: SymbolicLogic, is_true: bool, block: B::F, builder: &mut B) { -// match (expr, is_true) { -// (SymbolicLogic::Const(true), true) => { -// let instr = AsmInstruction::j(block, builder); -// builder.push(instr); -// } -// (SymbolicLogic::Const(true), false) => {} -// (SymbolicLogic::Const(false), true) => {} -// (SymbolicLogic::Const(false), false) => { -// let instr = AsmInstruction::j(block, builder); -// builder.push(instr); -// } -// (SymbolicLogic::Value(expr), true) => { -// let instr = AsmInstruction::BNEI(block, expr.0, B::F::one()); -// builder.push(instr); -// } -// (SymbolicLogic::Value(expr), false) => { -// let instr = AsmInstruction::BEQI(block, expr.0, B::F::one()); -// builder.push(instr); -// } -// (expr, true) => { -// let value = builder.eval(expr); -// let instr = AsmInstruction::BNEI(block, value.0, B::F::one()); -// builder.push(instr); -// } -// (expr, false) => { -// let value = builder.eval(expr); -// let instr = AsmInstruction::BEQI(block, value.0, B::F::one()); -// builder.push(instr); -// } -// } -// } -// } - -// pub struct IfBuilder<'a, B: Builder> { -// builder: &'a mut B, -// lhs: Symbolic, -// rhs: Symbolic, -// is_eq: bool, -// } - -// impl<'a, B: Builder> Builder for IfBuilder<'a, B> { -// type F = B::F; -// fn get_mem(&mut self, size: usize) -> i32 { -// self.builder.get_mem(size) -// } - -// fn alloc(&mut self, size: Int) -> Int { -// self.builder.alloc(size) -// } - -// fn push(&mut self, instruction: AsmInstruction) { -// self.builder.push(instruction); -// } - -// fn get_block_mut(&mut self, label: Self::F) -> &mut BasicBlock { -// self.builder.get_block_mut(label) -// } - -// fn basic_block(&mut self) { -// self.builder.basic_block(); -// } - -// fn block_label(&mut self) -> B::F { -// self.builder.block_label() -// } -// } - -// impl<'a, B: Builder> IfBuilder<'a, B> { -// pub fn then(self, f: Func) -// where -// Func: FnOnce(&mut B), -// { -// let Self { -// builder, -// lhs, -// rhs, -// is_eq, -// } = self; -// // Get the label for the block after the if block, and generate the conditional branch -// // instruction to it, if the condition is not met. -// let after_if_block = builder.block_label() + B::F::two(); -// Self::branch(lhs, rhs, is_eq, after_if_block, builder); -// // Generate the block for the then branch. -// builder.basic_block(); -// f(builder); -// // Generate the block for returning to the main flow. -// builder.basic_block(); -// } - -// pub fn then_or_else(self, then_f: ThenFunc, else_f: ElseFunc) -// where -// ThenFunc: FnOnce(&mut B), -// ElseFunc: FnOnce(&mut B), -// { -// let Self { -// builder, -// lhs, -// rhs, -// is_eq, -// } = self; -// // Get the label for the else block, and the continued main flow block, and generate the -// // conditional branc instruction to it, if the condition is not met. -// let else_block = builder.block_label() + B::F::two(); -// let main_flow_block = else_block + B::F::one(); -// Self::branch(lhs, rhs, is_eq, else_block, builder); -// // Generate the block for the then branch. -// builder.basic_block(); -// then_f(builder); -// // Generate the jump instruction to the main flow block. -// let instr = AsmInstruction::j(main_flow_block, builder); -// builder.push(instr); -// // Generate the block for the else branch. -// builder.basic_block(); -// else_f(builder); -// // Generate the block for returning to the main flow. -// builder.basic_block(); -// } - -// fn branch(lhs: Symbolic, rhs: Symbolic, is_eq: bool, block: B::F, builder: &mut B) { -// match (lhs, rhs, is_eq) { -// (Symbolic::Const(lhs), Symbolic::Const(rhs), true) => { -// if lhs == rhs { -// let instr = AsmInstruction::j(block, builder); -// builder.push(instr); -// } -// } -// (Symbolic::Const(lhs), Symbolic::Const(rhs), false) => { -// if lhs != rhs { -// let instr = AsmInstruction::j(block, builder); -// builder.push(instr); -// } -// } -// (Symbolic::Const(lhs), Symbolic::Value(rhs), true) => { -// let instr = AsmInstruction::BNEI(block, rhs.0, lhs); -// builder.push(instr); -// } -// (Symbolic::Const(lhs), Symbolic::Value(rhs), false) => { -// let instr = AsmInstruction::BEQI(block, rhs.0, lhs); -// builder.push(instr); -// } -// (Symbolic::Const(lhs), rhs, true) => { -// let rhs = builder.eval(rhs); -// let instr = AsmInstruction::BNEI(block, rhs.0, lhs); -// builder.push(instr); -// } -// (Symbolic::Const(lhs), rhs, false) => { -// let rhs = builder.eval(rhs); -// let instr = AsmInstruction::BEQI(block, rhs.0, lhs); -// builder.push(instr); -// } -// (Symbolic::Value(lhs), Symbolic::Const(rhs), true) => { -// let instr = AsmInstruction::BNEI(block, lhs.0, rhs); -// builder.push(instr); -// } -// (Symbolic::Value(lhs), Symbolic::Const(rhs), false) => { -// let instr = AsmInstruction::BEQI(block, lhs.0, rhs); -// builder.push(instr); -// } -// (lhs, Symbolic::Const(rhs), true) => { -// let lhs = builder.eval(lhs); -// let instr = AsmInstruction::BNEI(block, lhs.0, rhs); -// builder.push(instr); -// } -// (lhs, Symbolic::Const(rhs), false) => { -// let lhs = builder.eval(lhs); -// let instr = AsmInstruction::BEQI(block, lhs.0, rhs); -// builder.push(instr); -// } -// (Symbolic::Value(lhs), Symbolic::Value(rhs), true) => { -// let instr = AsmInstruction::BNE(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (Symbolic::Value(lhs), Symbolic::Value(rhs), false) => { -// let instr = AsmInstruction::BEQ(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (Symbolic::Value(lhs), rhs, true) => { -// let rhs = builder.eval(rhs); -// let instr = AsmInstruction::BNE(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (Symbolic::Value(lhs), rhs, false) => { -// let rhs = builder.eval(rhs); -// let instr = AsmInstruction::BEQ(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (lhs, Symbolic::Value(rhs), true) => { -// let lhs = builder.eval(lhs); -// let instr = AsmInstruction::BNE(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (lhs, Symbolic::Value(rhs), false) => { -// let lhs = builder.eval(lhs); -// let instr = AsmInstruction::BEQ(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (lhs, rhs, true) => { -// let lhs = builder.eval(lhs); -// let rhs = builder.eval(rhs); -// let instr = AsmInstruction::BNE(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// (lhs, rhs, false) => { -// let lhs = builder.eval(lhs); -// let rhs = builder.eval(rhs); -// let instr = AsmInstruction::BEQ(block, lhs.0, rhs.0); -// builder.push(instr); -// } -// } -// } -// } - -// /// A builder for a for loop. -// /// -// /// Starting with end < start will lead to undefined behavior! -// pub struct ForBuilder<'a, B: Builder> { -// builder: &'a mut B, -// start: Felt, -// end: Felt, -// loop_var: Felt, -// } - -// impl<'a, B: Builder> Builder for ForBuilder<'a, B> { -// type F = B::F; -// fn get_mem(&mut self, size: usize) -> i32 { -// self.builder.get_mem(size) -// } - -// fn alloc(&mut self, size: Int) -> Int { -// self.builder.alloc(size) -// } - -// fn push(&mut self, instruction: AsmInstruction) { -// self.builder.push(instruction); -// } - -// fn get_block_mut(&mut self, label: Self::F) -> &mut BasicBlock { -// self.builder.get_block_mut(label) -// } - -// fn basic_block(&mut self) { -// self.builder.basic_block(); -// } - -// fn block_label(&mut self) -> B::F { -// self.builder.block_label() -// } -// } - -// impl<'a, B: Builder> ForBuilder<'a, B> { -// pub fn for_each(&mut self, f: Func) -// where -// Func: FnOnce(Felt, &mut Self), -// { -// // The function block structure: -// // - Setting the loop range -// // - Executing the loop body and incrementing the loop variable -// // - the loop condition -// let loop_var = self.loop_var; -// // Set the loop variable to the start of the range. -// self.assign(loop_var, self.start); -// // Save the label of the for loop call -// let loop_call_label = self.block_label(); -// // A basic block for the loop body -// self.basic_block(); -// // Save the loop body label for the loop condition. -// let loop_label = self.block_label(); -// // The loop body. -// f(loop_var, self); -// self.assign(loop_var, loop_var + B::F::one()); -// // Add a basic block for the loop condition. -// self.basic_block(); -// // Jump to loop body if the loop condition still holds. -// let instr = AsmInstruction::BNE(loop_label, loop_var.0, self.end.0); -// self.push(instr); -// // Add a jump instruction to the loop condition in the following block -// let label = self.block_label(); -// let instr = AsmInstruction::j(label, self); -// self.push_to_block(loop_call_label, instr); -// } -// } diff --git a/recursion/compiler/src/gnark/lib/main.go b/recursion/compiler/src/gnark/lib/main.go index 0e6aa4b98..1e7502aaa 100644 --- a/recursion/compiler/src/gnark/lib/main.go +++ b/recursion/compiler/src/gnark/lib/main.go @@ -16,11 +16,11 @@ func (circuit *Circuit) Define(api frontend.API) error { // Variables. var felt2 *babybear.Variable - var var0 frontend.Variable - var backend1 frontend.Variable - var felt1 *babybear.Variable var felt0 *babybear.Variable + var felt1 *babybear.Variable + var var0 frontend.Variable var backend0 frontend.Variable + var backend1 frontend.Variable // Operations. var0 = frontend.Variable(0) diff --git a/recursion/compiler/src/ir/symbolic.rs b/recursion/compiler/src/ir/symbolic.rs index 64e92b2bc..fb9c8adb0 100644 --- a/recursion/compiler/src/ir/symbolic.rs +++ b/recursion/compiler/src/ir/symbolic.rs @@ -43,6 +43,12 @@ pub enum SymbolicExt { Neg(Rc>), } +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum SymbolicUsize { + Const(usize), + Var(SymbolicVar), +} + #[derive(Debug, Clone)] pub enum ExtOperand { Base(F), @@ -1024,3 +1030,171 @@ impl> From> for SymbolicExt { SymbolicExt::Base(Rc::new(SymbolicFelt::Val(value))) } } + +impl> Neg for Ext { + type Output = SymbolicExt; + fn neg(self) -> Self::Output { + SymbolicExt::Neg(Rc::new(SymbolicExt::Val(self))) + } +} + +impl Neg for Felt { + type Output = SymbolicFelt; + + fn neg(self) -> Self::Output { + SymbolicFelt::Neg(Rc::new(SymbolicFelt::Val(self))) + } +} + +impl Neg for Var { + type Output = SymbolicVar; + + fn neg(self) -> Self::Output { + SymbolicVar::Neg(Rc::new(SymbolicVar::Val(self))) + } +} + +impl From for SymbolicUsize { + fn from(n: usize) -> Self { + SymbolicUsize::Const(n) + } +} + +impl From> for SymbolicUsize { + fn from(n: SymbolicVar) -> Self { + SymbolicUsize::Var(n) + } +} + +impl From> for SymbolicUsize { + fn from(n: Var) -> Self { + SymbolicUsize::Var(SymbolicVar::from(n)) + } +} + +impl Add for SymbolicUsize { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (SymbolicUsize::Const(a), SymbolicUsize::Const(b)) => SymbolicUsize::Const(a + b), + (SymbolicUsize::Var(a), SymbolicUsize::Const(b)) => { + SymbolicUsize::Var(a + N::from_canonical_usize(b)) + } + (SymbolicUsize::Const(a), SymbolicUsize::Var(b)) => { + SymbolicUsize::Var(b + N::from_canonical_usize(a)) + } + (SymbolicUsize::Var(a), SymbolicUsize::Var(b)) => SymbolicUsize::Var(a + b), + } + } +} + +impl Sub for SymbolicUsize { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + match (self, rhs) { + (SymbolicUsize::Const(a), SymbolicUsize::Const(b)) => SymbolicUsize::Const(a - b), + (SymbolicUsize::Var(a), SymbolicUsize::Const(b)) => { + SymbolicUsize::Var(a - N::from_canonical_usize(b)) + } + (SymbolicUsize::Const(a), SymbolicUsize::Var(b)) => { + SymbolicUsize::Var(b - N::from_canonical_usize(a)) + } + (SymbolicUsize::Var(a), SymbolicUsize::Var(b)) => SymbolicUsize::Var(a - b), + } + } +} + +impl Add for SymbolicUsize { + type Output = Self; + + fn add(self, rhs: usize) -> Self::Output { + match self { + SymbolicUsize::Const(a) => SymbolicUsize::Const(a + rhs), + SymbolicUsize::Var(a) => SymbolicUsize::Var(a + N::from_canonical_usize(rhs)), + } + } +} + +impl Sub for SymbolicUsize { + type Output = Self; + + fn sub(self, rhs: usize) -> Self::Output { + match self { + SymbolicUsize::Const(a) => SymbolicUsize::Const(a - rhs), + SymbolicUsize::Var(a) => SymbolicUsize::Var(a - N::from_canonical_usize(rhs)), + } + } +} + +impl From> for SymbolicUsize { + fn from(n: Usize) -> Self { + match n { + Usize::Const(n) => SymbolicUsize::Const(n), + Usize::Var(n) => SymbolicUsize::Var(SymbolicVar::from(n)), + } + } +} + +impl Add> for SymbolicUsize { + type Output = SymbolicUsize; + + fn add(self, rhs: Usize) -> Self::Output { + self + Self::from(rhs) + } +} + +impl Sub> for SymbolicUsize { + type Output = SymbolicUsize; + + fn sub(self, rhs: Usize) -> Self::Output { + self - Self::from(rhs) + } +} + +impl Add for Usize { + type Output = SymbolicUsize; + + fn add(self, rhs: usize) -> Self::Output { + SymbolicUsize::from(self) + rhs + } +} + +impl Sub for Usize { + type Output = SymbolicUsize; + + fn sub(self, rhs: usize) -> Self::Output { + SymbolicUsize::from(self) - rhs + } +} + +impl Add> for Usize { + type Output = SymbolicUsize; + + fn add(self, rhs: Usize) -> Self::Output { + SymbolicUsize::from(self) + rhs + } +} + +impl Sub> for Usize { + type Output = SymbolicUsize; + + fn sub(self, rhs: Usize) -> Self::Output { + SymbolicUsize::from(self) - rhs + } +} + +impl MulAssign> for SymbolicFelt { + fn mul_assign(&mut self, rhs: Felt) { + *self = Self::from(rhs); + } +} + +impl Mul> for Felt { + type Output = SymbolicFelt; + + fn mul(self, rhs: SymbolicFelt) -> Self::Output { + SymbolicFelt::::from(self) * rhs + } +} diff --git a/recursion/compiler/src/ir/types.rs b/recursion/compiler/src/ir/types.rs index 5eaa71ef5..7b7a0e449 100644 --- a/recursion/compiler/src/ir/types.rs +++ b/recursion/compiler/src/ir/types.rs @@ -11,6 +11,7 @@ use std::hash::Hash; use super::MemVariable; use super::Ptr; +use super::SymbolicUsize; use super::{Builder, Config, DslIR, SymbolicExt, SymbolicFelt, SymbolicVar, Variable}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -19,8 +20,7 @@ pub struct Var(pub u32, pub PhantomData); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Felt(pub u32, pub PhantomData); -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] - +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] pub struct Ext(pub u32, pub PhantomData<(F, EF)>); #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -75,6 +75,13 @@ impl Felt { pub fn id(&self) -> String { format!("felt{}", self.0) } + + pub fn inverse(&self) -> SymbolicFelt + where + F: Field, + { + SymbolicFelt::::one() / *self + } } impl Ext { @@ -85,13 +92,21 @@ impl Ext { pub fn id(&self) -> String { format!("ext{}", self.0) } + + pub fn inverse(&self) -> SymbolicExt + where + F: Field, + EF: ExtensionField, + { + SymbolicExt::::one() / *self + } } impl Variable for Usize { - type Expression = Self; + type Expression = SymbolicUsize; - fn uninit(_: &mut Builder) -> Self { - Usize::Const(0) + fn uninit(builder: &mut Builder) -> Self { + builder.uninit::>().into() } fn assign(&self, src: Self::Expression, builder: &mut Builder) { @@ -100,10 +115,10 @@ impl Variable for Usize { panic!("cannot assign to a constant usize") } Usize::Var(v) => match src { - Usize::Const(src) => { + SymbolicUsize::Const(src) => { builder.assign(*v, C::N::from_canonical_usize(src)); } - Usize::Var(src) => { + SymbolicUsize::Var(src) => { builder.assign(*v, src); } }, @@ -119,18 +134,16 @@ impl Variable for Usize { let rhs = rhs.into(); match (lhs, rhs) { - (Usize::Const(lhs), Usize::Const(rhs)) => { + (SymbolicUsize::Const(lhs), SymbolicUsize::Const(rhs)) => { assert_eq!(lhs, rhs, "constant usizes do not match"); } - (Usize::Const(lhs), Usize::Var(rhs)) => { - builder.push(DslIR::AssertEqVI(rhs, C::N::from_canonical_usize(lhs))); + (SymbolicUsize::Const(lhs), SymbolicUsize::Var(rhs)) => { + builder.assert_var_eq(C::N::from_canonical_usize(lhs), rhs); } - (Usize::Var(lhs), Usize::Const(rhs)) => { - builder.push(DslIR::AssertEqVI(lhs, C::N::from_canonical_usize(rhs))); - } - (Usize::Var(lhs), Usize::Var(rhs)) => { - builder.push(DslIR::AssertEqV(lhs, rhs)); + (SymbolicUsize::Var(lhs), SymbolicUsize::Const(rhs)) => { + builder.assert_var_eq(lhs, C::N::from_canonical_usize(rhs)); } + (SymbolicUsize::Var(lhs), SymbolicUsize::Var(rhs)) => builder.assert_var_eq(lhs, rhs), } } @@ -143,17 +156,17 @@ impl Variable for Usize { let rhs = rhs.into(); match (lhs, rhs) { - (Usize::Const(lhs), Usize::Const(rhs)) => { + (SymbolicUsize::Const(lhs), SymbolicUsize::Const(rhs)) => { assert_ne!(lhs, rhs, "constant usizes do not match"); } - (Usize::Const(lhs), Usize::Var(rhs)) => { - builder.push(DslIR::AssertNeVI(rhs, C::N::from_canonical_usize(lhs))); + (SymbolicUsize::Const(lhs), SymbolicUsize::Var(rhs)) => { + builder.assert_var_ne(C::N::from_canonical_usize(lhs), rhs); } - (Usize::Var(lhs), Usize::Const(rhs)) => { - builder.push(DslIR::AssertNeVI(lhs, C::N::from_canonical_usize(rhs))); + (SymbolicUsize::Var(lhs), SymbolicUsize::Const(rhs)) => { + builder.assert_var_ne(lhs, C::N::from_canonical_usize(rhs)); } - (Usize::Var(lhs), Usize::Var(rhs)) => { - builder.push(DslIR::AssertNeV(lhs, rhs)); + (SymbolicUsize::Var(lhs), SymbolicUsize::Var(rhs)) => { + builder.assert_var_ne(lhs, rhs); } } } diff --git a/recursion/compiler/src/lib.rs b/recursion/compiler/src/lib.rs index ae52d5d1a..a46067a43 100644 --- a/recursion/compiler/src/lib.rs +++ b/recursion/compiler/src/lib.rs @@ -1,7 +1,6 @@ extern crate alloc; pub mod asm; -pub mod builder; pub mod gnark; pub mod ir; pub mod util; diff --git a/recursion/compiler/src/verifier/constraints/domain.rs b/recursion/compiler/src/verifier/constraints/domain.rs new file mode 100644 index 000000000..3b8cf2712 --- /dev/null +++ b/recursion/compiler/src/verifier/constraints/domain.rs @@ -0,0 +1,218 @@ +use p3_commit::LagrangeSelectors; + +use crate::{ + ir::{Config, Felt, Usize}, + prelude::{Builder, Ext, SymbolicFelt, Var}, +}; +use p3_field::{AbstractField, TwoAdicField}; + +/// Reference: https://github.com/Plonky3/Plonky3/blob/main/commit/src/domain.rs#L55 +pub struct TwoAdicMultiplicativeCoset { + pub log_n: Usize, + pub size: Usize, + pub shift: Felt, + pub g: Felt, +} + +impl TwoAdicMultiplicativeCoset { + /// Reference: https://github.com/Plonky3/Plonky3/blob/main/commit/src/domain.rs#L74 + pub fn first_point(&self) -> Felt { + self.shift + } + + pub fn size(&self) -> Usize { + self.size + } + + pub fn gen(&self) -> Felt { + self.g + } +} + +impl Builder { + pub fn const_domain( + &mut self, + domain: &p3_commit::TwoAdicMultiplicativeCoset, + ) -> TwoAdicMultiplicativeCoset + where + C::F: TwoAdicField, + { + let log_d_val = domain.log_n as u32; + let g_val = C::F::two_adic_generator(domain.log_n); + // Initialize a domain. + TwoAdicMultiplicativeCoset:: { + log_n: self + .eval::, _>(C::N::from_canonical_u32(log_d_val)) + .into(), + size: self + .eval::, _>(C::N::from_canonical_u32(1 << (log_d_val))) + .into(), + shift: self.eval(domain.shift), + g: self.eval(g_val), + } + } + /// Reference: https://github.com/Plonky3/Plonky3/blob/main/commit/src/domain.rs#L77 + pub fn next_point( + &mut self, + domain: &TwoAdicMultiplicativeCoset, + point: Ext, + ) -> Ext { + self.eval(point * domain.gen()) + } + + /// Reference: https://github.com/Plonky3/Plonky3/blob/main/commit/src/domain.rs#L112 + pub fn selectors_at_point( + &mut self, + domain: &TwoAdicMultiplicativeCoset, + point: Ext, + ) -> LagrangeSelectors> { + let unshifted_point: Ext<_, _> = self.eval(point * domain.shift.inverse()); + let z_h_expr = + self.exp_power_of_2_v::>(unshifted_point, domain.log_n) - C::EF::one(); + let z_h: Ext<_, _> = self.eval(z_h_expr); + + LagrangeSelectors { + is_first_row: self.eval(z_h / (unshifted_point - C::EF::one())), + is_last_row: self.eval(z_h / (unshifted_point - domain.gen().inverse())), + is_transition: self.eval(unshifted_point - domain.gen().inverse()), + inv_zeroifier: self.eval(z_h.inverse()), + } + } + + /// Reference: https://github.com/Plonky3/Plonky3/blob/main/commit/src/domain.rs#L87 + pub fn zp_at_point( + &mut self, + domain: &TwoAdicMultiplicativeCoset, + point: Ext, + ) -> Ext { + // Compute (point * domain.shift.inverse()).exp_power_of_2(domain.log_n) - Ext::one() + let unshifted_power = + self.exp_power_of_2_v::>(point * domain.shift.inverse(), domain.log_n); + self.eval(unshifted_power - C::EF::one()) + } + + pub fn split_domains( + &mut self, + domain: &TwoAdicMultiplicativeCoset, + log_num_chunks: usize, + ) -> Vec> { + let num_chunks = 1 << log_num_chunks; + let log_n = self.eval(domain.log_n - log_num_chunks); + let size = self.power_of_two_usize(log_n); + + let g_dom = domain.gen(); + + let domain_power = |i| { + let mut result = SymbolicFelt::from(g_dom); + for _ in 0..i { + result *= g_dom; + } + result + }; + + // We can compute a generator for the domain by computing g_dom^{log_num_chunks} + let g = self.exp_power_of_2_v::>(g_dom, log_num_chunks.into()); + (0..num_chunks) + .map(|i| TwoAdicMultiplicativeCoset { + log_n, + size, + shift: self.eval(domain.shift * domain_power(i)), + g, + }) + .collect() + } +} + +#[cfg(test)] +mod tests { + use crate::asm::VmBuilder; + use crate::prelude::ExtConst; + + use super::*; + use p3_commit::{Pcs, PolynomialSpace}; + use p3_field::TwoAdicField; + use rand::{thread_rng, Rng}; + use sp1_core::stark::Dom; + use sp1_core::{stark::StarkGenericConfig, utils::BabyBearPoseidon2}; + use sp1_recursion_core::runtime::Runtime; + + fn domain_assertions>( + builder: &mut Builder, + domain: &TwoAdicMultiplicativeCoset, + domain_val: &p3_commit::TwoAdicMultiplicativeCoset, + zeta_val: C::EF, + ) { + // Get a random point. + let zeta: Ext<_, _> = builder.eval(zeta_val.cons()); + + // Compare the selector values of the reference and the builder. + let sels_expected = domain_val.selectors_at_point(zeta_val); + let sels = builder.selectors_at_point(domain, zeta); + builder.assert_ext_eq(sels.is_first_row, sels_expected.is_first_row.cons()); + builder.assert_ext_eq(sels.is_last_row, sels_expected.is_last_row.cons()); + builder.assert_ext_eq(sels.is_transition, sels_expected.is_transition.cons()); + + let zp_val = domain_val.zp_at_point(zeta_val); + let zp = builder.zp_at_point(domain, zeta); + builder.assert_ext_eq(zp, zp_val.cons()); + } + + #[test] + fn test_domain() { + type SC = BabyBearPoseidon2; + type F = ::Val; + type EF = ::Challenge; + type Challenger = ::Challenger; + type ScPcs = ::Pcs; + + let mut rng = thread_rng(); + let config = SC::default(); + let pcs = config.pcs(); + let natural_domain_for_degree = |degree: usize| -> Dom { + >::natural_domain_for_degree(pcs, degree) + }; + + // Initialize a builder. + let mut builder = VmBuilder::::default(); + for i in 0..5 { + let log_d_val = 10 + i; + + let log_quotient_degree = 2; + + // Initialize a reference doamin. + let domain_val = natural_domain_for_degree(1 << log_d_val); + let domain = builder.const_domain(&domain_val); + let zeta_val = rng.gen::(); + domain_assertions(&mut builder, &domain, &domain_val, zeta_val); + + // Try a shifted domain. + let disjoint_domain_val = + domain_val.create_disjoint_domain(1 << (log_d_val + log_quotient_degree)); + let disjoint_domain = builder.const_domain(&disjoint_domain_val); + domain_assertions( + &mut builder, + &disjoint_domain, + &disjoint_domain_val, + zeta_val, + ); + + // Now try splited domains + let qc_domains_val = disjoint_domain_val.split_domains(1 << log_quotient_degree); + for dom_val in qc_domains_val.iter() { + let dom = builder.const_domain(dom_val); + domain_assertions(&mut builder, &dom, dom_val, zeta_val); + } + + // Test the splitting of domains by the builder. + let qc_domains = builder.split_domains(&disjoint_domain, log_quotient_degree); + for (dom, dom_val) in qc_domains.iter().zip(qc_domains_val.iter()) { + domain_assertions(&mut builder, dom, dom_val, zeta_val); + } + } + + let program = builder.compile(); + + let mut runtime = Runtime::::new(&program); + runtime.run(); + } +} diff --git a/recursion/compiler/src/verifier/constraints/mod.rs b/recursion/compiler/src/verifier/constraints/mod.rs index e2f319ef2..111658716 100644 --- a/recursion/compiler/src/verifier/constraints/mod.rs +++ b/recursion/compiler/src/verifier/constraints/mod.rs @@ -1,95 +1,357 @@ +mod domain; +mod opening; pub mod utils; use p3_air::Air; +use p3_commit::LagrangeSelectors; use p3_field::AbstractExtensionField; use p3_field::AbstractField; -use p3_field::Field; -use sp1_core::air::MachineAir; -use sp1_core::stark::ChipOpenedValues; -use sp1_core::stark::{GenericVerifierConstraintFolder, MachineChip, StarkGenericConfig}; -use std::marker::PhantomData; - -use crate::prelude::{Builder, Ext, Felt, SymbolicExt}; -use crate::verifier::StarkGenericBuilderConfig; - -pub fn verify_constraints>( - builder: &mut Builder>, - chip: &MachineChip, - opening: &ChipOpenedValues, - g: Felt, - zeta: Ext, - alpha: Ext, -) where - A: for<'a> Air< - GenericVerifierConstraintFolder< - 'a, - SC::Val, - SC::Challenge, - Ext, - SymbolicExt, - >, - >, -{ - let g_inv: Felt = builder.eval(g / SC::Val::one()); - let z_h: Ext = builder.exp_power_of_2(zeta, opening.log_degree); - let one: Ext = builder.eval(SC::Val::one()); - let is_first_row = builder.eval(z_h / (zeta - one)); - let is_last_row = builder.eval(z_h / (zeta - g_inv)); - let is_transition = builder.eval(zeta - g_inv); - - let preprocessed = builder.const_opened_values(&opening.preprocessed); - let main = builder.const_opened_values(&opening.main); - let perm = builder.const_opened_values(&opening.permutation); - - let zero: Ext = builder.eval(SC::Val::zero()); - let zero_expr: SymbolicExt = zero.into(); - let mut folder = GenericVerifierConstraintFolder::< - SC::Val, - SC::Challenge, - Ext, - SymbolicExt, - > { - preprocessed: preprocessed.view(), - main: main.view(), - perm: perm.view(), - perm_challenges: &[SC::Challenge::zero(), SC::Challenge::zero()], - cumulative_sum: builder.eval(SC::Val::zero()), - is_first_row, - is_last_row, - is_transition, - alpha, - accumulator: zero_expr, - _marker: PhantomData, - }; +use sp1_core::stark::AirOpenedValues; +use sp1_core::stark::{MachineChip, StarkGenericConfig}; + +use crate::prelude::Config; +use crate::prelude::ExtConst; +use crate::prelude::{Builder, Ext, SymbolicExt}; + +pub use domain::*; +pub use opening::*; + +use super::folder::RecursiveVerifierConstraintFolder; + +// pub struct TwoAdicCose + +impl Builder { + pub fn eval_constrains( + &mut self, + chip: &MachineChip, + opening: &ChipOpening, + selectors: &LagrangeSelectors>, + alpha: Ext, + permutation_challenges: &[C::EF], + ) -> Ext + where + SC: StarkGenericConfig, + A: for<'a> Air>, + { + let mut unflatten = |v: &[Ext]| { + v.chunks_exact(SC::Challenge::D) + .map(|chunk| { + self.eval( + chunk + .iter() + .enumerate() + .map(|(e_i, &x)| x * C::EF::monomial(e_i).cons()) + .sum::>(), + ) + }) + .collect::>>() + }; + let perm_opening = AirOpenedValues { + local: unflatten(&opening.permutation.local), + next: unflatten(&opening.permutation.next), + }; - let monomials = (0..SC::Challenge::D) - .map(SC::Challenge::monomial) - .collect::>(); + let zero: Ext = self.eval(SC::Val::zero()); + let mut folder = RecursiveVerifierConstraintFolder { + builder: self, + preprocessed: opening.preprocessed.view(), + main: opening.main.view(), + perm: perm_opening.view(), + perm_challenges: permutation_challenges, + cumulative_sum: opening.cumulative_sum, + is_first_row: selectors.is_first_row, + is_last_row: selectors.is_last_row, + is_transition: selectors.is_transition, + alpha, + accumulator: zero, + }; - let quotient_parts = opening - .quotient - .iter() - .map(|chunk| { - chunk + chip.eval(&mut folder); + folder.accumulator + } + + pub fn recompute_quotient( + &mut self, + opening: &ChipOpening, + qc_domains: Vec>, + zeta: Ext, + ) -> Ext { + let zps = qc_domains + .iter() + .enumerate() + .map(|(i, domain)| { + qc_domains + .iter() + .enumerate() + .filter(|(j, _)| *j != i) + .map(|(_, other_domain)| { + // Calculate: other_domain.zp_at_point(zeta) + // * other_domain.zp_at_point(domain.first_point()).inverse() + let first_point: Ext<_, _> = self.eval(domain.first_point()); + self.zp_at_point(other_domain, zeta) + * self.zp_at_point(other_domain, first_point).inverse() + }) + .product::>() + }) + .collect::>>() + .into_iter() + .map(|x| self.eval(x)) + .collect::>>(); + + self.eval( + opening + .quotient .iter() - .zip(monomials.iter()) - .map(|(x, m)| *x * *m) - .sum() - }) - .collect::>(); - - let mut zeta_powers = zeta; - let quotient: Ext = builder.eval(SC::Val::zero()); - let quotient_expr: SymbolicExt = quotient.into(); - for quotient_part in quotient_parts { - zeta_powers = builder.eval(zeta_powers * zeta); - builder.assign(quotient, zeta_powers * quotient_part); + .enumerate() + .map(|(ch_i, ch)| { + assert_eq!(ch.len(), C::EF::D); + ch.iter() + .enumerate() + .map(|(e_i, &c)| zps[ch_i] * C::EF::monomial(e_i) * c) + .sum::>() + }) + .sum::>(), + ) + } + + /// Reference: `[sp1_core::stark::Verifier::verify_constraints]` + #[allow(clippy::too_many_arguments)] + pub fn verify_constraints( + &mut self, + chip: &MachineChip, + opening: &ChipOpening, + trace_domain: TwoAdicMultiplicativeCoset, + qc_domains: Vec>, + zeta: Ext, + alpha: Ext, + permutation_challenges: &[C::EF], + ) where + SC: StarkGenericConfig, + A: for<'a> Air>, + { + let sels = self.selectors_at_point(&trace_domain, zeta); + + let folded_constraints = + self.eval_constrains::(chip, opening, &sels, alpha, permutation_challenges); + + let quotient: Ext<_, _> = self.recompute_quotient(opening, qc_domains, zeta); + + // Assert that the quotient times the zerofier is equal to the folded constraints. + self.assert_ext_eq(folded_constraints * sels.inv_zeroifier, quotient); + } +} + +#[cfg(test)] +mod tests { + use itertools::{izip, Itertools}; + use serde::{de::DeserializeOwned, Serialize}; + use sp1_core::{ + air::MachineAir, + stark::{ + Chip, Com, Dom, MachineStark, OpeningProof, PcsProverData, RiscvAir, ShardCommitment, + ShardMainData, ShardProof, StarkGenericConfig, Verifier, + }, + utils::BabyBearPoseidon2, + SP1Prover, SP1Stdin, + }; + use sp1_recursion_core::runtime::Runtime; + + use crate::{asm::VmBuilder, prelude::ExtConst}; + use p3_challenger::{CanObserve, FieldChallenger}; + use p3_field::PrimeField32; + + use p3_commit::{Pcs, PolynomialSpace}; + + #[allow(clippy::type_complexity)] + fn get_shard_data<'a, SC>( + machine: &'a MachineStark>, + proof: &ShardProof, + challenger: &mut SC::Challenger, + ) -> ( + Vec<&'a Chip>>, + Vec>, + Vec>>, + Vec, + SC::Challenge, + SC::Challenge, + ) + where + SC: StarkGenericConfig + Default, + SC::Challenger: Clone, + OpeningProof: Send + Sync, + Com: Send + Sync, + PcsProverData: Send + Sync, + ShardMainData: Serialize + DeserializeOwned, + SC::Val: p3_field::PrimeField32, + { + let ShardProof { + commitment, + opened_values, + .. + } = proof; + + let ShardCommitment { + permutation_commit, + quotient_commit, + .. + } = commitment; + + // Extract verification metadata. + let pcs = machine.config().pcs(); + + let permutation_challenges = (0..2) + .map(|_| challenger.sample_ext_element::()) + .collect::>(); + + challenger.observe(permutation_commit.clone()); + + let alpha = challenger.sample_ext_element::(); + + // Observe the quotient commitments. + challenger.observe(quotient_commit.clone()); + + let zeta = challenger.sample_ext_element::(); + + let chips = machine + .chips() + .iter() + .filter(|chip| proof.chip_ids.contains(&chip.name())) + .collect::>(); + + let log_degrees = opened_values + .chips + .iter() + .map(|val| val.log_degree) + .collect::>(); + + let log_quotient_degrees = chips + .iter() + .map(|chip| chip.log_quotient_degree()) + .collect::>(); + + let trace_domains = log_degrees + .iter() + .map(|log_degree| pcs.natural_domain_for_degree(1 << log_degree)) + .collect::>(); + + let quotient_chunk_domains = trace_domains + .iter() + .zip_eq(log_degrees) + .zip_eq(log_quotient_degrees) + .map(|((domain, log_degree), log_quotient_degree)| { + let quotient_degree = 1 << log_quotient_degree; + let quotient_domain = + domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree)); + quotient_domain.split_domains(quotient_degree) + }) + .collect::>(); + + ( + chips, + trace_domains, + quotient_chunk_domains, + permutation_challenges, + alpha, + zeta, + ) } - let quotient: Ext = builder.eval(quotient_expr); - folder.alpha = alpha; - chip.eval(&mut folder); - let folded_constraints = folder.accumulator; - let expected_folded_constraints = z_h * quotient; - builder.assert_ext_eq(folded_constraints, expected_folded_constraints); + #[test] + fn test_verify_constraints() { + type SC = BabyBearPoseidon2; + type F = ::Val; + type EF = ::Challenge; + type A = RiscvAir; + + // Generate a dummy proof. + sp1_core::utils::setup_logger(); + let elf = include_bytes!( + "../../../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf" + ); + + let machine = A::machine(SC::default()); + let mut challenger = machine.config().challenger(); + let proofs = SP1Prover::prove_with_config(elf, SP1Stdin::new(), machine.config().clone()) + .unwrap() + .proof + .shard_proofs; + println!("Proof generated successfully"); + + proofs.iter().for_each(|proof| { + challenger.observe(proof.commitment.main_commit); + }); + + // Run the verify inside the DSL and compare it to the calculated value. + let mut builder = VmBuilder::::default(); + + for proof in proofs.into_iter().take(1) { + let ( + chips, + trace_domains_vals, + quotient_chunk_domains_vals, + permutation_challenges, + alpha_val, + zeta_val, + ) = get_shard_data(&machine, &proof, &mut challenger); + + for (chip, trace_domain_val, qc_domains_vals, values_vals) in izip!( + chips.iter(), + trace_domains_vals, + quotient_chunk_domains_vals, + proof.opened_values.chips.iter(), + ) { + // Compute the expected folded constraints value. + let sels_val = trace_domain_val.selectors_at_point(zeta_val); + let folded_constraints_val = Verifier::::eval_constraints( + chip, + values_vals, + &sels_val, + alpha_val, + &permutation_challenges, + ); + + // Compute the folded constraints value in the DSL. + let values = builder.const_chip_opening(values_vals); + let alpha = builder.eval(alpha_val.cons()); + let zeta = builder.eval(zeta_val.cons()); + let trace_domain = builder.const_domain(&trace_domain_val); + let sels = builder.selectors_at_point(&trace_domain, zeta); + let folded_constraints = builder.eval_constrains::( + chip, + &values, + &sels, + alpha, + permutation_challenges.as_slice(), + ); + + // Assert that the two values are equal. + builder.assert_ext_eq(folded_constraints, folded_constraints_val.cons()); + + // Compute the expected quotient value. + let quotient_val = + Verifier::::recompute_quotient(values_vals, &qc_domains_vals, zeta_val); + + let qc_domains = qc_domains_vals + .iter() + .map(|domain| builder.const_domain(domain)) + .collect::>(); + let quotient = builder.recompute_quotient(&values, qc_domains, zeta); + + // Assert that the two values are equal. + builder.assert_ext_eq(quotient, quotient_val.cons()); + + // Assert that the constraint-quotient relation holds. + builder.assert_ext_eq(folded_constraints * sels.inv_zeroifier, quotient); + } + } + + let program = builder.compile(); + + let mut runtime = Runtime::::new(&program); + runtime.run(); + println!( + "The program executed successfully, number of cycles: {}", + runtime.clk.as_canonical_u32() / 4 + ); + } } diff --git a/recursion/compiler/src/verifier/constraints/opening.rs b/recursion/compiler/src/verifier/constraints/opening.rs new file mode 100644 index 000000000..42ffc3ea4 --- /dev/null +++ b/recursion/compiler/src/verifier/constraints/opening.rs @@ -0,0 +1,30 @@ +use sp1_core::stark::{AirOpenedValues, ChipOpenedValues}; + +use crate::prelude::{Builder, Config, Ext, ExtConst, Usize}; + +#[derive(Debug, Clone)] +pub struct ChipOpening { + pub preprocessed: AirOpenedValues>, + pub main: AirOpenedValues>, + pub permutation: AirOpenedValues>, + pub quotient: Vec>>, + pub cumulative_sum: Ext, + pub log_degree: Usize, +} + +impl Builder { + pub fn const_chip_opening(&mut self, opening: &ChipOpenedValues) -> ChipOpening { + ChipOpening { + preprocessed: self.const_opened_values(&opening.preprocessed), + main: self.const_opened_values(&opening.main), + permutation: self.const_opened_values(&opening.permutation), + quotient: opening + .quotient + .iter() + .map(|q| q.iter().map(|s| self.eval(s.cons())).collect()) + .collect(), + cumulative_sum: self.eval(opening.cumulative_sum.cons()), + log_degree: self.eval(opening.log_degree), + } + } +} diff --git a/recursion/compiler/src/verifier/constraints/utils.rs b/recursion/compiler/src/verifier/constraints/utils.rs index 8520f00b1..dc900bca8 100644 --- a/recursion/compiler/src/verifier/constraints/utils.rs +++ b/recursion/compiler/src/verifier/constraints/utils.rs @@ -1,6 +1,9 @@ +use std::ops::{Add, Mul}; + +use p3_field::AbstractField; use sp1_core::stark::AirOpenedValues; -use crate::prelude::{Builder, Config, Ext, SymbolicExt}; +use crate::prelude::{Builder, Config, Ext, ExtConst, Felt, SymbolicExt, Usize, Var, Variable}; impl Builder { pub fn const_opened_values( @@ -20,4 +23,45 @@ impl Builder { .collect(), } } + + pub fn exp_power_of_2_v>( + &mut self, + base: impl Into, + power_log: Usize, + ) -> V + where + V: Copy + Mul, + { + let result: V = self.eval(base); + self.range(0, power_log) + .for_each(|_, builder| builder.assign(result, result * result)); + result + } + + /// Multiplies `base` by `2^{log_power}`. + pub fn sll>(&mut self, base: impl Into, shift: Usize) -> V + where + V: Copy + Add, + { + let result: V = self.eval(base); + self.range(0, shift) + .for_each(|_, builder| builder.assign(result, result + result)); + result + } + + pub fn power_of_two_usize(&mut self, power: Usize) -> Usize { + self.sll(Usize::Const(1), power) + } + + pub fn power_of_two_var(&mut self, power: Usize) -> Var { + self.sll(C::N::one(), power) + } + + pub fn power_of_two_felt(&mut self, power: Usize) -> Felt { + self.sll(C::F::one(), power) + } + + pub fn power_of_two_expr(&mut self, power: Usize) -> Ext { + self.sll(C::EF::one().cons(), power) + } } diff --git a/recursion/compiler/src/verifier/folder.rs b/recursion/compiler/src/verifier/folder.rs new file mode 100644 index 000000000..d373c4db0 --- /dev/null +++ b/recursion/compiler/src/verifier/folder.rs @@ -0,0 +1,96 @@ +use p3_air::{AirBuilder, ExtensionBuilder, PairBuilder, PermutationAirBuilder, TwoRowMatrixView}; +use sp1_core::air::{EmptyMessageBuilder, MultiTableAirBuilder}; + +use crate::{ + ir::{Builder, Config, Ext}, + prelude::SymbolicExt, +}; + +pub struct RecursiveVerifierConstraintFolder<'a, C: Config> { + pub builder: &'a mut Builder, + pub preprocessed: TwoRowMatrixView<'a, Ext>, + pub main: TwoRowMatrixView<'a, Ext>, + pub perm: TwoRowMatrixView<'a, Ext>, + pub perm_challenges: &'a [C::EF], + pub cumulative_sum: Ext, + pub is_first_row: Ext, + pub is_last_row: Ext, + pub is_transition: Ext, + pub alpha: Ext, + pub accumulator: Ext, +} + +impl<'a, C: Config> AirBuilder for RecursiveVerifierConstraintFolder<'a, C> { + type F = C::F; + type Expr = SymbolicExt; + type Var = Ext; + type M = TwoRowMatrixView<'a, Ext>; + + fn main(&self) -> Self::M { + self.main + } + + fn is_first_row(&self) -> Self::Expr { + self.is_first_row.into() + } + + fn is_last_row(&self) -> Self::Expr { + self.is_last_row.into() + } + + fn is_transition_window(&self, size: usize) -> Self::Expr { + if size == 2 { + self.is_transition.into() + } else { + panic!("uni-stark only supports a window size of 2") + } + } + + fn assert_zero>(&mut self, x: I) { + let x: Self::Expr = x.into(); + self.builder + .assign(self.accumulator, self.accumulator * self.alpha); + self.builder.assign(self.accumulator, self.accumulator + x); + } +} + +impl<'a, C: Config> ExtensionBuilder for RecursiveVerifierConstraintFolder<'a, C> { + type EF = C::EF; + type ExprEF = SymbolicExt; + type VarEF = Ext; + + fn assert_zero_ext(&mut self, x: I) + where + I: Into, + { + self.assert_zero(x) + } +} + +impl<'a, C: Config> PermutationAirBuilder for RecursiveVerifierConstraintFolder<'a, C> { + type MP = TwoRowMatrixView<'a, Self::Var>; + + fn permutation(&self) -> Self::MP { + self.perm + } + + fn permutation_randomness(&self) -> &[Self::EF] { + self.perm_challenges + } +} + +impl<'a, C: Config> MultiTableAirBuilder for RecursiveVerifierConstraintFolder<'a, C> { + type Sum = Self::Var; + + fn cumulative_sum(&self) -> Self::Sum { + self.cumulative_sum + } +} + +impl<'a, C: Config> PairBuilder for RecursiveVerifierConstraintFolder<'a, C> { + fn preprocessed(&self) -> Self::M { + self.preprocessed + } +} + +impl<'a, C: Config> EmptyMessageBuilder for RecursiveVerifierConstraintFolder<'a, C> {} diff --git a/recursion/compiler/src/verifier/fri/pcs.rs b/recursion/compiler/src/verifier/fri/pcs.rs index 0a833fd8b..e946de36e 100644 --- a/recursion/compiler/src/verifier/fri/pcs.rs +++ b/recursion/compiler/src/verifier/fri/pcs.rs @@ -121,12 +121,11 @@ pub fn verify_two_adic_pcs( let g = builder.generator(); let two_adic_generator = builder.two_adic_generator(Usize::Var(log_height)); let g_mul_two_adic_generator = builder.eval(g * two_adic_generator); - let x: SymbolicExt = builder - .exp_usize_f(g_mul_two_adic_generator, Usize::Var(rev_reduced_index)) - .into(); + let x: Felt = builder + .exp_usize_f(g_mul_two_adic_generator, Usize::Var(rev_reduced_index)); builder.range(0, mat_points.len()).for_each(|l, builder| { - let z: SymbolicExt = builder.get(&mat_points, l).into(); + let z: Ext = builder.get(&mat_points, l); let ps_at_z = builder.get(&mat_values, l); builder.range(0, ps_at_z.len()).for_each(|m, builder| { let p_at_x: SymbolicExt = diff --git a/recursion/compiler/src/verifier/mod.rs b/recursion/compiler/src/verifier/mod.rs index 659151a9f..c8a85e278 100644 --- a/recursion/compiler/src/verifier/mod.rs +++ b/recursion/compiler/src/verifier/mod.rs @@ -1,5 +1,6 @@ pub mod challenger; pub mod constraints; +pub mod folder; pub mod fri; pub use constraints::*; diff --git a/recursion/compiler/tests/eval_constraints.rs b/recursion/compiler/tests/eval_constraints.rs new file mode 100644 index 000000000..aadb63918 --- /dev/null +++ b/recursion/compiler/tests/eval_constraints.rs @@ -0,0 +1,165 @@ +use std::marker::PhantomData; + +use sp1_core::stark::VerifierConstraintFolder; +use sp1_recursion_compiler::ir::Ext; +use sp1_recursion_compiler::ir::Felt; +use sp1_recursion_compiler::prelude::Builder; + +use p3_air::Air; +use p3_field::AbstractField; +use p3_field::Field; +use p3_field::PrimeField32; +use rand::thread_rng; +use rand::Rng; +use sp1_core::air::MachineAir; +use sp1_core::stark::ChipOpenedValues; +use sp1_core::stark::MachineChip; +use sp1_core::stark::RiscvAir; +use sp1_core::stark::StarkAir; +use sp1_core::stark::StarkGenericConfig; +use sp1_core::utils; +use sp1_core::utils::BabyBearPoseidon2; +use sp1_core::SP1Prover; +use sp1_core::SP1Stdin; +use sp1_recursion_compiler::asm::VmBuilder; +use sp1_recursion_compiler::ir::Config; +use sp1_recursion_compiler::ir::ExtConst; +use sp1_recursion_compiler::verifier::folder::RecursiveVerifierConstraintFolder; +use sp1_recursion_core::runtime::Runtime; + +pub fn eval_constraints_test( + builder: &mut Builder, + chip: &MachineChip, + opening: &ChipOpenedValues, + g_val: SC::Val, + zeta_val: SC::Challenge, + alpha_val: SC::Challenge, +) where + SC: StarkGenericConfig, + C: Config, + A: MachineAir + StarkAir + for<'a> Air>, +{ + let g_inv_val = g_val.inverse(); + let g: Felt<_> = builder.eval(g_val); + let g_inv: Felt = builder.eval(g.inverse()); + builder.assert_felt_eq(g_inv, g_inv_val); + + let z_h_val = zeta_val.exp_power_of_2(opening.log_degree); + let zeta: Ext = builder.eval(zeta_val.cons()); + let z_h: Ext = builder.exp_power_of_2(zeta, opening.log_degree); + builder.assert_ext_eq(z_h, z_h_val.cons()); + let one: Ext = builder.eval(SC::Val::one()); + let is_first_row: Ext<_, _> = builder.eval(z_h / (zeta - one)); + let is_last_row: Ext<_, _> = builder.eval(z_h / (zeta - g_inv)); + let is_transition: Ext<_, _> = builder.eval(zeta - g_inv); + + let is_first_row_val = z_h_val / (zeta_val - SC::Challenge::one()); + let is_last_row_val = z_h_val / (zeta_val - g_inv_val); + let is_transition_val = zeta_val - g_inv_val; + + builder.assert_ext_eq(is_first_row, is_first_row_val.cons()); + builder.assert_ext_eq(is_last_row, is_last_row_val.cons()); + builder.assert_ext_eq(is_transition, is_transition_val.cons()); + + let preprocessed = builder.const_opened_values(&opening.preprocessed); + let main = builder.const_opened_values(&opening.main); + let perm = builder.const_opened_values(&opening.permutation); + + let zero: Ext = builder.eval(SC::Val::zero()); + let cumulative_sum = builder.eval(SC::Val::zero()); + let alpha = builder.eval(alpha_val.cons()); + let mut folder = RecursiveVerifierConstraintFolder { + builder, + preprocessed: preprocessed.view(), + main: main.view(), + perm: perm.view(), + perm_challenges: &[SC::Challenge::one(), SC::Challenge::one()], + cumulative_sum, + is_first_row, + is_last_row, + is_transition, + alpha, + accumulator: zero, + }; + + chip.eval(&mut folder); + let folded_constraints = folder.accumulator; + + let mut test_folder = VerifierConstraintFolder:: { + preprocessed: opening.preprocessed.view(), + main: opening.main.view(), + perm: opening.permutation.view(), + perm_challenges: &[SC::Challenge::one(), SC::Challenge::one()], + cumulative_sum: SC::Challenge::zero(), + is_first_row: is_first_row_val, + is_last_row: is_last_row_val, + is_transition: is_transition_val, + alpha: alpha_val, + accumulator: SC::Challenge::zero(), + _marker: PhantomData, + }; + + chip.eval(&mut test_folder); + let folded_constraints_val = test_folder.accumulator; + + builder.assert_ext_eq(folded_constraints, folded_constraints_val.cons()); +} + +#[test] +fn test_compiler_eval_constraints() { + type SC = BabyBearPoseidon2; + type F = ::Val; + type EF = ::Challenge; + + let mut rng = thread_rng(); + + // Generate a dummy proof. + utils::setup_logger(); + let elf = include_bytes!("../../../examples/fibonacci/program/elf/riscv32im-succinct-zkvm-elf"); + let proofs = SP1Prover::prove(elf, SP1Stdin::new()) + .unwrap() + .proof + .shard_proofs; + + println!("Proof generated successfully"); + + // Extract verification metadata. + let machine = RiscvAir::machine(SC::new()); + + // Run the verify inside the DSL. + let mut builder = VmBuilder::::default(); + let g_val = F::one(); + + let zeta_val = rng.gen::(); + let alpha_val = rng.gen::(); + + for shard_proof in proofs.into_iter().take(1) { + let chips = machine + .chips() + .iter() + .filter(|chip| shard_proof.chip_ids.contains(&chip.name())) + .collect::>(); + for (chip, values) in chips + .into_iter() + .zip(shard_proof.opened_values.chips.iter()) + { + eval_constraints_test::<_, SC, _>( + &mut builder, + chip, + values, + g_val, + zeta_val, + alpha_val, + ) + } + } + + let program = builder.compile(); + + let mut runtime = Runtime::::new(&program); + runtime.run(); + println!( + "The program executed successfully, number of cycles: {}", + runtime.clk.as_canonical_u32() / 4 + ); +} diff --git a/recursion/core/src/air/block.rs b/recursion/core/src/air/block.rs new file mode 100644 index 000000000..5a1354df9 --- /dev/null +++ b/recursion/core/src/air/block.rs @@ -0,0 +1,104 @@ +use p3_air::AirBuilder; +use p3_field::AbstractField; +use p3_field::ExtensionField; +use p3_field::Field; +use p3_field::PrimeField32; +use sp1_core::air::{BinomialExtension, SP1AirBuilder}; +use sp1_derive::AlignedBorrow; + +use std::ops::Index; + +use crate::runtime::D; + +/// The smallest unit of memory that can be read and written to. +#[derive(AlignedBorrow, Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] +#[repr(C)] +pub struct Block(pub [T; D]); + +pub trait BlockBuilder: AirBuilder { + fn assert_block_eq, Rhs: Into>( + &mut self, + lhs: Block, + rhs: Block, + ) { + for (l, r) in lhs.0.into_iter().zip(rhs.0) { + self.assert_eq(l, r); + } + } +} + +impl BlockBuilder for AB {} + +impl Block { + pub fn map(self, f: F) -> Block + where + F: FnMut(T) -> U, + { + Block(self.0.map(f)) + } + + pub fn ext(&self) -> E + where + T: Field, + E: ExtensionField, + { + E::from_base_slice(&self.0) + } +} + +impl Block { + pub fn as_extension>(&self) -> BinomialExtension { + let arr: [AB::Expr; 4] = self.0.clone().map(|x| AB::Expr::zero() + x); + BinomialExtension(arr) + } + + pub fn as_extension_from_base>( + &self, + base: AB::Expr, + ) -> BinomialExtension { + let mut arr: [AB::Expr; 4] = self.0.clone().map(|_| AB::Expr::zero()); + arr[0] = base; + + BinomialExtension(arr) + } +} + +impl From<[T; D]> for Block { + fn from(arr: [T; D]) -> Self { + Self(arr) + } +} + +impl From for Block { + fn from(value: F) -> Self { + Self([value, F::zero(), F::zero(), F::zero()]) + } +} + +impl From<&[T]> for Block { + fn from(slice: &[T]) -> Self { + let arr: [T; D] = slice.try_into().unwrap(); + Self(arr) + } +} + +impl Index for Block +where + [T]: Index, +{ + type Output = <[T] as Index>::Output; + + #[inline] + fn index(&self, index: I) -> &Self::Output { + Index::index(&self.0, index) + } +} + +impl IntoIterator for Block { + type Item = T; + type IntoIter = std::array::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.0.into_iter() + } +} diff --git a/recursion/core/src/air/extension.rs b/recursion/core/src/air/extension.rs index c9b643c8b..179dca3b6 100644 --- a/recursion/core/src/air/extension.rs +++ b/recursion/core/src/air/extension.rs @@ -1,7 +1,11 @@ +use p3_field::extension::{BinomialExtensionField, BinomiallyExtendable}; +use p3_field::{AbstractExtensionField, Field}; use sp1_core::air::BinomialExtension; use super::Block; +use crate::runtime::D; + pub trait BinomialExtensionUtils { fn from_block(block: Block) -> Self; @@ -17,3 +21,17 @@ impl BinomialExtensionUtils for BinomialExtension { Block(self.0.clone()) } } + +impl BinomialExtensionUtils for BinomialExtensionField +where + AF: Field, + AF::F: BinomiallyExtendable, +{ + fn from_block(block: Block) -> Self { + Self::from_base_slice(&block.0) + } + + fn as_block(&self) -> Block { + Block(self.as_base_slice().try_into().unwrap()) + } +} diff --git a/recursion/core/src/air/is_ext_zero.rs b/recursion/core/src/air/is_ext_zero.rs new file mode 100644 index 000000000..73502d416 --- /dev/null +++ b/recursion/core/src/air/is_ext_zero.rs @@ -0,0 +1,93 @@ +//! An operation to check if the input is 0. +//! +//! This is guaranteed to return 1 if and only if the input is 0. +//! +//! The idea is that 1 - input * inverse is exactly the boolean value indicating whether the input +//! is 0. +use crate::air::Block; +use p3_air::AirBuilder; +use p3_field::extension::BinomialExtensionField; +use p3_field::extension::BinomiallyExtendable; +use p3_field::AbstractField; +use p3_field::Field; +use sp1_core::air::BinomialExtension; +use sp1_derive::AlignedBorrow; + +use crate::air::extension::BinomialExtensionUtils; +use sp1_core::air::SP1AirBuilder; + +use crate::runtime::D; + +/// A set of columns needed to compute whether the given word is 0. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct IsExtZeroOperation { + /// The inverse of the input. + pub inverse: Block, + + /// Result indicating whether the input is 0. This equals `inverse * input == 0`. + pub result: T, +} + +impl> IsExtZeroOperation { + pub fn populate(&mut self, a: Block) -> F { + let a = BinomialExtensionField::::from_block(a); + + let (inverse, result) = if a.is_zero() { + (BinomialExtensionField::zero(), F::one()) + } else { + (a.inverse(), F::zero()) + }; + + self.inverse = inverse.as_block(); + self.result = result; + + let prod = inverse * a; + debug_assert!(prod == BinomialExtensionField::::one() || prod.is_zero()); + + result + } + + pub fn eval( + builder: &mut AB, + a: BinomialExtension, + cols: IsExtZeroOperation, + is_real: AB::Expr, + ) { + // Assert that the `is_real` is a boolean. + builder.assert_bool(is_real.clone()); + // Assert that the result is boolean. + builder.assert_bool(cols.result); + + // 1. Input == 0 => is_zero = 1 regardless of the inverse. + // 2. Input != 0 + // 2.1. inverse is correctly set => is_zero = 0. + // 2.2. inverse is incorrect + // 2.2.1 inverse is nonzero => is_zero isn't bool, it fails. + // 2.2.2 inverse is 0 => is_zero is 1. But then we would assert that a = 0. And that + // assert fails. + + // If the input is 0, then any product involving it is 0. If it is nonzero and its inverse + // is correctly set, then the product is 1. + let one_ext = BinomialExtension::::from_base(AB::Expr::one()); + + let inverse = cols.inverse.as_extension::(); + + let is_zero = one_ext.clone() - inverse * a.clone(); + let result_ext = BinomialExtension::::from_base(cols.result.into()); + + for (eq_z, res) in is_zero.into_iter().zip(result_ext.0) { + builder.when(is_real.clone()).assert_eq(eq_z, res); + } + + builder.when(is_real.clone()).assert_bool(cols.result); + + // If the result is 1, then the input is 0. + for x in a { + builder + .when(is_real.clone()) + .when(cols.result) + .assert_zero(x.clone()); + } + } +} diff --git a/recursion/core/src/air/mod.rs b/recursion/core/src/air/mod.rs index 3688692ec..5bd24f7d8 100644 --- a/recursion/core/src/air/mod.rs +++ b/recursion/core/src/air/mod.rs @@ -1,89 +1,7 @@ +mod block; mod extension; +mod is_ext_zero; -use std::ops::Index; - +pub use block::*; pub use extension::*; - -use p3_air::AirBuilder; -use p3_field::AbstractField; -use p3_field::ExtensionField; -use p3_field::Field; -use p3_field::PrimeField32; -use sp1_core::air::{BinomialExtension, SP1AirBuilder}; -use sp1_derive::AlignedBorrow; - -use crate::runtime::D; - -/// The smallest unit of memory that can be read and written to. -#[derive(AlignedBorrow, Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] -#[repr(C)] -pub struct Block(pub [T; D]); - -pub trait BlockBuilder: AirBuilder { - fn assert_block_eq, Rhs: Into>( - &mut self, - lhs: Block, - rhs: Block, - ) { - for (l, r) in lhs.0.into_iter().zip(rhs.0) { - self.assert_eq(l, r); - } - } -} - -impl BlockBuilder for AB {} - -impl Block { - pub fn map(self, f: F) -> Block - where - F: FnMut(T) -> U, - { - Block(self.0.map(f)) - } - - pub fn ext(&self) -> E - where - T: Field, - E: ExtensionField, - { - E::from_base_slice(&self.0) - } -} - -impl Block { - pub fn as_extension>(&self) -> BinomialExtension { - let arr: [AB::Expr; 4] = self.0.clone().map(|x| AB::Expr::zero() + x); - BinomialExtension(arr) - } -} - -impl From<[T; D]> for Block { - fn from(arr: [T; D]) -> Self { - Self(arr) - } -} - -impl From for Block { - fn from(value: F) -> Self { - Self([value, F::zero(), F::zero(), F::zero()]) - } -} - -impl From<&[T]> for Block { - fn from(slice: &[T]) -> Self { - let arr: [T; D] = slice.try_into().unwrap(); - Self(arr) - } -} - -impl Index for Block -where - [T]: Index, -{ - type Output = <[T] as Index>::Output; - - #[inline] - fn index(&self, index: I) -> &Self::Output { - Index::index(&self.0, index) - } -} +pub use is_ext_zero::*; diff --git a/recursion/core/src/cpu/air.rs b/recursion/core/src/cpu/air.rs index 30ed0e422..5803cc1fc 100644 --- a/recursion/core/src/cpu/air.rs +++ b/recursion/core/src/cpu/air.rs @@ -1,7 +1,5 @@ -use crate::air::BinomialExtensionUtils; use crate::air::BlockBuilder; use crate::cpu::CpuChip; -use crate::runtime::Opcode; use core::mem::size_of; use p3_air::Air; use p3_air::AirBuilder; @@ -12,9 +10,7 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use p3_matrix::MatrixRowSlices; use sp1_core::air::AirInteraction; -use sp1_core::air::BinomialExtension; use sp1_core::lookup::InteractionKind; -use sp1_core::operations::IsZeroOperation; use sp1_core::stark::SP1AirBuilder; use sp1_core::utils::indices_arr; use sp1_core::{air::MachineAir, utils::pad_to_power_of_two}; @@ -50,7 +46,7 @@ impl MachineAir for CpuChip { .cpu_events .iter() .enumerate() - .map(|(i, event)| { + .map(|(_, event)| { let mut row = [F::zero(); NUM_CPU_COLS]; let cols: &mut CpuCols = row.as_mut_slice().borrow_mut(); @@ -64,24 +60,6 @@ impl MachineAir for CpuChip { cols.instruction.op_c = event.instruction.op_c; cols.instruction.imm_b = F::from_canonical_u32(event.instruction.imm_b as u32); cols.instruction.imm_c = F::from_canonical_u32(event.instruction.imm_c as u32); - match event.instruction.opcode { - Opcode::ADD => { - cols.is_add = F::one(); - } - Opcode::SUB => { - cols.is_sub = F::one(); - } - Opcode::MUL => { - cols.is_mul = F::one(); - } - Opcode::BEQ => { - cols.is_beq = F::one(); - } - Opcode::BNE => { - cols.is_bne = F::one(); - } - _ => {} - }; if let Some(record) = &event.a_record { cols.a.populate(record); @@ -97,25 +75,25 @@ impl MachineAir for CpuChip { cols.c.value = event.instruction.op_c; } - cols.add_scratch = cols.b.value.0[0] + cols.c.value.0[0]; - cols.sub_scratch = cols.b.value.0[0] - cols.c.value.0[0]; - cols.mul_scratch = cols.b.value.0[0] * cols.c.value.0[0]; - cols.add_ext_scratch = (BinomialExtension::from_block(cols.b.value) - + BinomialExtension::from_block(cols.c.value)) - .as_block(); - cols.sub_ext_scratch = (BinomialExtension::from_block(cols.b.value) - - BinomialExtension::from_block(cols.c.value)) - .as_block(); - cols.mul_ext_scratch = (BinomialExtension::from_block(cols.b.value) - * BinomialExtension::from_block(cols.c.value)) - .as_block(); - - cols.a_eq_b - .populate((cols.a.value.0[0] - cols.b.value.0[0]).as_canonical_u32()); - - let is_last_row = F::from_bool(i == input.cpu_events.len() - 1); - cols.beq = cols.is_beq * cols.a_eq_b.result * (F::one() - is_last_row); - cols.bne = cols.is_bne * (F::one() - cols.a_eq_b.result) * (F::one() - is_last_row); + // cols.add_scratch = cols.b.value.0[0] + cols.c.value.0[0]; + // cols.sub_scratch = cols.b.value.0[0] - cols.c.value.0[0]; + // cols.mul_scratch = cols.b.value.0[0] * cols.c.value.0[0]; + // cols.add_ext_scratch = (BinomialExtension::from_block(cols.b.value) + // + BinomialExtension::from_block(cols.c.value)) + // .as_block(); + // cols.sub_ext_scratch = (BinomialExtension::from_block(cols.b.value) + // - BinomialExtension::from_block(cols.c.value)) + // .as_block(); + // cols.mul_ext_scratch = (BinomialExtension::from_block(cols.b.value) + // * BinomialExtension::from_block(cols.c.value)) + // .as_block(); + + // cols.a_eq_b + // .populate((cols.a.value.0[0] - cols.b.value.0[0]).as_canonical_u32()); + + // let is_last_row = F::from_bool(i == input.cpu_events.len() - 1); + // cols.beq = cols.is_beq * cols.a_eq_b.result * (F::one() - is_last_row); + // cols.bne = cols.is_bne * (F::one() - cols.a_eq_b.result) * (F::one() - is_last_row); cols.is_real = F::one(); row @@ -155,6 +133,12 @@ where AB: SP1AirBuilder, { fn eval(&self, builder: &mut AB) { + // Constraints for the CPU chip. + // + // - Constraints for fetching the instruction. + // - Constraints for incrementing the internal state consisting of the program counter + // and the clock. + let main = builder.main(); let local: &CpuCols = main.row_slice(0).borrow(); let next: &CpuCols = main.row_slice(1).borrow(); @@ -165,14 +149,14 @@ where .when(local.is_real) .assert_eq(local.clk + AB::F::from_canonical_u32(4), next.clk); - // Increment pc by 1 every cycle unless it is a branch instruction that is satisfied. - builder - .when_transition() - .when(next.is_real * (AB::Expr::one() - (local.is_beq + local.is_bne))) - .assert_eq(local.pc + AB::F::one(), next.pc); - builder - .when(local.beq + local.bne) - .assert_eq(next.pc, local.pc + local.c.value.0[0]); + // // Increment pc by 1 every cycle unless it is a branch instruction that is satisfied. + // builder + // .when_transition() + // .when(next.is_real * (AB::Expr::one() - (local.is_beq + local.is_bne))) + // .assert_eq(local.pc + AB::F::one(), next.pc); + // builder + // .when(local.beq + local.bne) + // .assert_eq(next.pc, local.pc + local.c.value.0[0]); // Connect immediates. builder @@ -183,71 +167,71 @@ where .assert_block_eq::(local.c.value, local.instruction.op_c); // Compute ALU. - builder.assert_eq(local.b.value.0[0] + local.c.value.0[0], local.add_scratch); - builder.assert_eq(local.b.value.0[0] - local.c.value.0[0], local.sub_scratch); - builder.assert_eq(local.b.value.0[0] * local.c.value.0[0], local.mul_scratch); - - // Compute extension ALU. - builder.assert_ext_eq( - local.b.value.as_extension::() + local.c.value.as_extension::(), - local.add_ext_scratch.as_extension::(), - ); - builder.assert_ext_eq( - local.b.value.as_extension::() - local.c.value.as_extension::(), - local.sub_ext_scratch.as_extension::(), - ); - builder.assert_ext_eq( - local.b.value.as_extension::() * local.c.value.as_extension::(), - local.mul_ext_scratch.as_extension::(), - ); - - // Connect ALU to CPU. - builder - .when(local.is_add) - .assert_eq(local.a.value.0[0], local.add_scratch); - builder - .when(local.is_add) - .assert_eq(local.a.value.0[1], AB::F::zero()); - builder - .when(local.is_add) - .assert_eq(local.a.value.0[2], AB::F::zero()); - builder - .when(local.is_add) - .assert_eq(local.a.value.0[3], AB::F::zero()); - - builder - .when(local.is_sub) - .assert_eq(local.a.value.0[0], local.sub_scratch); - builder - .when(local.is_sub) - .assert_eq(local.a.value.0[1], AB::F::zero()); - builder - .when(local.is_sub) - .assert_eq(local.a.value.0[2], AB::F::zero()); - builder - .when(local.is_sub) - .assert_eq(local.a.value.0[3], AB::F::zero()); - - builder - .when(local.is_mul) - .assert_eq(local.a.value.0[0], local.mul_scratch); - builder - .when(local.is_mul) - .assert_eq(local.a.value.0[1], AB::F::zero()); - builder - .when(local.is_mul) - .assert_eq(local.a.value.0[2], AB::F::zero()); - builder - .when(local.is_mul) - .assert_eq(local.a.value.0[3], AB::F::zero()); + // builder.assert_eq(local.b.value.0[0] + local.c.value.0[0], local.add_scratch); + // builder.assert_eq(local.b.value.0[0] - local.c.value.0[0], local.sub_scratch); + // builder.assert_eq(local.b.value.0[0] * local.c.value.0[0], local.mul_scratch); + + // // Compute extension ALU. + // builder.assert_ext_eq( + // local.b.value.as_extension::() + local.c.value.as_extension::(), + // local.add_ext_scratch.as_extension::(), + // ); + // builder.assert_ext_eq( + // local.b.value.as_extension::() - local.c.value.as_extension::(), + // local.sub_ext_scratch.as_extension::(), + // ); + // builder.assert_ext_eq( + // local.b.value.as_extension::() * local.c.value.as_extension::(), + // local.mul_ext_scratch.as_extension::(), + // ); + + // // Connect ALU to CPU. + // builder + // .when(local.is_add) + // .assert_eq(local.a.value.0[0], local.add_scratch); + // builder + // .when(local.is_add) + // .assert_eq(local.a.value.0[1], AB::F::zero()); + // builder + // .when(local.is_add) + // .assert_eq(local.a.value.0[2], AB::F::zero()); + // builder + // .when(local.is_add) + // .assert_eq(local.a.value.0[3], AB::F::zero()); + + // builder + // .when(local.is_sub) + // .assert_eq(local.a.value.0[0], local.sub_scratch); + // builder + // .when(local.is_sub) + // .assert_eq(local.a.value.0[1], AB::F::zero()); + // builder + // .when(local.is_sub) + // .assert_eq(local.a.value.0[2], AB::F::zero()); + // builder + // .when(local.is_sub) + // .assert_eq(local.a.value.0[3], AB::F::zero()); + + // builder + // .when(local.is_mul) + // .assert_eq(local.a.value.0[0], local.mul_scratch); + // builder + // .when(local.is_mul) + // .assert_eq(local.a.value.0[1], AB::F::zero()); + // builder + // .when(local.is_mul) + // .assert_eq(local.a.value.0[2], AB::F::zero()); + // builder + // .when(local.is_mul) + // .assert_eq(local.a.value.0[3], AB::F::zero()); // Compute if a == b. - IsZeroOperation::::eval::( - builder, - local.a.value.0[0] - local.b.value.0[0], - local.a_eq_b, - local.is_real.into(), - ); + // IsZeroOperation::::eval::( + // builder, + // local.a.value.0[0] - local.b.value.0[0], + // local.a_eq_b, + // local.is_real.into(), + // ); // Receive C. builder.receive(AirInteraction::new( diff --git a/recursion/core/src/cpu/columns.rs b/recursion/core/src/cpu/columns.rs deleted file mode 100644 index 3f20a765b..000000000 --- a/recursion/core/src/cpu/columns.rs +++ /dev/null @@ -1,59 +0,0 @@ -use crate::{air::Block, memory::MemoryReadWriteCols}; -use sp1_core::operations::IsZeroOperation; -use sp1_derive::AlignedBorrow; - -/// The column layout for the chip. -#[derive(AlignedBorrow, Default, Clone, Debug)] -#[repr(C)] -pub struct CpuCols { - pub clk: T, - pub pc: T, - pub fp: T, - - pub a: MemoryReadWriteCols, - pub b: MemoryReadWriteCols, - pub c: MemoryReadWriteCols, - - pub instruction: InstructionCols, - pub is_add: T, - pub is_sub: T, - pub is_mul: T, - pub is_beq: T, - pub is_bne: T, - - pub beq: T, - pub bne: T, - - // c = a + b; - pub add_scratch: T, - - // c = a - b; - pub sub_scratch: T, - - // c = a * b; - pub mul_scratch: T, - - // ext(c) = ext(a) + ext(b); - pub add_ext_scratch: Block, - - // ext(c) = ext(a) - ext(b); - pub sub_ext_scratch: Block, - - // ext(c) = ext(a) * ext(b); - pub mul_ext_scratch: Block, - - // c = a == b; - pub a_eq_b: IsZeroOperation, - - pub is_real: T, -} - -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -pub struct InstructionCols { - pub opcode: T, - pub op_a: T, - pub op_b: Block, - pub op_c: Block, - pub imm_b: T, - pub imm_c: T, -} diff --git a/recursion/core/src/cpu/columns/alu.rs b/recursion/core/src/cpu/columns/alu.rs new file mode 100644 index 000000000..c45603f56 --- /dev/null +++ b/recursion/core/src/cpu/columns/alu.rs @@ -0,0 +1,28 @@ +use crate::air::Block; +use sp1_derive::AlignedBorrow; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct AluCols { + pub ext_a: Block, + + pub ext_b: Block, + + // c = a + b; + pub add_scratch: T, + + // c = a - b; + pub sub_scratch: T, + + // c = a * b; + pub mul_scratch: T, + + // ext(c) = ext(a) + ext(b); + pub add_ext_scratch: Block, + + // ext(c) = ext(a) - ext(b); + pub sub_ext_scratch: Block, + + // ext(c) = ext(a) * ext(b); + pub mul_ext_scratch: Block, +} diff --git a/recursion/core/src/cpu/columns/branch.rs b/recursion/core/src/cpu/columns/branch.rs new file mode 100644 index 000000000..2b7df19de --- /dev/null +++ b/recursion/core/src/cpu/columns/branch.rs @@ -0,0 +1,14 @@ +use sp1_derive::AlignedBorrow; +use std::mem::size_of; + +use crate::air::IsExtZeroOperation; + +#[allow(dead_code)] +pub const NUM_BRANCH_COLS: usize = size_of::>(); + +/// The column layout for branching. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct BranchCols { + is_eq_zero: IsExtZeroOperation, +} diff --git a/recursion/core/src/cpu/columns/instruction.rs b/recursion/core/src/cpu/columns/instruction.rs new file mode 100644 index 000000000..2220c7b1f --- /dev/null +++ b/recursion/core/src/cpu/columns/instruction.rs @@ -0,0 +1,38 @@ +use crate::{air::Block, cpu::Instruction}; +use p3_field::PrimeField; +use sp1_derive::AlignedBorrow; +use std::{iter::once, vec::IntoIter}; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct InstructionCols { + pub opcode: T, + pub op_a: T, + pub op_b: Block, + pub op_c: Block, + pub imm_b: T, + pub imm_c: T, +} + +impl InstructionCols { + pub fn populate(&mut self, instruction: Instruction) { + self.opcode = instruction.opcode.as_field::(); + self.op_a = instruction.op_a; + self.op_b = instruction.op_b; + self.op_c = instruction.op_c; + } +} + +impl IntoIterator for InstructionCols { + type Item = T; + type IntoIter = IntoIter; + + fn into_iter(self) -> Self::IntoIter { + once(self.opcode) + .chain(once(self.op_a)) + .chain(self.op_b) + .chain(self.op_c) + .collect::>() + .into_iter() + } +} diff --git a/recursion/core/src/cpu/columns/jump.rs b/recursion/core/src/cpu/columns/jump.rs new file mode 100644 index 000000000..afc65cd98 --- /dev/null +++ b/recursion/core/src/cpu/columns/jump.rs @@ -0,0 +1,15 @@ +use sp1_derive::AlignedBorrow; +use std::mem::size_of; + +#[allow(dead_code)] +pub const NUM_JUMP_COLS: usize = size_of::>(); + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct JumpCols { + /// The current program counter. + pub pc: T, + + /// THe next program counter. + pub next_pc: T, +} diff --git a/recursion/core/src/cpu/columns/mod.rs b/recursion/core/src/cpu/columns/mod.rs new file mode 100644 index 000000000..da14d923b --- /dev/null +++ b/recursion/core/src/cpu/columns/mod.rs @@ -0,0 +1,36 @@ +use crate::{air::IsExtZeroOperation, memory::MemoryReadWriteCols}; +use sp1_derive::AlignedBorrow; + +mod alu; +mod branch; +mod instruction; +mod jump; +mod opcode; +mod opcode_specific; + +pub use alu::*; +pub use instruction::*; +pub use opcode::*; + +/// The column layout for the chip. +#[derive(AlignedBorrow, Default, Clone, Debug)] +#[repr(C)] +pub struct CpuCols { + pub clk: T, + pub pc: T, + pub fp: T, + + pub instruction: InstructionCols, + pub selectors: OpcodeSelectorCols, + + pub a: MemoryReadWriteCols, + pub b: MemoryReadWriteCols, + pub c: MemoryReadWriteCols, + + pub alu: AluCols, + + // result = operand_1 == operand_2; + pub eq_1_2: IsExtZeroOperation, + + pub is_real: T, +} diff --git a/recursion/core/src/cpu/columns/opcode.rs b/recursion/core/src/cpu/columns/opcode.rs new file mode 100644 index 000000000..1c1ef0705 --- /dev/null +++ b/recursion/core/src/cpu/columns/opcode.rs @@ -0,0 +1,129 @@ +use p3_field::Field; +use sp1_derive::AlignedBorrow; + +use crate::{cpu::Instruction, runtime::Opcode}; + +const OPCODE_COUNT: usize = core::mem::size_of::>(); + +/// Selectors for the opcode. +/// +/// This contains selectors for the different opcodes corresponding to variants of the [`Opcode`] +/// enum. +#[derive(AlignedBorrow, Clone, Copy, Default, Debug)] +#[repr(C)] +pub struct OpcodeSelectorCols { + // Arithmetic field instructions. + pub is_add: T, + pub is_sub: T, + pub is_mul: T, + pub is_div: T, + + // Arithmetic field extension operations. + pub is_eadd: T, + pub is_esub: T, + pub is_emul: T, + pub is_ediv: T, + + // Mixed arithmetic operations. + pub is_efadd: T, + pub is_efsub: T, + pub is_fesub: T, + pub is_efmul: T, + pub is_efdiv: T, + pub is_fediv: T, + + // Memory instructions. + pub is_lw: T, + pub is_sw: T, + pub is_le: T, + pub is_se: T, + + // Branch instructions. + pub is_beq: T, + pub is_bne: T, + pub is_ebeq: T, + pub is_ebne: T, + + // Jump instructions. + pub is_jal: T, + pub is_jalr: T, + + // System instructions. + pub is_trap: T, + pub is_noop: T, +} + +impl OpcodeSelectorCols { + /// Populates the opcode columns with the given instruction. + /// + /// The opcode flag should be set to 1 for the relevant opcode and 0 for the rest. We already + /// assume that the state of the columns is set to zero at the start of the function, so we only + /// need to set the relevant opcode column to 1. + pub fn populate(&mut self, instruction: Instruction) { + match instruction.opcode { + Opcode::ADD => self.is_add = F::one(), + Opcode::SUB => self.is_sub = F::one(), + Opcode::MUL => self.is_mul = F::one(), + Opcode::DIV => self.is_div = F::one(), + Opcode::EADD => self.is_eadd = F::one(), + Opcode::ESUB => self.is_esub = F::one(), + Opcode::EMUL => self.is_emul = F::one(), + Opcode::EDIV => self.is_ediv = F::one(), + Opcode::EFADD => self.is_efadd = F::one(), + Opcode::EFSUB => self.is_efsub = F::one(), + Opcode::FESUB => self.is_fesub = F::one(), + Opcode::EFMUL => self.is_efmul = F::one(), + Opcode::EFDIV => self.is_efdiv = F::one(), + Opcode::FEDIV => self.is_fediv = F::one(), + Opcode::LW => self.is_lw = F::one(), + Opcode::SW => self.is_sw = F::one(), + Opcode::LE => self.is_le = F::one(), + Opcode::SE => self.is_se = F::one(), + Opcode::BEQ => self.is_beq = F::one(), + Opcode::BNE => self.is_bne = F::one(), + Opcode::EBEQ => self.is_ebeq = F::one(), + Opcode::EBNE => self.is_ebne = F::one(), + Opcode::JAL => self.is_jal = F::one(), + Opcode::JALR => self.is_jalr = F::one(), + Opcode::TRAP => self.is_trap = F::one(), + } + } +} + +impl IntoIterator for OpcodeSelectorCols { + type Item = T; + + type IntoIter = std::array::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + [ + self.is_add, + self.is_sub, + self.is_mul, + self.is_div, + self.is_eadd, + self.is_esub, + self.is_emul, + self.is_ediv, + self.is_efadd, + self.is_efsub, + self.is_fesub, + self.is_efmul, + self.is_efdiv, + self.is_fediv, + self.is_lw, + self.is_sw, + self.is_le, + self.is_se, + self.is_beq, + self.is_bne, + self.is_ebeq, + self.is_ebne, + self.is_jal, + self.is_jalr, + self.is_trap, + self.is_noop, + ] + .into_iter() + } +} diff --git a/recursion/core/src/cpu/columns/opcode_specific.rs b/recursion/core/src/cpu/columns/opcode_specific.rs new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/recursion/core/src/cpu/columns/opcode_specific.rs @@ -0,0 +1 @@ + diff --git a/recursion/core/src/cpu/mod.rs b/recursion/core/src/cpu/mod.rs index e842c14d9..f32dd933e 100644 --- a/recursion/core/src/cpu/mod.rs +++ b/recursion/core/src/cpu/mod.rs @@ -4,6 +4,9 @@ pub mod columns; use crate::air::Block; pub use crate::{memory::MemoryRecord, runtime::Instruction}; +pub use air::*; +pub use columns::*; + #[derive(Debug, Clone)] pub struct CpuEvent { pub clk: F, diff --git a/recursion/core/src/program/mod.rs b/recursion/core/src/program/mod.rs index 7cc6e56e2..29f7252ef 100644 --- a/recursion/core/src/program/mod.rs +++ b/recursion/core/src/program/mod.rs @@ -1,4 +1,4 @@ -use crate::{cpu::columns::InstructionCols, runtime::ExecutionRecord}; +use crate::{cpu::InstructionCols, runtime::ExecutionRecord}; use core::mem::size_of; use p3_air::{Air, BaseAir}; use p3_field::PrimeField32; diff --git a/recursion/core/src/runtime/instruction.rs b/recursion/core/src/runtime/instruction.rs index c4b9c6bb1..3640f63f3 100644 --- a/recursion/core/src/runtime/instruction.rs +++ b/recursion/core/src/runtime/instruction.rs @@ -57,6 +57,7 @@ impl Instruction { | Opcode::EFSUB | Opcode::EFMUL | Opcode::EDIV + | Opcode::EFDIV | Opcode::EBNE | Opcode::EBEQ ) @@ -70,11 +71,9 @@ impl Instruction { | Opcode::EADD | Opcode::EMUL | Opcode::ESUB + | Opcode::FESUB | Opcode::EDIV - | Opcode::EFADD - | Opcode::EFSUB - | Opcode::EFMUL - | Opcode::EFDIV + | Opcode::FEDIV | Opcode::EBNE | Opcode::EBEQ ) diff --git a/recursion/core/src/runtime/mod.rs b/recursion/core/src/runtime/mod.rs index 258bc614a..95cab9d79 100644 --- a/recursion/core/src/runtime/mod.rs +++ b/recursion/core/src/runtime/mod.rs @@ -245,14 +245,14 @@ impl> Runtime { self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, c_val); } - Opcode::ESUB | Opcode::EFSUB => { + Opcode::ESUB | Opcode::EFSUB | Opcode::FESUB => { let (a_ptr, b_val, c_val) = self.alu_rr(&instruction); let diff = EF::from_base_slice(&b_val.0) - EF::from_base_slice(&c_val.0); let a_val = Block::from(diff.as_base_slice()); self.mw(a_ptr, a_val, MemoryAccessPosition::A); (a, b, c) = (a_val, b_val, c_val); } - Opcode::EDIV | Opcode::EFDIV => { + Opcode::EDIV | Opcode::EFDIV | Opcode::FEDIV => { let (a_ptr, b_val, c_val) = self.alu_rr(&instruction); let quotient = EF::from_base_slice(&b_val.0) / EF::from_base_slice(&c_val.0); let a_val = Block::from(quotient.as_base_slice()); diff --git a/recursion/core/src/runtime/opcode.rs b/recursion/core/src/runtime/opcode.rs index 9b3c100f9..da68ba2f4 100644 --- a/recursion/core/src/runtime/opcode.rs +++ b/recursion/core/src/runtime/opcode.rs @@ -1,3 +1,5 @@ +use p3_field::AbstractField; + #[allow(clippy::upper_case_acronyms)] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Opcode { @@ -16,8 +18,10 @@ pub enum Opcode { // Mixed arithmetic operations. EFADD = 20, EFSUB = 21, + FESUB = 24, EFMUL = 22, EFDIV = 23, + FEDIV = 25, // Memory instructions. LW = 4, @@ -38,3 +42,9 @@ pub enum Opcode { // System instructions. TRAP = 30, } + +impl Opcode { + pub fn as_field(&self) -> F { + F::from_canonical_u32(*self as u32) + } +} diff --git a/recursion/core/src/stark/mod.rs b/recursion/core/src/stark/mod.rs index 3b9a4d3e6..f42b8b0fb 100644 --- a/recursion/core/src/stark/mod.rs +++ b/recursion/core/src/stark/mod.rs @@ -3,22 +3,23 @@ use crate::{ memory::{MemoryChipKind, MemoryGlobalChip}, program::ProgramChip, }; -use p3_field::PrimeField32; +use p3_field::{extension::BinomiallyExtendable, PrimeField32}; use sp1_core::stark::{Chip, MachineStark, StarkGenericConfig}; use sp1_derive::MachineAir; +use crate::runtime::D; + #[derive(MachineAir)] #[sp1_core_path = "sp1_core"] #[execution_record_path = "crate::runtime::ExecutionRecord"] -pub enum RecursionAir { +pub enum RecursionAir> { Program(ProgramChip), Cpu(CpuChip), MemoryInit(MemoryGlobalChip), MemoryFinalize(MemoryGlobalChip), } -#[allow(dead_code)] -impl RecursionAir { +impl> RecursionAir { pub fn machine>(config: SC) -> MachineStark { let chips = Self::get_all() .into_iter()