diff --git a/Cargo.lock b/Cargo.lock index b75a8e566..51f737b3e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4859,9 +4859,11 @@ dependencies = [ "num-bigint 0.4.5", "p3-baby-bear", "p3-field", + "p3-symmetric", "rand", "serde", "serde_json", + "sp1-core", "sp1-recursion-compiler", "tempfile", ] diff --git a/core/Cargo.toml b/core/Cargo.toml index df02550f1..548fcde6c 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -58,6 +58,7 @@ web-time = "1.1.0" rayon-scan = "0.1.1" thiserror = "1.0.60" num-bigint = { version = "0.4.3", default-features = false } +rand = "0.8.5" [dev-dependencies] tiny-keccak = { version = "2.0.2", features = ["keccak"] } diff --git a/core/src/air/builder.rs b/core/src/air/builder.rs index eba567d92..d957f43cd 100644 --- a/core/src/air/builder.rs +++ b/core/src/air/builder.rs @@ -308,6 +308,7 @@ pub trait AluAirBuilder: BaseAirBuilder { c: Word>, shard: impl Into, channel: impl Into, + nonce: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -316,6 +317,7 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) .chain(once(channel.into())) + .chain(once(nonce.into())) .collect(); self.send(AirInteraction::new( @@ -335,6 +337,7 @@ pub trait AluAirBuilder: BaseAirBuilder { c: Word>, shard: impl Into, channel: impl Into, + nonce: impl Into, multiplicity: impl Into, ) { let values = once(opcode.into()) @@ -343,6 +346,7 @@ pub trait AluAirBuilder: BaseAirBuilder { .chain(c.0.into_iter().map(Into::into)) .chain(once(shard.into())) .chain(once(channel.into())) + .chain(once(nonce.into())) .collect(); self.receive(AirInteraction::new( @@ -359,6 +363,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard: impl Into + Clone, channel: impl Into + Clone, clk: impl Into + Clone, + nonce: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, arg2: impl Into + Clone, @@ -369,6 +374,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard.clone().into(), channel.clone().into(), clk.clone().into(), + nonce.clone().into(), syscall_id.clone().into(), arg1.clone().into(), arg2.clone().into(), @@ -385,6 +391,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard: impl Into + Clone, channel: impl Into + Clone, clk: impl Into + Clone, + nonce: impl Into + Clone, syscall_id: impl Into + Clone, arg1: impl Into + Clone, arg2: impl Into + Clone, @@ -395,6 +402,7 @@ pub trait AluAirBuilder: BaseAirBuilder { shard.clone().into(), channel.clone().into(), clk.clone().into(), + nonce.clone().into(), syscall_id.clone().into(), arg1.clone().into(), arg2.clone().into(), diff --git a/core/src/alu/add_sub/mod.rs b/core/src/alu/add_sub/mod.rs index 2321427c5..3179d4d77 100644 --- a/core/src/alu/add_sub/mod.rs +++ b/core/src/alu/add_sub/mod.rs @@ -1,7 +1,8 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::{Air, BaseAir}; +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -38,6 +39,9 @@ pub struct AddSubCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// Instance of `AddOperation` to handle addition logic in `AddSubChip`'s ALU operations. /// It's result will be `a` for the add operation and `b` for the sub operation. pub add_operation: AddOperation, @@ -129,6 +133,13 @@ impl MachineAir for AddSubChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut AddSubCols = + trace.values[i * NUM_ADD_SUB_COLS..(i + 1) * NUM_ADD_SUB_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -151,6 +162,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &AddSubCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &AddSubCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // Evaluate the addition operation. AddOperation::::eval( @@ -172,6 +191,7 @@ where local.operand_2, local.shard, local.channel, + local.nonce, local.is_add, ); @@ -183,6 +203,7 @@ where local.operand_2, local.shard, local.channel, + local.nonce, local.is_sub, ); diff --git a/core/src/alu/bitwise/mod.rs b/core/src/alu/bitwise/mod.rs index 3e7227b70..81163b11e 100644 --- a/core/src/alu/bitwise/mod.rs +++ b/core/src/alu/bitwise/mod.rs @@ -1,7 +1,9 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; +use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; @@ -31,6 +33,9 @@ pub struct BitwiseCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -111,6 +116,12 @@ impl MachineAir for BitwiseChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + for i in 0..trace.height() { + let cols: &mut BitwiseCols = + trace.values[i * NUM_BITWISE_COLS..(i + 1) * NUM_BITWISE_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -133,6 +144,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &BitwiseCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &BitwiseCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // Get the opcode for the operation. let opcode = local.is_xor * ByteOpcode::XOR.as_field::() @@ -166,6 +185,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_xor + local.is_or + local.is_and, ); diff --git a/core/src/alu/divrem/mod.rs b/core/src/alu/divrem/mod.rs index b54206469..84198b16b 100644 --- a/core/src/alu/divrem/mod.rs +++ b/core/src/alu/divrem/mod.rs @@ -64,6 +64,7 @@ mod utils; use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use std::collections::HashMap; use p3_air::{Air, AirBuilder, BaseAir}; use p3_field::AbstractField; @@ -72,11 +73,10 @@ use p3_matrix::dense::RowMajorMatrix; use p3_matrix::Matrix; use sp1_derive::AlignedBorrow; -use self::utils::eval_abs_value; use crate::air::MachineAir; use crate::air::{SP1AirBuilder, Word}; use crate::alu::divrem::utils::{get_msb, get_quotient_and_remainder, is_signed_operation}; -use crate::alu::AluEvent; +use crate::alu::{create_alu_lookups, AluEvent}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; use crate::disassembler::WORD_SIZE; @@ -107,6 +107,9 @@ pub struct DivRemCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -184,6 +187,23 @@ pub struct DivRemCols { /// Flag to indicate whether `c` is negative. pub c_neg: T, + /// The lower nonce of the operation. + pub lower_nonce: T, + + /// The upper nonce of the operation. + pub upper_nonce: T, + + /// The absolute nonce of the operation. + pub abs_nonce: T, + + /// Selector to determine whether an ALU Event is sent for absolute value computation of `c`. + pub abs_c_alu_event: T, + pub abs_c_alu_event_nonce: T, + + /// Selector to determine whether an ALU Event is sent for absolute value computation of `rem`. + pub abs_rem_alu_event: T, + pub abs_rem_alu_event_nonce: T, + /// Selector to know whether this row is enabled. pub is_real: T, @@ -259,6 +279,24 @@ impl MachineAir for DivRemChip { cols.max_abs_c_or_1 = Word::from(u32::max(1, event.c)); } + // Set the `alu_event` flags. + cols.abs_c_alu_event = cols.c_neg * cols.is_real; + cols.abs_c_alu_event_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[4]) + .copied() + .unwrap_or_default(), + ); + cols.abs_rem_alu_event = cols.rem_neg * cols.is_real; + cols.abs_rem_alu_event_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[5]) + .copied() + .unwrap_or_default(), + ); + // Insert the MSB lookup events. { let words = [event.b, event.c, remainder]; @@ -281,7 +319,7 @@ impl MachineAir for DivRemChip { // Calculate the modified multiplicity { - cols.remainder_check_multiplicity = cols.is_real * cols.is_c_0.result; + cols.remainder_check_multiplicity = cols.is_real * (F::one() - cols.is_c_0.result); } // Calculate c * quotient + remainder. @@ -321,6 +359,40 @@ impl MachineAir for DivRemChip { // mul and LT upon which div depends. This ordering is critical as mul and LT // require all the mul and LT events be added before we can call generate_trace. { + // Insert the absolute value computation events. + { + let mut add_events: Vec = vec![]; + if cols.abs_c_alu_event == F::one() { + add_events.push(AluEvent { + lookup_id: event.sub_lookups[4], + shard: event.shard, + channel: event.channel, + clk: event.clk, + opcode: Opcode::ADD, + a: 0, + b: event.c, + c: (event.c as i32).abs() as u32, + sub_lookups: create_alu_lookups(), + }) + } + if cols.abs_rem_alu_event == F::one() { + add_events.push(AluEvent { + lookup_id: event.sub_lookups[5], + shard: event.shard, + channel: event.channel, + clk: event.clk, + opcode: Opcode::ADD, + a: 0, + b: remainder, + c: (remainder as i32).abs() as u32, + sub_lookups: create_alu_lookups(), + }) + } + let mut alu_events = HashMap::new(); + alu_events.insert(Opcode::ADD, add_events); + output.add_alu_events(alu_events); + } + let mut lower_word = 0; for i in 0..WORD_SIZE { lower_word += (c_times_quotient[i] as u32) << (i * BYTE_SIZE); @@ -332,6 +404,7 @@ impl MachineAir for DivRemChip { } let lower_multiplication = AluEvent { + lookup_id: event.sub_lookups[0], shard: event.shard, channel: event.channel, clk: event.clk, @@ -339,10 +412,19 @@ impl MachineAir for DivRemChip { a: lower_word, c: event.c, b: quotient, + sub_lookups: create_alu_lookups(), }; + cols.lower_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[0]) + .copied() + .unwrap_or_default(), + ); output.add_mul_event(lower_multiplication); let upper_multiplication = AluEvent { + lookup_id: event.sub_lookups[1], shard: event.shard, channel: event.channel, clk: event.clk, @@ -356,22 +438,45 @@ impl MachineAir for DivRemChip { a: upper_word, c: event.c, b: quotient, + sub_lookups: create_alu_lookups(), }; - + cols.upper_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[1]) + .copied() + .unwrap_or_default(), + ); output.add_mul_event(upper_multiplication); - let lt_event = if is_signed_operation(event.opcode) { + cols.abs_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[2]) + .copied() + .unwrap_or_default(), + ); AluEvent { + lookup_id: event.sub_lookups[2], shard: event.shard, channel: event.channel, - opcode: Opcode::SLT, + opcode: Opcode::SLTU, a: 1, b: (remainder as i32).abs() as u32, c: u32::max(1, (event.c as i32).abs() as u32), clk: event.clk, + sub_lookups: create_alu_lookups(), } } else { + cols.abs_nonce = F::from_canonical_u32( + input + .nonce_lookup + .get(&event.sub_lookups[3]) + .copied() + .unwrap_or_default(), + ); AluEvent { + lookup_id: event.sub_lookups[3], shard: event.shard, channel: event.channel, opcode: Opcode::SLTU, @@ -379,8 +484,10 @@ impl MachineAir for DivRemChip { b: remainder, c: u32::max(1, event.c), clk: event.clk, + sub_lookups: create_alu_lookups(), } }; + if cols.remainder_check_multiplicity == F::one() { output.add_lt_event(lt_event); } @@ -430,6 +537,13 @@ impl MachineAir for DivRemChip { trace.values[i] = padded_row_template[i % NUM_DIVREM_COLS]; } + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut DivRemCols = + trace.values[i * NUM_DIVREM_COLS..(i + 1) * NUM_DIVREM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -452,10 +566,18 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &DivRemCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &DivRemCols = (*next).borrow(); let base = AB::F::from_canonical_u32(1 << 8); let one: AB::Expr = AB::F::one().into(); let zero: AB::Expr = AB::F::zero().into(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // Calculate whether b, remainder, and c are negative. { // Negative if and only if op code is signed & MSB = 1. @@ -490,6 +612,7 @@ where local.c, local.shard, local.channel, + local.lower_nonce, local.is_real, ); @@ -515,6 +638,7 @@ where local.c, local.shard, local.channel, + local.upper_nonce, local.is_real, ); } @@ -659,18 +783,37 @@ where // Range check remainder. (i.e., |remainder| < |c| when not is_c_0) { - eval_abs_value( - builder, - local.remainder.borrow(), - local.abs_remainder.borrow(), - local.rem_neg.borrow(), + // For each of `c` and `rem`, assert that the absolute value is equal to the original value, + // if the original value is non-negative or the minimum i32. + for i in 0..WORD_SIZE { + builder + .when_not(local.c_neg) + .assert_eq(local.c[i], local.abs_c[i]); + builder + .when_not(local.rem_neg) + .assert_eq(local.remainder[i], local.abs_remainder[i]); + } + // In the case that `c` or `rem` is negative, instead check that their sum is zero by + // sending an AddEvent. + builder.send_alu( + AB::Expr::from_canonical_u32(Opcode::ADD as u32), + Word([zero.clone(), zero.clone(), zero.clone(), zero.clone()]), + local.c, + local.abs_c, + local.shard, + local.channel, + local.abs_c_alu_event_nonce, + local.abs_c_alu_event, ); - - eval_abs_value( - builder, - local.c.borrow(), - local.abs_c.borrow(), - local.c_neg.borrow(), + builder.send_alu( + AB::Expr::from_canonical_u32(Opcode::ADD as u32), + Word([zero.clone(), zero.clone(), zero.clone(), zero.clone()]), + local.remainder, + local.abs_remainder, + local.shard, + local.channel, + local.abs_rem_alu_event_nonce, + local.abs_rem_alu_event, ); // max(abs(c), 1) = abs(c) * (1 - is_c_0) + 1 * is_c_0 @@ -691,29 +834,31 @@ where builder.assert_eq(local.max_abs_c_or_1[i], max_abs_c_or_1[i].clone()); } - let opcode = { - let is_signed = local.is_div + local.is_rem; - let is_unsigned = local.is_divu + local.is_remu; - let slt = AB::Expr::from_canonical_u32(Opcode::SLT as u32); - let sltu = AB::Expr::from_canonical_u32(Opcode::SLTU as u32); - is_signed * slt + is_unsigned * sltu - }; - - // Check that the event multiplicity column is computed correctly. + // Handle cases: + // - If is_real == 0 then remainder_check_multiplicity == 0 is forced. + // - If is_real == 1 then is_c_0_result must be the expected one, so + // remainder_check_multiplicity = (1 - is_c_0_result) * is_real. builder.assert_eq( + (AB::Expr::one() - local.is_c_0.result) * local.is_real, local.remainder_check_multiplicity, - local.is_c_0.result * local.is_real, ); + // the cleaner idea is simply remainder_check_multiplicity == (1 - is_c_0_result) * is_real + + // Check that the absolute value selector columns are computed correctly. + builder.assert_eq(local.abs_c_alu_event, local.c_neg * local.is_real); + builder.assert_eq(local.abs_rem_alu_event, local.rem_neg * local.is_real); + // Dispatch abs(remainder) < max(abs(c), 1), this is equivalent to abs(remainder) < // abs(c) if not division by 0. builder.send_alu( - opcode, + AB::Expr::from_canonical_u32(Opcode::SLTU as u32), Word([one.clone(), zero.clone(), zero.clone(), zero.clone()]), local.abs_remainder, local.max_abs_c_or_1, local.shard, local.channel, + local.abs_nonce, local.remainder_check_multiplicity, ); } @@ -783,6 +928,8 @@ where local.rem_neg, local.c_neg, local.is_real, + local.abs_c_alu_event, + local.abs_rem_alu_event, ]; for flag in bool_flags.iter() { @@ -817,6 +964,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/alu/divrem/utils.rs b/core/src/alu/divrem/utils.rs index f3a7b7070..d71c35aad 100644 --- a/core/src/alu/divrem/utils.rs +++ b/core/src/alu/divrem/utils.rs @@ -1,7 +1,3 @@ -use p3_air::AirBuilder; -use p3_field::AbstractField; - -use crate::air::{SP1AirBuilder, Word, WORD_SIZE}; use crate::runtime::Opcode; /// Returns `true` if the given `opcode` is a signed operation. @@ -32,47 +28,3 @@ pub fn get_quotient_and_remainder(b: u32, c: u32, opcode: Opcode) -> (u32, u32) pub const fn get_msb(a: u32) -> u8 { ((a >> 31) & 1) as u8 } - -/// Verifies that `abs_value = abs(value)` using `is_negative` as a flag. -/// -/// `abs(value) + value = 0` if `value` is negative. `abs(value) = value` otherwise. -/// -/// In two's complement arithmetic, the negation involves flipping its bits and adding 1. Therefore, -/// for a negative number, `abs(value) + value` equals 0. This is because `abs(value)` is the two's -/// complement (negation) of `value`. For a positive number, `abs(value)` is the same as `value`. -/// -/// The function iterates over each limb of the `value` and `abs_value`, checking the following -/// conditions: -/// -/// 1. If `value` is non-negative, it checks that each limb in `value` and `abs_value` is identical. -/// 2. If `value` is negative, it checks that the sum of each corresponding limb in `value` and -/// `abs_value` equals the expected sum for a two's complement representation. The least -/// significant limb (first limb) should add up to `0xff + 1` (to account for the +1 in two's -/// complement negation), and other limbs should add up to `0xff` (as the rest of the limbs just -/// have their bits flipped). -pub fn eval_abs_value( - builder: &mut AB, - value: &Word, - abs_value: &Word, - is_negative: &AB::Var, -) where - AB: SP1AirBuilder, -{ - for i in 0..WORD_SIZE { - let exp_sum_if_negative = AB::Expr::from_canonical_u32({ - if i == 0 { - 0xff + 1 - } else { - 0xff - } - }); - - builder - .when(*is_negative) - .assert_eq(value[i] + abs_value[i], exp_sum_if_negative.clone()); - - builder - .when_not(*is_negative) - .assert_eq(value[i], abs_value[i]); - } -} diff --git a/core/src/alu/lt/mod.rs b/core/src/alu/lt/mod.rs index 91b504181..54d5768c2 100644 --- a/core/src/alu/lt/mod.rs +++ b/core/src/alu/lt/mod.rs @@ -34,6 +34,9 @@ pub struct LtCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// If the opcode is SLT. pub is_slt: T, @@ -220,6 +223,13 @@ impl MachineAir for LtChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut LtCols = + trace.values[i * NUM_LT_COLS..(i + 1) * NUM_LT_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -242,6 +252,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &LtCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &LtCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let is_real = local.is_slt + local.is_sltu; @@ -431,6 +449,7 @@ where local.c, local.shard, local.channel, + local.nonce, is_real, ); } diff --git a/core/src/alu/mod.rs b/core/src/alu/mod.rs index c667c612c..a67d1ff90 100644 --- a/core/src/alu/mod.rs +++ b/core/src/alu/mod.rs @@ -11,6 +11,7 @@ pub use bitwise::*; pub use divrem::*; pub use lt::*; pub use mul::*; +use rand::Rng; pub use sll::*; pub use sr::*; @@ -21,6 +22,9 @@ use crate::runtime::Opcode; /// A standard format for describing ALU operations that need to be proven. #[derive(Debug, Clone, Copy, Serialize, Deserialize)] pub struct AluEvent { + /// The lookup id of the event. + pub lookup_id: usize, + /// The shard number, used for byte lookup table. pub shard: u32, @@ -41,12 +45,15 @@ pub struct AluEvent { // The second input operand. pub c: u32, + + pub sub_lookups: [usize; 6], } impl AluEvent { /// Creates a new `AluEvent`. pub fn new(shard: u32, channel: u32, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) -> Self { Self { + lookup_id: 0, shard, channel, clk, @@ -54,6 +61,24 @@ impl AluEvent { a, b, c, + sub_lookups: create_alu_lookups(), } } } + +pub fn create_alu_lookup_id() -> usize { + let mut rng = rand::thread_rng(); + rng.gen() +} + +pub fn create_alu_lookups() -> [usize; 6] { + let mut rng = rand::thread_rng(); + [ + rng.gen(), + rng.gen(), + rng.gen(), + rng.gen(), + rng.gen(), + rng.gen(), + ] +} diff --git a/core/src/alu/mul/mod.rs b/core/src/alu/mul/mod.rs index c30a59c4f..1351e78c3 100644 --- a/core/src/alu/mul/mod.rs +++ b/core/src/alu/mul/mod.rs @@ -79,6 +79,9 @@ pub struct MulCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -270,6 +273,13 @@ impl MachineAir for MulChip { // Pad the trace to a power of two. pad_to_power_of_two::(&mut trace.values); + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut MulCols = + trace.values[i * NUM_MUL_COLS..(i + 1) * NUM_MUL_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -292,12 +302,20 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &MulCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &MulCols = (*next).borrow(); let base = AB::F::from_canonical_u32(1 << 8); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); let byte_mask = AB::F::from_canonical_u8(BYTE_MASK); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // Calculate the MSBs. let (b_msb, c_msb) = { let msb_pairs = [ @@ -412,14 +430,6 @@ where .when(local.c_sign_extend) .assert_eq(local.c_msb, one.clone()); - // If the opcode doesn't allow sign extension for an operand, we must not extend their sign. - builder - .when(local.is_mul + local.is_mulhu) - .assert_zero(local.b_sign_extend + local.c_sign_extend); - builder - .when(local.is_mul + local.is_mulhsu + local.is_mulhsu) - .assert_zero(local.c_sign_extend); - // Calculate the opcode. let opcode = { // Exactly one of the op codes must be on. @@ -455,6 +465,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/alu/sll/mod.rs b/core/src/alu/sll/mod.rs index d87ee780d..b5a711542 100644 --- a/core/src/alu/sll/mod.rs +++ b/core/src/alu/sll/mod.rs @@ -67,6 +67,9 @@ pub struct ShiftLeftCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -199,6 +202,12 @@ impl MachineAir for ShiftLeft { trace.values[i] = padded_row_template[i % NUM_SHIFT_LEFT_COLS]; } + for i in 0..trace.height() { + let cols: &mut ShiftLeftCols = + trace.values[i * NUM_SHIFT_LEFT_COLS..(i + 1) * NUM_SHIFT_LEFT_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -221,11 +230,19 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &ShiftLeftCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &ShiftLeftCols = (*next).borrow(); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); let base: AB::Expr = AB::F::from_canonical_u32(1 << BYTE_SIZE).into(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // We first "bit shift" and next we "byte shift". Then we compare the results with a. // Finally, we perform some misc checks. @@ -354,6 +371,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/alu/sr/mod.rs b/core/src/alu/sr/mod.rs index 8f9ea721e..bd7a91d52 100644 --- a/core/src/alu/sr/mod.rs +++ b/core/src/alu/sr/mod.rs @@ -85,6 +85,9 @@ pub struct ShiftRightCols { /// The channel number, used for byte lookup table. pub channel: T, + /// The nonce of the operation. + pub nonce: T, + /// The output operand. pub a: Word, @@ -283,6 +286,13 @@ impl MachineAir for ShiftRightChip { trace.values[i] = padded_row_template[i % NUM_SHIFT_RIGHT_COLS]; } + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut ShiftRightCols = + trace.values[i * NUM_SHIFT_RIGHT_COLS..(i + 1) * NUM_SHIFT_RIGHT_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + trace } @@ -305,9 +315,17 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &ShiftRightCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &ShiftRightCols = (*next).borrow(); let zero: AB::Expr = AB::F::zero().into(); let one: AB::Expr = AB::F::one().into(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + // Check that the MSB of most_significant_byte matches local.b_msb using lookup. { let byte = local.b[WORD_SIZE - 1]; @@ -464,6 +482,9 @@ where for shift_by_n_bit in local.shift_by_n_bits.iter() { builder.assert_bool(*shift_by_n_bit); } + for bit in local.c_least_sig_byte.iter() { + builder.assert_bool(*bit); + } } // Range check bytes. @@ -485,6 +506,9 @@ where builder.assert_bool(local.is_sra); builder.assert_bool(local.is_real); + // Check that is_real is the sum of the two operation flags. + builder.assert_eq(local.is_srl + local.is_sra, local.is_real); + // Receive the arguments. builder.receive_alu( local.is_srl * AB::F::from_canonical_u32(Opcode::SRL as u32) @@ -494,6 +518,7 @@ where local.c, local.shard, local.channel, + local.nonce, local.is_real, ); } diff --git a/core/src/bytes/mod.rs b/core/src/bytes/mod.rs index f6d5bc482..2cf8c8fb1 100644 --- a/core/src/bytes/mod.rs +++ b/core/src/bytes/mod.rs @@ -24,7 +24,7 @@ use crate::bytes::trace::NUM_ROWS; pub const NUM_BYTE_OPS: usize = 9; /// The number of different byte lookup channels. -pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 4; +pub const NUM_BYTE_LOOKUP_CHANNELS: u32 = 16; /// A chip for computing byte operations. /// diff --git a/core/src/bytes/trace.rs b/core/src/bytes/trace.rs index 39f2b72ff..d6ba3921b 100644 --- a/core/src/bytes/trace.rs +++ b/core/src/bytes/trace.rs @@ -28,7 +28,7 @@ impl MachineAir for ByteChip { } fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option> { - // TODO: We should be able to make this a constant. Also, trace / map should be separate. + // OPT: We should be able to make this a constant. Also, trace / map should be separate. // Since we only need the trace and not the map, we can just pass 0 as the shard. let (trace, _) = Self::trace_and_map(0); diff --git a/core/src/cpu/air/branch.rs b/core/src/cpu/air/branch.rs index fad654de3..60bba4175 100644 --- a/core/src/cpu/air/branch.rs +++ b/core/src/cpu/air/branch.rs @@ -3,6 +3,7 @@ use p3_field::AbstractField; use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, OpcodeSelectorCols}; +use crate::operations::BabyBearWordRangeChecker; use crate::{cpu::CpuChip, runtime::Opcode}; impl CpuChip { @@ -57,6 +58,20 @@ impl CpuChip { .when(local.branching) .assert_eq(branch_cols.next_pc.reduce::(), local.next_pc); + // Range check branch_cols.pc and branch_cols.next_pc. + BabyBearWordRangeChecker::::range_check( + builder, + branch_cols.pc, + branch_cols.pc_range_checker, + is_branch_instruction.clone(), + ); + BabyBearWordRangeChecker::::range_check( + builder, + branch_cols.next_pc, + branch_cols.next_pc_range_checker, + is_branch_instruction.clone(), + ); + // When we are branching, calculate branch_cols.next_pc <==> branch_cols.pc + c. builder.send_alu( Opcode::ADD.as_field::(), @@ -65,6 +80,7 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, + branch_cols.next_pc_nonce, local.branching, ); @@ -83,15 +99,21 @@ impl CpuChip { .when(local.is_real) .when(local.not_branching) .assert_eq(local.pc + AB::Expr::from_canonical_u8(4), local.next_pc); - } - // Evaluate branching value constraints. - { - // Assert that local.is_branching is a bit. + // Assert that either we are branching or not branching when the instruction is a branch. + builder + .when(is_branch_instruction.clone()) + .assert_one(local.branching + local.not_branching); builder .when(is_branch_instruction.clone()) .assert_bool(local.branching); + builder + .when(is_branch_instruction.clone()) + .assert_bool(local.not_branching); + } + // Evaluate branching value constraints. + { // When the opcode is BEQ and we are branching, assert that a_eq_b is true. builder .when(local.selectors.is_beq * local.branching) @@ -146,6 +168,11 @@ impl CpuChip { .when(is_branch_instruction.clone() * branch_cols.a_eq_b) .assert_word_eq(local.op_a_val(), local.op_b_val()); + // To prevent this ALU send to be arbitrarily large when is_branch_instruction is false. + builder + .when_not(is_branch_instruction.clone()) + .assert_zero(local.branching); + // Calculate a_lt_b <==> a < b (using appropriate signedness). let use_signed_comparison = local.selectors.is_blt + local.selectors.is_bge; builder.send_alu( @@ -157,6 +184,7 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, + branch_cols.a_lt_b_nonce, is_branch_instruction.clone(), ); @@ -169,6 +197,7 @@ impl CpuChip { local.op_a_val(), local.shard, local.channel, + branch_cols.a_gt_b_nonce, is_branch_instruction.clone(), ); } diff --git a/core/src/cpu/air/ecall.rs b/core/src/cpu/air/ecall.rs index 506b2c7b7..870513b83 100644 --- a/core/src/cpu/air/ecall.rs +++ b/core/src/cpu/air/ecall.rs @@ -35,14 +35,19 @@ impl CpuChip { let syscall_id = syscall_code[0]; let send_to_table = syscall_code[1]; - // When is_ecall_instruction == true AND sent_to_table == true, ecall_mul_send_to_table should be true. - builder - .when(is_ecall_instruction.clone()) - .assert_eq(send_to_table, local.ecall_mul_send_to_table); + // Handle cases: + // - is_ecall_instruction = 1 => ecall_mul_send_to_table == send_to_table + // - is_ecall_instruction = 0 => ecall_mul_send_to_table == 0 + builder.assert_eq( + local.ecall_mul_send_to_table, + send_to_table * is_ecall_instruction.clone(), + ); + builder.send_syscall( local.shard, local.channel, local.clk, + ecall_cols.syscall_nonce, syscall_id, local.op_b_val().reduce::(), local.op_c_val().reduce::(), diff --git a/core/src/cpu/air/memory.rs b/core/src/cpu/air/memory.rs index 6ac1a07c1..707a50ff9 100644 --- a/core/src/cpu/air/memory.rs +++ b/core/src/cpu/air/memory.rs @@ -5,6 +5,7 @@ use crate::air::{BaseAirBuilder, SP1AirBuilder, Word, WordAirBuilder}; use crate::cpu::columns::{CpuCols, MemoryColumns, OpcodeSelectorCols}; use crate::cpu::CpuChip; use crate::memory::MemoryCols; +use crate::operations::BabyBearWordRangeChecker; use crate::runtime::{MemoryAccessPosition, Opcode}; impl CpuChip { @@ -66,6 +67,15 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, + memory_columns.addr_word_nonce, + is_memory_instruction.clone(), + ); + + // Range check the addr_word to be a valid babybear word. + BabyBearWordRangeChecker::::range_check( + builder, + memory_columns.addr_word, + memory_columns.addr_word_range_checker, is_memory_instruction.clone(), ); @@ -88,6 +98,35 @@ impl CpuChip { memory_columns.addr_word.reduce::(), ); + // Verify that the least significant byte of addr_word - addr_offset is divisible by 4. + let offset = [ + memory_columns.offset_is_one, + memory_columns.offset_is_two, + memory_columns.offset_is_three, + ] + .iter() + .enumerate() + .fold(AB::Expr::zero(), |acc, (index, &value)| { + acc + AB::Expr::from_canonical_usize(index + 1) * value + }); + let mut recomposed_byte = AB::Expr::zero(); + memory_columns + .aa_least_sig_byte_decomp + .iter() + .enumerate() + .for_each(|(i, value)| { + builder + .when(is_memory_instruction.clone()) + .assert_bool(*value); + + recomposed_byte = + recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << (i + 2)) * *value; + }); + + builder + .when(is_memory_instruction.clone()) + .assert_eq(memory_columns.addr_word[0] - offset, recomposed_byte); + // For operations that require reading from memory (not registers), we need to read the // value into the memory columns. builder.eval_memory_access( @@ -98,6 +137,14 @@ impl CpuChip { &memory_columns.memory_access, is_memory_instruction.clone(), ); + + // On memory load instructions, make sure that the memory value is not changed. + builder + .when(self.is_load_instruction::(&local.selectors)) + .assert_word_eq( + *memory_columns.memory_access.value(), + *memory_columns.memory_access.prev_value(), + ); } /// Evaluates constraints related to loading from memory. @@ -121,12 +168,11 @@ impl CpuChip { // Assert that if `is_lb` and `is_lh` are both true, then the most significant byte // matches the value of `local.mem_value_is_neg`. - builder - .when(local.selectors.is_lb + local.selectors.is_lh) - .assert_eq( - local.mem_value_is_neg, - memory_columns.most_sig_byte_decomp[7], - ); + builder.assert_eq( + local.mem_value_is_neg, + (local.selectors.is_lb + local.selectors.is_lh) + * memory_columns.most_sig_byte_decomp[7], + ); // When the memory value is negative, use the SUB opcode to compute the signed value of // the memory value and verify that the op_a value is correct. @@ -143,6 +189,7 @@ impl CpuChip { signed_value, local.shard, local.channel, + local.unsigned_mem_val_nonce, local.mem_value_is_neg, ); @@ -195,6 +242,11 @@ impl CpuChip { .when(local.selectors.is_sh) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); + // When the instruction is SW, ensure that the offset is 0. + builder + .when(local.selectors.is_sw) + .assert_one(offset_is_zero.clone()); + // Compute the expected stored value for a SH instruction. let a_is_lower_half = offset_is_zero; let a_is_upper_half = memory_columns.offset_is_two; @@ -247,6 +299,12 @@ impl CpuChip { builder .when(local.selectors.is_lh + local.selectors.is_lhu) .assert_zero(memory_columns.offset_is_one + memory_columns.offset_is_three); + + // When the instruction is LW, ensure that the offset is zero. + builder + .when(local.selectors.is_lw) + .assert_one(offset_is_zero.clone()); + let use_lower_half = offset_is_zero; let use_upper_half = memory_columns.offset_is_two; let half_value = Word([ @@ -273,9 +331,12 @@ impl CpuChip { local: &CpuCols, unsigned_mem_val: &Word, ) { + let is_mem = self.is_memory_instruction::(&local.selectors); let mut recomposed_byte = AB::Expr::zero(); for i in 0..8 { - builder.assert_bool(memory_columns.most_sig_byte_decomp[i]); + builder + .when(is_mem.clone()) + .assert_bool(memory_columns.most_sig_byte_decomp[i]); recomposed_byte += memory_columns.most_sig_byte_decomp[i] * AB::Expr::from_canonical_u8(1 << i); } diff --git a/core/src/cpu/air/mod.rs b/core/src/cpu/air/mod.rs index 11a985bb5..4caebbbf7 100644 --- a/core/src/cpu/air/mod.rs +++ b/core/src/cpu/air/mod.rs @@ -22,13 +22,16 @@ use crate::bytes::ByteOpcode; use crate::cpu::columns::OpcodeSelectorCols; use crate::cpu::columns::{CpuCols, NUM_CPU_COLS}; use crate::cpu::CpuChip; +use crate::operations::BabyBearWordRangeChecker; use crate::runtime::Opcode; use super::columns::eval_channel_selectors; +use super::columns::OPCODE_SELECTORS_COL_MAP; impl Air for CpuChip where AB: SP1AirBuilder + AirBuilderWithPublicValues, + AB::Var: Sized, { #[inline(never)] fn eval(&self, builder: &mut AB) { @@ -84,6 +87,7 @@ where local.op_c_val(), local.shard, local.channel, + local.nonce, is_alu_instruction, ); @@ -121,6 +125,27 @@ where // Check that the is_real flag is correct. self.eval_is_real(builder, local, next); + + // Check that when `is_real=0` that all flags that send interactions are zero. + local + .selectors + .into_iter() + .enumerate() + .for_each(|(i, selector)| { + if i == OPCODE_SELECTORS_COL_MAP.imm_b { + builder + .when(AB::Expr::one() - local.is_real) + .assert_one(local.selectors.imm_b); + } else if i == OPCODE_SELECTORS_COL_MAP.imm_c { + builder + .when(AB::Expr::one() - local.is_real) + .assert_one(local.selectors.imm_c); + } else { + builder + .when(AB::Expr::one() - local.is_real) + .assert_zero(selector); + } + }); } } @@ -175,6 +200,26 @@ impl CpuChip { .when(is_jump_instruction.clone()) .assert_eq(jump_columns.next_pc.reduce::(), local.next_pc); + // Range check op_a, pc, and next_pc. + BabyBearWordRangeChecker::::range_check( + builder, + local.op_a_val(), + jump_columns.op_a_range_checker, + is_jump_instruction.clone(), + ); + BabyBearWordRangeChecker::::range_check( + builder, + jump_columns.pc, + jump_columns.pc_range_checker, + local.selectors.is_jal.into(), + ); + BabyBearWordRangeChecker::::range_check( + builder, + jump_columns.next_pc, + jump_columns.next_pc_range_checker, + is_jump_instruction.clone(), + ); + // Verify that the new pc is calculated correctly for JAL instructions. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), @@ -183,6 +228,7 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, + jump_columns.jal_nonce, local.selectors.is_jal, ); @@ -194,6 +240,7 @@ impl CpuChip { local.op_c_val(), local.shard, local.channel, + jump_columns.jalr_nonce, local.selectors.is_jalr, ); } @@ -208,6 +255,14 @@ impl CpuChip { .when(local.selectors.is_auipc) .assert_eq(auipc_columns.pc.reduce::(), local.pc); + // Range check the pc. + BabyBearWordRangeChecker::::range_check( + builder, + auipc_columns.pc, + auipc_columns.pc_range_checker, + local.selectors.is_auipc.into(), + ); + // Verify that op_a == pc + op_b. builder.send_alu( AB::Expr::from_canonical_u32(Opcode::ADD as u32), @@ -216,6 +271,7 @@ impl CpuChip { local.op_b_val(), local.shard, local.channel, + auipc_columns.auipc_nonce, local.selectors.is_auipc, ); } @@ -288,17 +344,16 @@ impl CpuChip { next: &CpuCols, is_branch_instruction: AB::Expr, ) { - // Verify that if is_sequential_instr is true, assert that local.is_real is true. - // This is needed for the following constraint, which is already degree 3. - builder - .when(local.is_sequential_instr) - .assert_one(local.is_real); - // When is_sequential_instr is true, assert that instruction is not branch, jump, or halt. // Note that the condition `when(local_is_real)` is implied from the previous constraint. let is_halt = self.get_is_halt_syscall::(builder, local); - builder.when(local.is_sequential_instr).assert_zero( - is_branch_instruction + local.selectors.is_jal + local.selectors.is_jalr + is_halt, + builder.when(local.is_real).assert_eq( + local.is_sequential_instr, + AB::Expr::one() + - (is_branch_instruction + + local.selectors.is_jal + + local.selectors.is_jalr + + is_halt), ); // Verify that the pc increments by 4 for all instructions except branch, jump and halt instructions. diff --git a/core/src/cpu/air/register.rs b/core/src/cpu/air/register.rs index e0b989c2b..23b6551d1 100644 --- a/core/src/cpu/air/register.rs +++ b/core/src/cpu/air/register.rs @@ -57,6 +57,15 @@ impl CpuChip { local.is_real, ); + // Always range check the word value in `op_a`, as JUMP instructions may witness + // an invalid word and write it to memory. + builder.slice_range_check_u8( + &local.op_a_access.access.value.0, + local.shard, + local.channel, + local.is_real, + ); + // If we are performing a branch or a store, then the value of `a` is the previous value. builder .when(is_branch_instruction.clone() + self.is_store_instruction::(&local.selectors)) diff --git a/core/src/cpu/columns/auipc.rs b/core/src/cpu/columns/auipc.rs index a6eb410e7..fa6871c21 100644 --- a/core/src/cpu/columns/auipc.rs +++ b/core/src/cpu/columns/auipc.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::{air::Word, operations::BabyBearWordRangeChecker}; pub const NUM_AUIPC_COLS: usize = size_of::>(); @@ -10,4 +10,6 @@ pub const NUM_AUIPC_COLS: usize = size_of::>(); pub struct AuipcCols { /// The current program counter. pub pc: Word, + pub pc_range_checker: BabyBearWordRangeChecker, + pub auipc_nonce: T, } diff --git a/core/src/cpu/columns/branch.rs b/core/src/cpu/columns/branch.rs index 06a77ad30..c6298ef0f 100644 --- a/core/src/cpu/columns/branch.rs +++ b/core/src/cpu/columns/branch.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::{air::Word, operations::BabyBearWordRangeChecker}; pub const NUM_BRANCH_COLS: usize = size_of::>(); @@ -11,9 +11,11 @@ pub const NUM_BRANCH_COLS: usize = size_of::>(); pub struct BranchCols { /// The current program counter. pub pc: Word, + pub pc_range_checker: BabyBearWordRangeChecker, /// The next program counter. pub next_pc: Word, + pub next_pc_range_checker: BabyBearWordRangeChecker, /// Whether a equals b. pub a_eq_b: T, @@ -23,4 +25,13 @@ pub struct BranchCols { /// Whether a is less than b. pub a_lt_b: T, + + /// The nonce of the operation to compute `a_lt_b`. + pub a_lt_b_nonce: T, + + /// The nonce of the operation to compute `a_gt_b`. + pub a_gt_b_nonce: T, + + /// The nonce of the operation to compute `next_pc`. + pub next_pc_nonce: T, } diff --git a/core/src/cpu/columns/ecall.rs b/core/src/cpu/columns/ecall.rs index 927b70614..5d91622c3 100644 --- a/core/src/cpu/columns/ecall.rs +++ b/core/src/cpu/columns/ecall.rs @@ -26,4 +26,7 @@ pub struct EcallCols { /// Field to store the word index passed into the COMMIT ecall. index_bitmap[word index] should /// be set to 1 and everything else set to 0. pub index_bitmap: [T; PV_DIGEST_NUM_WORDS], + + /// The nonce of the syscall operation. + pub syscall_nonce: T, } diff --git a/core/src/cpu/columns/jump.rs b/core/src/cpu/columns/jump.rs index ca94f3eca..0e1b5701f 100644 --- a/core/src/cpu/columns/jump.rs +++ b/core/src/cpu/columns/jump.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::air::Word; +use crate::{air::Word, operations::BabyBearWordRangeChecker}; pub const NUM_JUMP_COLS: usize = size_of::>(); @@ -10,7 +10,15 @@ pub const NUM_JUMP_COLS: usize = size_of::>(); pub struct JumpCols { /// The current program counter. pub pc: Word, + pub pc_range_checker: BabyBearWordRangeChecker, - /// THe next program counter. + /// The next program counter. pub next_pc: Word, + pub next_pc_range_checker: BabyBearWordRangeChecker, + + // A range checker for `op_a` which may contain `pc + 4`. + pub op_a_range_checker: BabyBearWordRangeChecker, + + pub jal_nonce: T, + pub jalr_nonce: T, } diff --git a/core/src/cpu/columns/memory.rs b/core/src/cpu/columns/memory.rs index fc54de34c..baab9e1fc 100644 --- a/core/src/cpu/columns/memory.rs +++ b/core/src/cpu/columns/memory.rs @@ -1,7 +1,7 @@ use sp1_derive::AlignedBorrow; use std::mem::size_of; -use crate::{air::Word, memory::MemoryReadWriteCols}; +use crate::{air::Word, memory::MemoryReadWriteCols, operations::BabyBearWordRangeChecker}; pub const NUM_MEMORY_COLUMNS: usize = size_of::>(); @@ -17,7 +17,11 @@ pub struct MemoryColumns { // addr_offset = addr_word % 4 // Note that this all needs to be verified in the AIR pub addr_word: Word, + pub addr_word_range_checker: BabyBearWordRangeChecker, + pub addr_aligned: T, + /// The LE bit decomp of the least significant byte of address aligned. + pub aa_least_sig_byte_decomp: [T; 6], pub addr_offset: T, pub memory_access: MemoryReadWriteCols, @@ -28,4 +32,7 @@ pub struct MemoryColumns { // LE bit decomposition for the most significant byte of memory value. This is used to determine // the sign for that value (used for LB and LH). pub most_sig_byte_decomp: [T; 8], + + pub addr_word_nonce: T, + pub unsigned_mem_val_nonce: T, } diff --git a/core/src/cpu/columns/mod.rs b/core/src/cpu/columns/mod.rs index d81bd806f..968c58362 100644 --- a/core/src/cpu/columns/mod.rs +++ b/core/src/cpu/columns/mod.rs @@ -40,6 +40,8 @@ pub struct CpuCols { /// The channel value, used for byte lookup multiplicity. pub channel: T, + pub nonce: T, + /// The clock cycle value. This should be within 24 bits. pub clk: T, /// The least significant 16 bit limb of clk. @@ -97,6 +99,8 @@ pub struct CpuCols { /// memory opcodes (i.e. LB, LH, LW, LBU, and LHU). pub unsigned_mem_val: Word, + pub unsigned_mem_val_nonce: T, + /// The result of selectors.is_ecall * the send_to_table column for the ECALL opcode. pub ecall_mul_send_to_table: T, diff --git a/core/src/cpu/columns/opcode.rs b/core/src/cpu/columns/opcode.rs index ac67c6934..80fd63ad3 100644 --- a/core/src/cpu/columns/opcode.rs +++ b/core/src/cpu/columns/opcode.rs @@ -1,11 +1,23 @@ use p3_field::PrimeField; use sp1_derive::AlignedBorrow; -use std::mem::size_of; +use std::mem::{size_of, transmute}; use std::vec::IntoIter; -use crate::runtime::{Instruction, Opcode}; +use crate::{ + runtime::{Instruction, Opcode}, + utils::indices_arr, +}; pub const NUM_OPCODE_SELECTOR_COLS: usize = size_of::>(); +pub const OPCODE_SELECTORS_COL_MAP: OpcodeSelectorCols = make_selectors_col_map(); + +/// Creates the column map for the CPU. +const fn make_selectors_col_map() -> OpcodeSelectorCols { + let indices_arr = indices_arr::(); + unsafe { + transmute::<[usize; NUM_OPCODE_SELECTOR_COLS], OpcodeSelectorCols>(indices_arr) + } +} /// The column layout for opcode selectors. #[derive(AlignedBorrow, Clone, Copy, Default, Debug)] @@ -98,7 +110,7 @@ impl IntoIterator for OpcodeSelectorCols { type IntoIter = IntoIter; fn into_iter(self) -> Self::IntoIter { - vec![ + let columns = vec![ self.imm_b, self.imm_c, self.is_alu, @@ -121,7 +133,8 @@ impl IntoIterator for OpcodeSelectorCols { self.is_jal, self.is_auipc, self.is_unimpl, - ] - .into_iter() + ]; + assert_eq!(columns.len(), NUM_OPCODE_SELECTOR_COLS); + columns.into_iter() } } diff --git a/core/src/cpu/event.rs b/core/src/cpu/event.rs index 2170d91d5..cdd38f476 100644 --- a/core/src/cpu/event.rs +++ b/core/src/cpu/event.rs @@ -51,4 +51,15 @@ pub struct CpuEvent { /// Exit code called with halt. pub exit_code: u32, + + pub alu_lookup_id: usize, + pub syscall_lookup_id: usize, + pub memory_add_lookup_id: usize, + pub memory_sub_lookup_id: usize, + pub branch_gt_lookup_id: usize, + pub branch_lt_lookup_id: usize, + pub branch_add_lookup_id: usize, + pub jump_jal_lookup_id: usize, + pub jump_jalr_lookup_id: usize, + pub auipc_lookup_id: usize, } diff --git a/core/src/cpu/trace.rs b/core/src/cpu/trace.rs index b65c4e43c..893faa385 100644 --- a/core/src/cpu/trace.rs +++ b/core/src/cpu/trace.rs @@ -1,3 +1,4 @@ +use std::array; use std::borrow::BorrowMut; use std::collections::HashMap; @@ -11,6 +12,8 @@ use tracing::instrument; use super::columns::{CPU_COL_MAP, NUM_CPU_COLS}; use super::{CpuChip, CpuEvent}; use crate::air::MachineAir; +use crate::air::Word; +use crate::alu::create_alu_lookups; use crate::alu::{self, AluEvent}; use crate::bytes::event::ByteRecord; use crate::bytes::{ByteLookupEvent, ByteOpcode}; @@ -42,7 +45,7 @@ impl MachineAir for CpuChip { let mut rows_with_events = input .cpu_events .par_iter() - .map(|op: &CpuEvent| self.event_to_row::(*op)) + .map(|op: &CpuEvent| self.event_to_row::(*op, &input.nonce_lookup)) .collect::>(); // No need to sort by the shard, since the cpu events are already partitioned by that. @@ -91,7 +94,7 @@ impl MachineAir for CpuChip { let mut alu = HashMap::new(); let mut blu: Vec<_> = Vec::default(); ops.iter().for_each(|op| { - let (_, alu_events, blu_events) = self.event_to_row::(*op); + let (_, alu_events, blu_events) = self.event_to_row::(*op, &HashMap::new()); alu_events.into_iter().for_each(|(key, value)| { alu.entry(key).or_insert(Vec::default()).extend(value); }); @@ -124,6 +127,7 @@ impl CpuChip { fn event_to_row( &self, event: CpuEvent, + nonce_lookup: &HashMap, ) -> ( [F; NUM_CPU_COLS], HashMap>, @@ -138,6 +142,14 @@ impl CpuChip { // Populate shard and clk columns. self.populate_shard_clk(cols, event, &mut new_blu_events); + // Populate the nonce. + cols.nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.alu_lookup_id) + .copied() + .unwrap_or_default(), + ); + // Populate basic fields. cols.pc = F::from_canonical_u32(event.pc); cols.next_pc = F::from_canonical_u32(event.next_pc); @@ -150,17 +162,45 @@ impl CpuChip { // Populate memory accesses for a, b, and c. if let Some(record) = event.a_record { cols.op_a_access - .populate(event.channel, record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events); } if let Some(MemoryRecordEnum::Read(record)) = event.b_record { cols.op_b_access - .populate(event.channel, record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events); } if let Some(MemoryRecordEnum::Read(record)) = event.c_record { cols.op_c_access - .populate(event.channel, record, &mut new_blu_events) + .populate(event.channel, record, &mut new_blu_events); } + // Populate range checks for a. + let a_bytes = cols + .op_a_access + .access + .value + .0 + .iter() + .map(|x| x.as_canonical_u32()) + .collect::>(); + new_blu_events.push(ByteLookupEvent { + shard: event.shard, + channel: event.channel, + opcode: ByteOpcode::U8Range, + a1: 0, + a2: 0, + b: a_bytes[0], + c: a_bytes[1], + }); + new_blu_events.push(ByteLookupEvent { + shard: event.shard, + channel: event.channel, + opcode: ByteOpcode::U8Range, + a1: 0, + a2: 0, + b: a_bytes[2], + c: a_bytes[3], + }); + // Populate memory accesses for reading from memory. assert_eq!(event.memory_record.is_some(), event.memory.is_some()); let memory_columns = cols.opcode_specific_columns.memory_mut(); @@ -171,19 +211,23 @@ impl CpuChip { } // Populate memory, branch, jump, and auipc specific fields. - self.populate_memory(cols, event, &mut new_alu_events, &mut new_blu_events); - self.populate_branch(cols, event, &mut new_alu_events); - self.populate_jump(cols, event, &mut new_alu_events); - self.populate_auipc(cols, event, &mut new_alu_events); - let is_halt = self.populate_ecall(cols, event); - - if !event.instruction.is_branch_instruction() - && !event.instruction.is_jump_instruction() - && !event.instruction.is_ecall_instruction() - && !is_halt - { - cols.is_sequential_instr = F::one(); - } + self.populate_memory( + cols, + event, + &mut new_alu_events, + &mut new_blu_events, + nonce_lookup, + ); + self.populate_branch(cols, event, &mut new_alu_events, nonce_lookup); + self.populate_jump(cols, event, &mut new_alu_events, nonce_lookup); + self.populate_auipc(cols, event, &mut new_alu_events, nonce_lookup); + let is_halt = self.populate_ecall(cols, event, nonce_lookup); + + cols.is_sequential_instr = F::from_bool( + !event.instruction.is_branch_instruction() + && !event.instruction.is_jump_instruction() + && !is_halt, + ); // Assert that the instruction is not a no-op. cols.is_real = F::one(); @@ -243,6 +287,7 @@ impl CpuChip { event: CpuEvent, new_alu_events: &mut HashMap>, new_blu_events: &mut Vec, + nonce_lookup: &HashMap, ) { if !matches!( event.instruction.opcode, @@ -261,12 +306,20 @@ impl CpuChip { // Populate addr_word and addr_aligned columns. let memory_columns = cols.opcode_specific_columns.memory_mut(); let memory_addr = event.b.wrapping_add(event.c); + let aligned_addr = memory_addr - memory_addr % WORD_SIZE as u32; memory_columns.addr_word = memory_addr.into(); - memory_columns.addr_aligned = - F::from_canonical_u32(memory_addr - memory_addr % WORD_SIZE as u32); + memory_columns.addr_word_range_checker.populate(memory_addr); + memory_columns.addr_aligned = F::from_canonical_u32(aligned_addr); + + // Populate the aa_least_sig_byte_decomp columns. + assert!(aligned_addr % 4 == 0); + let aligned_addr_ls_byte = (aligned_addr & 0x000000FF) as u8; + let bits: [bool; 8] = array::from_fn(|i| aligned_addr_ls_byte & (1 << i) != 0); + memory_columns.aa_least_sig_byte_decomp = array::from_fn(|i| F::from_bool(bits[i + 2])); // Add event to ALU check to check that addr == b + c let add_event = AluEvent { + lookup_id: event.memory_add_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -274,11 +327,18 @@ impl CpuChip { a: memory_addr, b: event.b, c: event.c, + sub_lookups: create_alu_lookups(), }; new_alu_events .entry(Opcode::ADD) .and_modify(|op_new_events| op_new_events.push(add_event)) .or_insert(vec![add_event]); + memory_columns.addr_word_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.memory_add_lookup_id) + .copied() + .unwrap_or_default(), + ); // Populate memory offsets. let addr_offset = (memory_addr % WORD_SIZE as u32) as u8; @@ -332,6 +392,7 @@ impl CpuChip { if memory_columns.most_sig_byte_decomp[7] == F::one() { cols.mem_value_is_neg = F::one(); let sub_event = AluEvent { + lookup_id: event.memory_sub_lookup_id, channel: event.channel, shard: event.shard, clk: event.clk, @@ -339,7 +400,14 @@ impl CpuChip { a: event.a, b: cols.unsigned_mem_val.to_u32(), c: sign_value, + sub_lookups: create_alu_lookups(), }; + cols.unsigned_mem_val_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.memory_sub_lookup_id) + .copied() + .unwrap_or_default(), + ); new_alu_events .entry(Opcode::SUB) @@ -370,6 +438,7 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, + nonce_lookup: &HashMap, ) { if event.instruction.is_branch_instruction() { let branch_columns = cols.opcode_specific_columns.branch_mut(); @@ -395,8 +464,10 @@ impl CpuChip { } else { Opcode::SLTU }; + // Add the ALU events for the comparisons let lt_comp_event = AluEvent { + lookup_id: event.branch_lt_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -404,7 +475,14 @@ impl CpuChip { a: a_lt_b as u32, b: event.a, c: event.b, + sub_lookups: create_alu_lookups(), }; + branch_columns.a_lt_b_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.branch_lt_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(alu_op_code) @@ -412,6 +490,7 @@ impl CpuChip { .or_insert(vec![lt_comp_event]); let gt_comp_event = AluEvent { + lookup_id: event.branch_gt_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -419,7 +498,14 @@ impl CpuChip { a: a_gt_b as u32, b: event.b, c: event.a, + sub_lookups: create_alu_lookups(), }; + branch_columns.a_gt_b_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.branch_gt_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(alu_op_code) @@ -438,14 +524,17 @@ impl CpuChip { _ => unreachable!(), }; - if branching { - let next_pc = event.pc.wrapping_add(event.c); + let next_pc = event.pc.wrapping_add(event.c); + branch_columns.pc = Word::from(event.pc); + branch_columns.next_pc = Word::from(next_pc); + branch_columns.pc_range_checker.populate(event.pc); + branch_columns.next_pc_range_checker.populate(next_pc); + if branching { cols.branching = F::one(); - branch_columns.pc = event.pc.into(); - branch_columns.next_pc = next_pc.into(); let add_event = AluEvent { + lookup_id: event.branch_add_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -453,7 +542,14 @@ impl CpuChip { a: next_pc, b: event.pc, c: event.c, + sub_lookups: create_alu_lookups(), }; + branch_columns.next_pc_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.branch_add_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -471,6 +567,7 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, + nonce_lookup: &HashMap, ) { if event.instruction.is_jump_instruction() { let jump_columns = cols.opcode_specific_columns.jump_mut(); @@ -478,10 +575,14 @@ impl CpuChip { match event.instruction.opcode { Opcode::JAL => { let next_pc = event.pc.wrapping_add(event.b); - jump_columns.pc = event.pc.into(); - jump_columns.next_pc = next_pc.into(); + jump_columns.op_a_range_checker.populate(event.a); + jump_columns.pc = Word::from(event.pc); + jump_columns.pc_range_checker.populate(event.pc); + jump_columns.next_pc = Word::from(next_pc); + jump_columns.next_pc_range_checker.populate(next_pc); let add_event = AluEvent { + lookup_id: event.jump_jal_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -489,7 +590,14 @@ impl CpuChip { a: next_pc, b: event.pc, c: event.b, + sub_lookups: create_alu_lookups(), }; + jump_columns.jal_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.jump_jal_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -498,9 +606,12 @@ impl CpuChip { } Opcode::JALR => { let next_pc = event.b.wrapping_add(event.c); - jump_columns.next_pc = next_pc.into(); + jump_columns.op_a_range_checker.populate(event.a); + jump_columns.next_pc = Word::from(next_pc); + jump_columns.next_pc_range_checker.populate(next_pc); let add_event = AluEvent { + lookup_id: event.jump_jalr_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -508,7 +619,14 @@ impl CpuChip { a: next_pc, b: event.b, c: event.c, + sub_lookups: create_alu_lookups(), }; + jump_columns.jalr_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.jump_jalr_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -526,13 +644,16 @@ impl CpuChip { cols: &mut CpuCols, event: CpuEvent, alu_events: &mut HashMap>, + nonce_lookup: &HashMap, ) { if matches!(event.instruction.opcode, Opcode::AUIPC) { let auipc_columns = cols.opcode_specific_columns.auipc_mut(); - auipc_columns.pc = event.pc.into(); + auipc_columns.pc = Word::from(event.pc); + auipc_columns.pc_range_checker.populate(event.pc); let add_event = AluEvent { + lookup_id: event.auipc_lookup_id, shard: event.shard, channel: event.channel, clk: event.clk, @@ -540,7 +661,14 @@ impl CpuChip { a: event.a, b: event.pc, c: event.b, + sub_lookups: create_alu_lookups(), }; + auipc_columns.auipc_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.auipc_lookup_id) + .copied() + .unwrap_or_default(), + ); alu_events .entry(Opcode::ADD) @@ -550,7 +678,12 @@ impl CpuChip { } /// Populate columns related to ECALL. - fn populate_ecall(&self, cols: &mut CpuCols, _: CpuEvent) -> bool { + fn populate_ecall( + &self, + cols: &mut CpuCols, + event: CpuEvent, + nonce_lookup: &HashMap, + ) -> bool { let mut is_halt = false; if cols.selectors.is_ecall == F::one() { @@ -604,6 +737,14 @@ impl CpuChip { ecall_cols.index_bitmap[digest_idx] = F::one(); } + // Write the syscall nonce. + ecall_cols.syscall_nonce = F::from_canonical_u32( + nonce_lookup + .get(&event.syscall_lookup_id) + .copied() + .unwrap_or_default(), + ); + is_halt = syscall_id == F::from_canonical_u32(SyscallCode::HALT.syscall_id()); } @@ -640,41 +781,41 @@ mod tests { use super::*; - use crate::runtime::{tests::simple_program, Instruction, Runtime}; + use crate::runtime::{tests::simple_program, Runtime}; use crate::utils::{run_test, setup_logger, SP1CoreOpts}; - #[test] - fn generate_trace() { - let mut shard = ExecutionRecord::default(); - shard.cpu_events = vec![CpuEvent { - shard: 1, - channel: 0, - clk: 6, - pc: 1, - next_pc: 5, - instruction: Instruction { - opcode: Opcode::ADD, - op_a: 0, - op_b: 1, - op_c: 2, - imm_b: false, - imm_c: false, - }, - a: 1, - a_record: None, - b: 2, - b_record: None, - c: 3, - c_record: None, - memory: None, - memory_record: None, - exit_code: 0, - }]; - let chip = CpuChip::default(); - let trace: RowMajorMatrix = - chip.generate_trace(&shard, &mut ExecutionRecord::default()); - println!("{:?}", trace.values); - } + // #[test] + // fn generate_trace() { + // let mut shard = ExecutionRecord::default(); + // shard.cpu_events = vec![CpuEvent { + // shard: 1, + // channel: 0, + // clk: 6, + // pc: 1, + // next_pc: 5, + // instruction: Instruction { + // opcode: Opcode::ADD, + // op_a: 0, + // op_b: 1, + // op_c: 2, + // imm_b: false, + // imm_c: false, + // }, + // a: 1, + // a_record: None, + // b: 2, + // b_record: None, + // c: 3, + // c_record: None, + // memory: None, + // memory_record: None, + // exit_code: 0, + // }]; + // let chip = CpuChip::default(); + // let trace: RowMajorMatrix = + // chip.generate_trace(&shard, &mut ExecutionRecord::default()); + // println!("{:?}", trace.values); + // } #[test] fn generate_trace_simple_program() { diff --git a/core/src/lookup/interaction.rs b/core/src/lookup/interaction.rs index 74b7a9fc0..1c20938cc 100644 --- a/core/src/lookup/interaction.rs +++ b/core/src/lookup/interaction.rs @@ -74,7 +74,6 @@ impl Interaction { } } -// TODO: add debug for VirtualPairCol so that we can derive Debug for Interaction. impl Debug for Interaction { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("Interaction") diff --git a/core/src/memory/global.rs b/core/src/memory/global.rs index 3786dd4ca..a6cf49d02 100644 --- a/core/src/memory/global.rs +++ b/core/src/memory/global.rs @@ -1,5 +1,6 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; +use std::array; use p3_air::BaseAir; use p3_air::{Air, AirBuilder}; @@ -10,8 +11,9 @@ use p3_matrix::Matrix; use sp1_derive::AlignedBorrow; use super::MemoryInitializeFinalizeEvent; -use crate::air::{AirInteraction, SP1AirBuilder, Word}; -use crate::air::{MachineAir, WordAirBuilder}; +use crate::air::MachineAir; +use crate::air::{AirInteraction, BaseAirBuilder, SP1AirBuilder}; +use crate::operations::BabyBearBitDecomposition; use crate::runtime::{ExecutionRecord, Program}; use crate::utils::pad_to_power_of_two; @@ -62,7 +64,7 @@ impl MachineAir for MemoryChip { MemoryChipType::Finalize => input.memory_finalize_events.clone(), }; memory_events.sort_by_key(|event| event.addr); - let rows: Vec<[F; 8]> = (0..memory_events.len()) // TODO: change this back to par_iter + let rows: Vec<[F; NUM_MEMORY_INIT_COLS]> = (0..memory_events.len()) // OPT: change this to par_iter .map(|i| { let MemoryInitializeFinalizeEvent { addr, @@ -71,14 +73,37 @@ impl MachineAir for MemoryChip { timestamp, used, } = memory_events[i]; + let mut row = [F::zero(); NUM_MEMORY_INIT_COLS]; let cols: &mut MemoryInitCols = row.as_mut_slice().borrow_mut(); cols.addr = F::from_canonical_u32(addr); + cols.addr_bits.populate(addr); cols.shard = F::from_canonical_u32(shard); cols.timestamp = F::from_canonical_u32(timestamp); - cols.value = value.into(); + cols.value = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1)); cols.is_real = F::from_canonical_u32(used); + if i != memory_events.len() - 1 { + let next_addr = memory_events[i + 1].addr; + assert_ne!(next_addr, addr); + + cols.addr_bits.populate(addr); + + cols.seen_diff_bits[0] = F::zero(); + for j in 0..32 { + let rev_j = 32 - j - 1; + let next_bit = ((next_addr >> rev_j) & 1) == 1; + let local_bit = ((addr >> rev_j) & 1) == 1; + cols.match_bits[j] = + F::from_bool((local_bit && next_bit) || (!local_bit && !next_bit)); + cols.seen_diff_bits[j + 1] = cols.seen_diff_bits[j] + + (F::one() - cols.seen_diff_bits[j]) * (F::one() - cols.match_bits[j]); + cols.not_match_and_not_seen_diff_bits[j] = + (F::one() - cols.match_bits[j]) * (F::one() - cols.seen_diff_bits[j]); + } + assert_eq!(cols.seen_diff_bits[cols.seen_diff_bits.len() - 1], F::one()); + } + row }) .collect::>(); @@ -101,7 +126,7 @@ impl MachineAir for MemoryChip { } } -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[derive(AlignedBorrow, Debug, Clone, Copy)] #[repr(C)] pub struct MemoryInitCols { /// The shard number of the memory access. @@ -113,8 +138,20 @@ pub struct MemoryInitCols { /// The address of the memory access. pub addr: T, + /// A bit decomposition of `addr`. + pub addr_bits: BabyBearBitDecomposition, + + // Whether the i'th bit matches the next addr's bit. + pub match_bits: [T; 32], + + // Whether we've seen a different bit in the comparison. + pub seen_diff_bits: [T; 33], + + // Whether the i'th bit doesn't match the next addr's bit and we haven't seen a diff bitn yet. + pub not_match_and_not_seen_diff_bits: [T; 32], + /// The value of the memory access. - pub value: Word, + pub value: [T; 32], /// Whether the memory access is a real access. pub is_real: T, @@ -130,10 +167,29 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &MemoryInitCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &MemoryInitCols = (*next).borrow(); + + builder.assert_bool(local.is_real); + for i in 0..32 { + builder.assert_bool(local.value[i]); + } + + let mut byte1 = AB::Expr::zero(); + let mut byte2 = AB::Expr::zero(); + let mut byte3 = AB::Expr::zero(); + let mut byte4 = AB::Expr::zero(); + for i in 0..8 { + byte1 += local.value[i].into() * AB::F::from_canonical_u8(1 << i); + byte2 += local.value[i + 8].into() * AB::F::from_canonical_u8(1 << i); + byte3 += local.value[i + 16].into() * AB::F::from_canonical_u8(1 << i); + byte4 += local.value[i + 24].into() * AB::F::from_canonical_u8(1 << i); + } + let value = [byte1, byte2, byte3, byte4]; if self.kind == MemoryChipType::Initialize { let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), local.addr.into()]; - values.extend(local.value.map(Into::into)); + values.extend(value.map(Into::into)); builder.receive(AirInteraction::new( values, local.is_real.into(), @@ -145,7 +201,7 @@ where local.timestamp.into(), local.addr.into(), ]; - values.extend(local.value.map(Into::into)); + values.extend(value); builder.send(AirInteraction::new( values, local.is_real.into(), @@ -153,16 +209,106 @@ where )); } + // We want to assert addr < addr'. Assume seen_diff_0 = 0. + // + // match_i = (addr_i & addr'_i) || (!addr_i & !addr'_i) + // => + // match_i == addr_i * addr_i + (1 - addr_i) * (1 - addr'_i) + // + // when !match_i and !seen_diff_i, then enforce (addr_i == 0) and (addr'_i == 1). + // if seen_diff_i: + // seen_diff_{i+1} = 1 + // else: + // seen_diff_{i+1} = !match_i + // => + // builder.when(!match_i * !seen_diff_i).assert_zero(addr_i) + // builder.when(!match_i * !seen_diff_i).assert_one(addr'_i) + // seen_diff_bit_{i+1} == seen_diff_i + (1-seen_diff_i) * (1 - match_i) + // + // at the end of the algorithm, assert that we've seen a diff bit. + // => + // seen_diff_bit_{last} == 1 + + // Assert that we start with assuming that we haven't seen a diff bit. + builder.assert_zero(local.seen_diff_bits[0]); + + for i in 0..local.addr_bits.bits.len() { + // Compute the i'th msb bit's index. + let rev_i = local.addr_bits.bits.len() - i - 1; + + // Compute whether the i'th msb bit matches. + let match_i = local.addr_bits.bits[rev_i] * next.addr_bits.bits[rev_i] + + (AB::Expr::one() - local.addr_bits.bits[rev_i]) + * (AB::Expr::one() - next.addr_bits.bits[rev_i]); + builder + .when_transition() + .when(next.is_real) + .assert_eq(match_i.clone(), local.match_bits[i]); + + // Compute whether it's not a match and we haven't seen a diff bit. + let not_match_and_not_seen_diff_i = (AB::Expr::one() - local.match_bits[i]) + * (AB::Expr::one() - local.seen_diff_bits[i]); + builder.when_transition().when(next.is_real).assert_eq( + local.not_match_and_not_seen_diff_bits[i], + not_match_and_not_seen_diff_i, + ); + + // If the i'th msb bit doesn't match and it's the first time we've seen a diff bit, + // then enforce that the next bit is one and the current bit is zero. + builder + .when_transition() + .when(local.not_match_and_not_seen_diff_bits[i]) + .when(next.is_real) + .assert_zero(local.addr_bits.bits[rev_i]); + builder + .when_transition() + .when(local.not_match_and_not_seen_diff_bits[i]) + .when(next.is_real) + .assert_one(next.addr_bits.bits[rev_i]); + + // Update the seen diff bits. + builder.when_transition().assert_eq( + local.seen_diff_bits[i + 1], + local.seen_diff_bits[i] + local.not_match_and_not_seen_diff_bits[i], + ); + } + + // Assert that on rows where the next row is real, we've seen a diff bit. + builder + .when_transition() + .when(next.is_real) + .assert_one(local.seen_diff_bits[local.addr_bits.bits.len()]); + + // Canonically decompose the address into bits so we can do comparisons. + BabyBearBitDecomposition::::range_check( + builder, + local.addr, + local.addr_bits, + local.is_real.into(), + ); + + // Assert that the real rows are all padded to the top. + builder + .when_transition() + .when_not(local.is_real) + .assert_zero(next.is_real); + + if self.kind == MemoryChipType::Initialize { + builder + .when(local.is_real) + .assert_eq(local.timestamp, AB::F::one()); + } + // Register %x0 should always be 0. See 2.6 Load and Store Instruction on // P.18 of the RISC-V spec. To ensure that, we expect that the first row of the Initialize // and Finalize global memory chip is for register %x0 (i.e. addr = 0x0), and that those rows // have a value of 0. Additionally, in the CPU air, we ensure that whenever op_a is set to // %x0, its value is 0. - // - // TODO: Add a similar check for MemoryChipType::Initialize. - if self.kind == MemoryChipType::Finalize { + if self.kind == MemoryChipType::Initialize || self.kind == MemoryChipType::Finalize { builder.when_first_row().assert_zero(local.addr); - builder.when_first_row().assert_word_zero(local.value); + for i in 0..32 { + builder.when_first_row().assert_zero(local.value[i]); + } } } } diff --git a/core/src/memory/mod.rs b/core/src/memory/mod.rs index 7acdee1fb..4246db14c 100644 --- a/core/src/memory/mod.rs +++ b/core/src/memory/mod.rs @@ -27,7 +27,7 @@ impl MemoryInitializeFinalizeEvent { addr, value, shard: 0, - timestamp: 0, + timestamp: 1, used: if used { 1 } else { 0 }, } } diff --git a/core/src/memory/program.rs b/core/src/memory/program.rs index 3d922c4ae..64eeb25a2 100644 --- a/core/src/memory/program.rs +++ b/core/src/memory/program.rs @@ -1,6 +1,6 @@ use core::borrow::{Borrow, BorrowMut}; use core::mem::size_of; -use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder}; +use p3_air::{Air, AirBuilderWithPublicValues, BaseAir, PairBuilder}; use p3_field::AbstractField; use p3_field::PrimeField; use p3_matrix::dense::RowMajorMatrix; @@ -10,7 +10,6 @@ use sp1_derive::AlignedBorrow; use crate::air::{AirInteraction, PublicValues, SP1AirBuilder}; use crate::air::{MachineAir, Word}; -use crate::operations::IsZeroOperation; use crate::runtime::{ExecutionRecord, Program}; use crate::utils::pad_to_power_of_two; @@ -31,10 +30,10 @@ pub struct MemoryProgramPreprocessedCols { #[derive(AlignedBorrow, Clone, Copy, Default)] #[repr(C)] pub struct MemoryProgramMultCols { - /// The multiplicity of the event, must be 1 in the first shard and 0 otherwise. + /// The multiplicity of the event. + /// + /// This column is technically redundant with `is_real`, but it's included for clarity. pub multiplicity: T, - /// Columns to see if current shard is 1. - pub is_first_shard: IsZeroOperation, } /// Chip that initializes memory that is provided from the program. The table is preprocessed and @@ -120,8 +119,6 @@ impl MachineAir for MemoryProgramChip { let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS]; let cols: &mut MemoryProgramMultCols = row.as_mut_slice().borrow_mut(); cols.multiplicity = mult; - IsZeroOperation::populate(&mut cols.is_first_shard, input.index - 1); - row }) .collect::>(); @@ -138,8 +135,8 @@ impl MachineAir for MemoryProgramChip { trace } - fn included(&self, _: &Self::Record) -> bool { - true + fn included(&self, record: &Self::Record) -> bool { + record.index == 1 } } @@ -171,24 +168,15 @@ where .map(|elm| (*elm).into()) .collect::>(), ); - IsZeroOperation::::eval( - builder, - public_values.shard - AB::Expr::one(), - mult_local.is_first_shard, - prep_local.is_real.into(), - ); - let is_first_shard = mult_local.is_first_shard.result; // Multiplicity must be either 0 or 1. builder.assert_bool(mult_local.multiplicity); + // If first shard and preprocessed is real, multiplicity must be one. - builder - .when(is_first_shard * prep_local.is_real) - .assert_one(mult_local.multiplicity); - // If not first shard or preprocessed is not real, multiplicity must be zero. - builder - .when((AB::Expr::one() - is_first_shard) + (AB::Expr::one() - prep_local.is_real)) - .assert_zero(mult_local.multiplicity); + builder.assert_eq(mult_local.multiplicity, prep_local.is_real.into()); + + // The shard this chip is contained in must be one. + builder.assert_one(public_values.shard); let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()]; values.extend(prep_local.value.map(Into::into)); diff --git a/core/src/operations/baby_bear_range.rs b/core/src/operations/baby_bear_range.rs new file mode 100644 index 000000000..7e1ad0ef4 --- /dev/null +++ b/core/src/operations/baby_bear_range.rs @@ -0,0 +1,88 @@ +use std::array; + +use p3_air::AirBuilder; +use p3_field::{AbstractField, Field}; +use sp1_derive::AlignedBorrow; + +use crate::stark::SP1AirBuilder; + +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct BabyBearBitDecomposition { + /// The bit decoposition of the`value`. + pub bits: [T; 32], + + /// The product of the the bits 3 to 5 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_5: T, + + /// The product of the the bits 3 to 6 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_6: T, + + /// The product of the the bits 3 to 7 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_7: T, +} + +impl BabyBearBitDecomposition { + pub fn populate(&mut self, value: u32) { + self.bits = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1)); + let most_sig_byte_decomp = &self.bits[24..32]; + self.and_most_sig_byte_decomp_3_to_5 = most_sig_byte_decomp[3] * most_sig_byte_decomp[4]; + self.and_most_sig_byte_decomp_3_to_6 = + self.and_most_sig_byte_decomp_3_to_5 * most_sig_byte_decomp[5]; + self.and_most_sig_byte_decomp_3_to_7 = + self.and_most_sig_byte_decomp_3_to_6 * most_sig_byte_decomp[6]; + } + + pub fn range_check( + builder: &mut AB, + value: AB::Var, + cols: BabyBearBitDecomposition, + is_real: AB::Expr, + ) { + let mut reconstructed_value = AB::Expr::zero(); + for (i, bit) in cols.bits.iter().enumerate() { + builder.when(is_real.clone()).assert_bool(*bit); + reconstructed_value += AB::Expr::from_wrapped_u32(1 << i) * *bit; + } + + // Assert that bits2num(bits) == value. + builder + .when(is_real.clone()) + .assert_eq(reconstructed_value, value); + + // Range check that value is less than baby bear modulus. To do this, it is sufficient + // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) + // 01111000_00000000_00000000_00000001. So we need to check the following conditions: + // 1) if most_sig_byte > 01111000, then fail. + // 2) if most_sig_byte == 01111000, then value's lower sig bytes must all be 0. + // 3) if most_sig_byte < 01111000, then pass. + let most_sig_byte_decomp = &cols.bits[24..32]; + builder + .when(is_real.clone()) + .assert_zero(most_sig_byte_decomp[7]); + + // Compute the product of the "top bits". + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_5, + most_sig_byte_decomp[3] * most_sig_byte_decomp[4], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_6, + cols.and_most_sig_byte_decomp_3_to_5 * most_sig_byte_decomp[5], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_7, + cols.and_most_sig_byte_decomp_3_to_6 * most_sig_byte_decomp[6], + ); + + // If the top bits are all 0, then the lower bits must all be 0. + let mut lower_bits_sum: AB::Expr = AB::Expr::zero(); + for bit in cols.bits[0..27].iter() { + lower_bits_sum = lower_bits_sum + *bit; + } + builder + .when(is_real) + .when(cols.and_most_sig_byte_decomp_3_to_7) + .assert_zero(lower_bits_sum); + } +} diff --git a/core/src/operations/baby_bear_word.rs b/core/src/operations/baby_bear_word.rs new file mode 100644 index 000000000..2e773b3e6 --- /dev/null +++ b/core/src/operations/baby_bear_word.rs @@ -0,0 +1,94 @@ +use std::array; + +use p3_air::AirBuilder; +use p3_field::{AbstractField, Field}; +use sp1_derive::AlignedBorrow; + +use crate::{air::Word, stark::SP1AirBuilder}; + +/// A set of columns needed to compute the add of two words. +#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] +#[repr(C)] +pub struct BabyBearWordRangeChecker { + /// Most sig byte LE bit decomposition. + pub most_sig_byte_decomp: [T; 8], + + /// The product of the the bits 3 to 5 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_5: T, + + /// The product of the the bits 3 to 6 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_6: T, + + /// The product of the the bits 3 to 7 in `most_sig_byte_decomp`. + pub and_most_sig_byte_decomp_3_to_7: T, +} + +impl BabyBearWordRangeChecker { + pub fn populate(&mut self, value: u32) { + self.most_sig_byte_decomp = array::from_fn(|i| F::from_bool(value & (1 << (i + 24)) != 0)); + self.and_most_sig_byte_decomp_3_to_5 = + self.most_sig_byte_decomp[3] * self.most_sig_byte_decomp[4]; + self.and_most_sig_byte_decomp_3_to_6 = + self.and_most_sig_byte_decomp_3_to_5 * self.most_sig_byte_decomp[5]; + self.and_most_sig_byte_decomp_3_to_7 = + self.and_most_sig_byte_decomp_3_to_6 * self.most_sig_byte_decomp[6]; + } + + pub fn range_check( + builder: &mut AB, + value: Word, + cols: BabyBearWordRangeChecker, + is_real: AB::Expr, + ) { + let mut recomposed_byte = AB::Expr::zero(); + cols.most_sig_byte_decomp + .iter() + .enumerate() + .for_each(|(i, value)| { + builder.when(is_real.clone()).assert_bool(*value); + recomposed_byte = + recomposed_byte.clone() + AB::Expr::from_canonical_usize(1 << i) * *value; + }); + + builder + .when(is_real.clone()) + .assert_eq(recomposed_byte, value[3]); + + // Range check that value is less than baby bear modulus. To do this, it is sufficient + // to just do comparisons for the most significant byte. BabyBear's modulus is (in big endian binary) + // 01111000_00000000_00000000_00000001. So we need to check the following conditions: + // 1) if most_sig_byte > 01111000, then fail. + // 2) if most_sig_byte == 01111000, then value's lower sig bytes must all be 0. + // 3) if most_sig_byte < 01111000, then pass. + builder + .when(is_real.clone()) + .assert_zero(cols.most_sig_byte_decomp[7]); + + // Compute the product of the "top bits". + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_5, + cols.most_sig_byte_decomp[3] * cols.most_sig_byte_decomp[4], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_6, + cols.and_most_sig_byte_decomp_3_to_5 * cols.most_sig_byte_decomp[5], + ); + builder.when(is_real.clone()).assert_eq( + cols.and_most_sig_byte_decomp_3_to_7, + cols.and_most_sig_byte_decomp_3_to_6 * cols.most_sig_byte_decomp[6], + ); + + let bottom_bits: AB::Expr = cols.most_sig_byte_decomp[0..3] + .iter() + .map(|bit| (*bit).into()) + .sum(); + builder + .when(is_real.clone()) + .when(cols.and_most_sig_byte_decomp_3_to_7) + .assert_zero(bottom_bits); + builder + .when(is_real) + .when(cols.and_most_sig_byte_decomp_3_to_7) + .assert_zero(value[0] + value[1] + value[2]); + } +} diff --git a/core/src/operations/field/field_op.rs b/core/src/operations/field/field_op.rs index ae04e2b9b..995142c2f 100644 --- a/core/src/operations/field/field_op.rs +++ b/core/src/operations/field/field_op.rs @@ -445,7 +445,6 @@ mod tests { let mut challenger = config.challenger(); - // TODO: test with other fields let chip: FieldOpChip = FieldOpChip::new(*op); let shard = ExecutionRecord::default(); let trace: RowMajorMatrix = diff --git a/core/src/operations/field/field_sqrt.rs b/core/src/operations/field/field_sqrt.rs index c0401a1d4..e16de147b 100644 --- a/core/src/operations/field/field_sqrt.rs +++ b/core/src/operations/field/field_sqrt.rs @@ -83,6 +83,20 @@ impl FieldSqrtCols { }; record.add_byte_lookup_event(and_event); + // Add the byte range check for `sqrt`. + record.add_u8_range_checks( + shard, + channel, + self.multiplication + .result + .0 + .as_slice() + .iter() + .map(|x| x.as_canonical_u32() as u8) + .collect::>() + .as_slice(), + ); + sqrt } } @@ -129,6 +143,14 @@ where is_real.clone(), ); + // Range check that `sqrt` limbs are bytes. + builder.slice_range_check_u8( + sqrt.0.as_slice(), + shard.clone(), + channel.clone(), + is_real.clone(), + ); + // Assert that the square root is the positive one, i.e., with least significant bit 0. // This is done by computing LSB = least_significant_byte & 1. builder.assert_bool(self.lsb); diff --git a/core/src/operations/mod.rs b/core/src/operations/mod.rs index 242c9100b..e3fbcc78b 100644 --- a/core/src/operations/mod.rs +++ b/core/src/operations/mod.rs @@ -8,6 +8,8 @@ mod add; mod add4; mod add5; mod and; +mod baby_bear_range; +mod baby_bear_word; pub mod field; mod fixed_rotate_right; mod fixed_shift_right; @@ -22,6 +24,8 @@ pub use add::*; pub use add4::*; pub use add5::*; pub use and::*; +pub use baby_bear_range::*; +pub use baby_bear_word::*; pub use fixed_rotate_right::*; pub use fixed_shift_right::*; pub use is_equal_word::*; diff --git a/core/src/operations/or.rs b/core/src/operations/or.rs index 8cb3f0019..b30821532 100644 --- a/core/src/operations/or.rs +++ b/core/src/operations/or.rs @@ -10,8 +10,6 @@ use crate::disassembler::WORD_SIZE; use crate::runtime::ExecutionRecord; /// A set of columns needed to compute the or of two words. -/// -/// TODO: This is currently not in use, and thus not tested thoroughly yet. #[derive(AlignedBorrow, Default, Debug, Clone, Copy)] #[repr(C)] pub struct OrOperation { diff --git a/core/src/runtime/mod.rs b/core/src/runtime/mod.rs index 003d35e8c..3eb8c4ead 100644 --- a/core/src/runtime/mod.rs +++ b/core/src/runtime/mod.rs @@ -30,6 +30,8 @@ use std::sync::Arc; use thiserror::Error; +use crate::alu::create_alu_lookup_id; +use crate::alu::create_alu_lookups; use crate::bytes::NUM_BYTE_LOOKUP_CHANNELS; use crate::memory::MemoryInitializeFinalizeEvent; use crate::utils::SP1CoreOpts; @@ -445,6 +447,8 @@ impl Runtime { memory_store_value: Option, record: MemoryAccessRecord, exit_code: u32, + lookup_id: usize, + syscall_lookup_id: usize, ) { let cpu_event = CpuEvent { shard, @@ -462,14 +466,25 @@ impl Runtime { memory: memory_store_value, memory_record: record.memory, exit_code, + alu_lookup_id: lookup_id, + syscall_lookup_id, + memory_add_lookup_id: create_alu_lookup_id(), + memory_sub_lookup_id: create_alu_lookup_id(), + branch_lt_lookup_id: create_alu_lookup_id(), + branch_gt_lookup_id: create_alu_lookup_id(), + branch_add_lookup_id: create_alu_lookup_id(), + jump_jal_lookup_id: create_alu_lookup_id(), + jump_jalr_lookup_id: create_alu_lookup_id(), + auipc_lookup_id: create_alu_lookup_id(), }; self.record.cpu_events.push(cpu_event); } /// Emit an ALU event. - fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32) { + fn emit_alu(&mut self, clk: u32, opcode: Opcode, a: u32, b: u32, c: u32, lookup_id: usize) { let event = AluEvent { + lookup_id, shard: self.shard(), clk, channel: self.channel(), @@ -477,6 +492,7 @@ impl Runtime { a, b, c, + sub_lookups: create_alu_lookups(), }; match opcode { Opcode::ADD => { @@ -530,10 +546,18 @@ impl Runtime { } /// Set the destination register with the result and emit an ALU event. - fn alu_rw(&mut self, instruction: Instruction, rd: Register, a: u32, b: u32, c: u32) { + fn alu_rw( + &mut self, + instruction: Instruction, + rd: Register, + a: u32, + b: u32, + c: u32, + lookup_id: usize, + ) { self.rw(rd, a); if self.emit_events { - self.emit_alu(self.state.clk, instruction.opcode, a, b, c); + self.emit_alu(self.state.clk, instruction.opcode, a, b, c, lookup_id); } } @@ -586,6 +610,9 @@ impl Runtime { let mut memory_store_value: Option = None; self.memory_accesses = MemoryAccessRecord::default(); + let lookup_id = create_alu_lookup_id(); + let syscall_lookup_id = create_alu_lookup_id(); + if self.should_report && !self.unconstrained { self.report .instruction_counts @@ -599,52 +626,52 @@ impl Runtime { Opcode::ADD => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_add(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SUB => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_sub(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::XOR => { (rd, b, c) = self.alu_rr(instruction); a = b ^ c; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::OR => { (rd, b, c) = self.alu_rr(instruction); a = b | c; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::AND => { (rd, b, c) = self.alu_rr(instruction); a = b & c; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SLL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_shl(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SRL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_shr(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SRA => { (rd, b, c) = self.alu_rr(instruction); a = (b as i32).wrapping_shr(c) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SLT => { (rd, b, c) = self.alu_rr(instruction); a = if (b as i32) < (c as i32) { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::SLTU => { (rd, b, c) = self.alu_rr(instruction); a = if b < c { 1 } else { 0 }; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } // Load instructions. @@ -818,6 +845,7 @@ impl Runtime { let syscall_impl = self.get_syscall(syscall).cloned(); let mut precompile_rt = SyscallContext::new(self); + precompile_rt.syscall_lookup_id = syscall_lookup_id; let (precompile_next_pc, precompile_cycles, returned_exit_code) = if let Some(syscall_impl) = syscall_impl { // Executing a syscall optionally returns a value to write to the t0 register. @@ -862,22 +890,22 @@ impl Runtime { Opcode::MUL => { (rd, b, c) = self.alu_rr(instruction); a = b.wrapping_mul(c); - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::MULH => { (rd, b, c) = self.alu_rr(instruction); a = (((b as i32) as i64).wrapping_mul((c as i32) as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::MULHU => { (rd, b, c) = self.alu_rr(instruction); a = ((b as u64).wrapping_mul(c as u64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::MULHSU => { (rd, b, c) = self.alu_rr(instruction); a = (((b as i32) as i64).wrapping_mul(c as i64) >> 32) as u32; - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::DIV => { (rd, b, c) = self.alu_rr(instruction); @@ -886,7 +914,7 @@ impl Runtime { } else { a = (b as i32).wrapping_div(c as i32) as u32; } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::DIVU => { (rd, b, c) = self.alu_rr(instruction); @@ -895,7 +923,7 @@ impl Runtime { } else { a = b.wrapping_div(c); } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::REM => { (rd, b, c) = self.alu_rr(instruction); @@ -904,7 +932,7 @@ impl Runtime { } else { a = (b as i32).wrapping_rem(c as i32) as u32; } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } Opcode::REMU => { (rd, b, c) = self.alu_rr(instruction); @@ -913,7 +941,7 @@ impl Runtime { } else { a = b.wrapping_rem(c); } - self.alu_rw(instruction, rd, a, b, c); + self.alu_rw(instruction, rd, a, b, c, lookup_id); } // See https://github.com/riscv-non-isa/riscv-asm-manual/blob/master/riscv-asm.md#instruction-aliases @@ -950,6 +978,8 @@ impl Runtime { memory_store_value, self.memory_accesses, exit_code, + lookup_id, + syscall_lookup_id, ); }; Ok(()) @@ -1111,7 +1141,7 @@ impl Runtime { None => &MemoryRecord { value: 0, shard: 0, - timestamp: 0, + timestamp: 1, }, }; memory_finalize_events.push(MemoryInitializeFinalizeEvent::finalize_from_record( diff --git a/core/src/runtime/record.rs b/core/src/runtime/record.rs index 67a2464f9..006c235a1 100644 --- a/core/src/runtime/record.rs +++ b/core/src/runtime/record.rs @@ -17,7 +17,6 @@ use crate::cpu::CpuEvent; use crate::runtime::MemoryInitializeFinalizeEvent; use crate::runtime::MemoryRecordEnum; use crate::stark::MachineRecord; -use crate::syscall::precompiles::blake3::Blake3CompressInnerEvent; use crate::syscall::precompiles::edwards::EdDecompressEvent; use crate::syscall::precompiles::keccak256::KeccakPermuteEvent; use crate::syscall::precompiles::sha256::{ShaCompressEvent, ShaExtendEvent}; @@ -87,8 +86,6 @@ pub struct ExecutionRecord { pub k256_decompress_events: Vec, - pub blake3_compress_inner_events: Vec, - pub bls12381_add_events: Vec, pub bls12381_double_events: Vec, @@ -103,6 +100,8 @@ pub struct ExecutionRecord { /// The public values. pub public_values: PublicValues, + + pub nonce_lookup: HashMap, } pub struct ShardingConfig { @@ -220,10 +219,6 @@ impl MachineRecord for ExecutionRecord { "k256_decompress_events".to_string(), self.k256_decompress_events.len(), ); - stats.insert( - "blake3_compress_inner_events".to_string(), - self.blake3_compress_inner_events.len(), - ); stats.insert( "bls12381_add_events".to_string(), self.bls12381_add_events.len(), @@ -272,8 +267,6 @@ impl MachineRecord for ExecutionRecord { .append(&mut other.bn254_double_events); self.k256_decompress_events .append(&mut other.k256_decompress_events); - self.blake3_compress_inner_events - .append(&mut other.blake3_compress_inner_events); self.bls12381_add_events .append(&mut other.bls12381_add_events); self.bls12381_double_events @@ -356,22 +349,15 @@ impl MachineRecord for ExecutionRecord { } } - // Shard all the other events according to the configuration. - // Shard the ADD events. for (add_chunk, shard) in take(&mut self.add_events) .chunks_mut(config.add_len) .zip(shards.iter_mut()) { shard.add_events.extend_from_slice(add_chunk); - } - - // Shard the MUL events. - for (mul_chunk, shard) in take(&mut self.mul_events) - .chunks_mut(config.mul_len) - .zip(shards.iter_mut()) - { - shard.mul_events.extend_from_slice(mul_chunk); + for (i, event) in add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the SUB events. @@ -380,6 +366,21 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.sub_events.extend_from_slice(sub_chunk); + for (i, event) in sub_chunk.iter().enumerate() { + self.nonce_lookup + .insert(event.lookup_id, shard.add_events.len() as u32 + i as u32); + } + } + + // Shard the MUL events. + for (mul_chunk, shard) in take(&mut self.mul_events) + .chunks_mut(config.mul_len) + .zip(shards.iter_mut()) + { + shard.mul_events.extend_from_slice(mul_chunk); + for (i, event) in mul_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the bitwise events. @@ -388,6 +389,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.bitwise_events.extend_from_slice(bitwise_chunk); + for (i, event) in bitwise_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the shift left events. @@ -396,6 +400,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.shift_left_events.extend_from_slice(shift_left_chunk); + for (i, event) in shift_left_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the shift right events. @@ -406,6 +413,9 @@ impl MachineRecord for ExecutionRecord { shard .shift_right_events .extend_from_slice(shift_right_chunk); + for (i, event) in shift_right_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the divrem events. @@ -414,6 +424,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.divrem_events.extend_from_slice(divrem_chunk); + for (i, event) in divrem_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Shard the LT events. @@ -422,6 +435,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.lt_events.extend_from_slice(lt_chunk); + for (i, event) in lt_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Keccak-256 permute events. @@ -430,6 +446,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.keccak_permute_events.extend_from_slice(keccak_chunk); + for (i, event) in keccak_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, (i * 24) as u32); + } } // secp256k1 curve add events. @@ -440,6 +459,9 @@ impl MachineRecord for ExecutionRecord { shard .secp256k1_add_events .extend_from_slice(secp256k1_add_chunk); + for (i, event) in secp256k1_add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // secp256k1 curve double events. @@ -450,6 +472,9 @@ impl MachineRecord for ExecutionRecord { shard .secp256k1_double_events .extend_from_slice(secp256k1_double_chunk); + for (i, event) in secp256k1_double_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // bn254 curve add events. @@ -458,6 +483,9 @@ impl MachineRecord for ExecutionRecord { .zip(shards.iter_mut()) { shard.bn254_add_events.extend_from_slice(bn254_add_chunk); + for (i, event) in bn254_add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // bn254 curve double events. @@ -468,6 +496,9 @@ impl MachineRecord for ExecutionRecord { shard .bn254_double_events .extend_from_slice(bn254_double_chunk); + for (i, event) in bn254_double_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // BLS12-381 curve add events. @@ -478,6 +509,9 @@ impl MachineRecord for ExecutionRecord { shard .bls12381_add_events .extend_from_slice(bls12381_add_chunk); + for (i, event) in bls12381_add_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // BLS12-381 curve double events. @@ -488,6 +522,9 @@ impl MachineRecord for ExecutionRecord { shard .bls12381_double_events .extend_from_slice(bls12381_double_chunk); + for (i, event) in bls12381_double_chunk.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } } // Put the precompile events in the first shard. @@ -495,27 +532,45 @@ impl MachineRecord for ExecutionRecord { // SHA-256 extend events. first.sha_extend_events = std::mem::take(&mut self.sha_extend_events); + for (i, event) in first.sha_extend_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, (i * 48) as u32); + } // SHA-256 compress events. first.sha_compress_events = std::mem::take(&mut self.sha_compress_events); + for (i, event) in first.sha_compress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, (i * 80) as u32); + } // Edwards curve add events. first.ed_add_events = std::mem::take(&mut self.ed_add_events); + for (i, event) in first.ed_add_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Edwards curve decompress events. first.ed_decompress_events = std::mem::take(&mut self.ed_decompress_events); + for (i, event) in first.ed_decompress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // K256 curve decompress events. first.k256_decompress_events = std::mem::take(&mut self.k256_decompress_events); - - // Blake3 compress events . - first.blake3_compress_inner_events = std::mem::take(&mut self.blake3_compress_inner_events); + for (i, event) in first.k256_decompress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Uint256 mul arithmetic events. first.uint256_mul_events = std::mem::take(&mut self.uint256_mul_events); + for (i, event) in first.uint256_mul_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Bls12-381 decompress events . first.bls12381_decompress_events = std::mem::take(&mut self.bls12381_decompress_events); + for (i, event) in first.bls12381_decompress_events.iter().enumerate() { + self.nonce_lookup.insert(event.lookup_id, i as u32); + } // Put the memory records in the last shard. let last_shard = shards.last_mut().unwrap(); @@ -527,6 +582,11 @@ impl MachineRecord for ExecutionRecord { .memory_finalize_events .extend_from_slice(&self.memory_finalize_events); + // Copy the nonce lookup to all shards. + for shard in shards.iter_mut() { + shard.nonce_lookup.clone_from(&self.nonce_lookup); + } + shards } diff --git a/core/src/runtime/syscall.rs b/core/src/runtime/syscall.rs index c320c7e28..7cb1cb338 100644 --- a/core/src/runtime/syscall.rs +++ b/core/src/runtime/syscall.rs @@ -5,7 +5,6 @@ use std::sync::Arc; use strum_macros::EnumIter; use crate::runtime::{Register, Runtime}; -use crate::stark::Blake3CompressInnerChip; use crate::syscall::precompiles::edwards::EdAddAssignChip; use crate::syscall::precompiles::edwards::EdDecompressChip; use crate::syscall::precompiles::keccak256::KeccakPermuteChip; @@ -68,9 +67,6 @@ pub enum SyscallCode { /// Executes the `SECP256K1_DECOMPRESS` precompile. SECP256K1_DECOMPRESS = 0x00_00_01_0C, - /// Executes the `BLAKE3_COMPRESS_INNER` precompile. - BLAKE3_COMPRESS_INNER = 0x00_38_01_0D, - /// Executes the `BN254_ADD` precompile. BN254_ADD = 0x00_01_01_0E, @@ -121,7 +117,6 @@ impl SyscallCode { 0x00_01_01_0A => SyscallCode::SECP256K1_ADD, 0x00_00_01_0B => SyscallCode::SECP256K1_DOUBLE, 0x00_00_01_0C => SyscallCode::SECP256K1_DECOMPRESS, - 0x00_38_01_0D => SyscallCode::BLAKE3_COMPRESS_INNER, 0x00_01_01_0E => SyscallCode::BN254_ADD, 0x00_00_01_0F => SyscallCode::BN254_DOUBLE, 0x00_01_01_1E => SyscallCode::BLS12381_ADD, @@ -180,6 +175,7 @@ pub struct SyscallContext<'a> { /// This is the exit_code used for the HALT syscall pub(crate) exit_code: u32, pub(crate) rt: &'a mut Runtime, + pub syscall_lookup_id: usize, } impl<'a> SyscallContext<'a> { @@ -192,6 +188,7 @@ impl<'a> SyscallContext<'a> { next_pc: runtime.state.pc.wrapping_add(4), exit_code: 0, rt: runtime, + syscall_lookup_id: 0, } } @@ -304,10 +301,6 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BN254_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), ); - syscall_map.insert( - SyscallCode::BLAKE3_COMPRESS_INNER, - Arc::new(Blake3CompressInnerChip::new()), - ); syscall_map.insert( SyscallCode::BLS12381_ADD, Arc::new(WeierstrassAddAssignChip::::new()), @@ -316,10 +309,6 @@ pub fn default_syscall_map() -> HashMap> { SyscallCode::BLS12381_DOUBLE, Arc::new(WeierstrassDoubleAssignChip::::new()), ); - syscall_map.insert( - SyscallCode::BLAKE3_COMPRESS_INNER, - Arc::new(Blake3CompressInnerChip::new()), - ); syscall_map.insert(SyscallCode::UINT256_MUL, Arc::new(Uint256MulChip::new())); syscall_map.insert( SyscallCode::ENTER_UNCONSTRAINED, @@ -359,10 +348,6 @@ mod tests { fn test_syscalls_in_default_map() { let default_syscall_map = default_syscall_map(); for code in SyscallCode::iter() { - if code == SyscallCode::BLAKE3_COMPRESS_INNER { - // Blake3 is currently disabled. - continue; - } default_syscall_map.get(&code).unwrap(); } } @@ -412,9 +397,6 @@ mod tests { SyscallCode::SECP256K1_DOUBLE => { assert_eq!(code as u32, sp1_zkvm::syscalls::SECP256K1_DOUBLE) } - SyscallCode::BLAKE3_COMPRESS_INNER => { - assert_eq!(code as u32, sp1_zkvm::syscalls::BLAKE3_COMPRESS_INNER) - } SyscallCode::BLS12381_ADD => { assert_eq!(code as u32, sp1_zkvm::syscalls::BLS12381_ADD) } diff --git a/core/src/stark/air.rs b/core/src/stark/air.rs index dc181b18d..558ccec5a 100644 --- a/core/src/stark/air.rs +++ b/core/src/stark/air.rs @@ -21,7 +21,6 @@ pub(crate) mod riscv_chips { pub use crate::cpu::CpuChip; pub use crate::memory::MemoryChip; pub use crate::program::ProgramChip; - pub use crate::syscall::precompiles::blake3::Blake3CompressInnerChip; pub use crate::syscall::precompiles::edwards::EdAddAssignChip; pub use crate::syscall::precompiles::edwards::EdDecompressChip; pub use crate::syscall::precompiles::keccak256::KeccakPermuteChip; @@ -88,8 +87,6 @@ pub enum RiscvAir { Secp256k1Double(WeierstrassDoubleAssignChip>), /// A precompile for the Keccak permutation. KeccakP(KeccakPermuteChip), - /// A precompile for the Blake3 compression function. (Disabled by default.) - Blake3Compress(Blake3CompressInnerChip), /// A precompile for addition on the Elliptic curve bn254. Bn254Add(WeierstrassAddAssignChip>), /// A precompile for doubling a point on the Elliptic curve bn254. @@ -152,12 +149,12 @@ impl RiscvAir { chips.push(RiscvAir::Uint256Mul(uint256_mul)); let bls12381_decompress = WeierstrassDecompressChip::>::new(); chips.push(RiscvAir::Bls12381Decompress(bls12381_decompress)); + let div_rem = DivRemChip::default(); + chips.push(RiscvAir::DivRem(div_rem)); let add = AddSubChip::default(); chips.push(RiscvAir::Add(add)); let bitwise = BitwiseChip::default(); chips.push(RiscvAir::Bitwise(bitwise)); - let div_rem = DivRemChip::default(); - chips.push(RiscvAir::DivRem(div_rem)); let mul = MulChip::default(); chips.push(RiscvAir::Mul(mul)); let shift_right = ShiftRightChip::default(); diff --git a/core/src/stark/chip.rs b/core/src/stark/chip.rs index 4a3164643..7248e73a4 100644 --- a/core/src/stark/chip.rs +++ b/core/src/stark/chip.rs @@ -61,12 +61,24 @@ where where A: MachineAir + Air> + Air>, { - // Todo: correct values let mut builder = InteractionBuilder::new(air.preprocessed_width(), air.width()); air.eval(&mut builder); let (sends, receives) = builder.interactions(); - // TODO: enable different numbers of public values. + let nb_byte_sends = sends + .iter() + .filter(|s| s.kind == InteractionKind::Byte) + .count(); + let nb_byte_receives = receives + .iter() + .filter(|r| r.kind == InteractionKind::Byte) + .count(); + tracing::debug!( + "chip {} has {} byte interactions", + air.name(), + nb_byte_sends + nb_byte_receives + ); + let mut max_constraint_degree = get_max_constraint_degree(&air, air.preprocessed_width(), PROOF_MAX_NUM_PVS); diff --git a/core/src/stark/machine.rs b/core/src/stark/machine.rs index 672226030..9f61460a0 100644 --- a/core/src/stark/machine.rs +++ b/core/src/stark/machine.rs @@ -473,6 +473,8 @@ pub enum MachineVerificationError { DebugInteractionsFailed, EmptyProof, InvalidPublicValues(&'static str), + TooManyShards, + InvalidChipOccurence(String), } impl Debug for MachineVerificationError { @@ -499,6 +501,12 @@ impl Debug for MachineVerificationError { MachineVerificationError::InvalidPublicValues(s) => { write!(f, "Invalid public values: {}", s) } + MachineVerificationError::TooManyShards => { + write!(f, "Too many shards") + } + MachineVerificationError::InvalidChipOccurence(s) => { + write!(f, "Invalid chip occurence: {}", s) + } } } } diff --git a/core/src/syscall/precompiles/blake3/compress/air.rs b/core/src/syscall/precompiles/blake3/compress/air.rs deleted file mode 100644 index a5876866e..000000000 --- a/core/src/syscall/precompiles/blake3/compress/air.rs +++ /dev/null @@ -1,235 +0,0 @@ -use core::borrow::Borrow; - -use p3_air::{Air, AirBuilder, BaseAir}; -use p3_field::AbstractField; -use p3_matrix::Matrix; - -use super::columns::{Blake3CompressInnerCols, NUM_BLAKE3_COMPRESS_INNER_COLS}; -use super::g::GOperation; -use super::{ - Blake3CompressInnerChip, G_INDEX, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, - NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, ROUND_COUNT, -}; -use crate::air::{BaseAirBuilder, SP1AirBuilder, WORD_SIZE}; -use crate::runtime::SyscallCode; - -impl BaseAir for Blake3CompressInnerChip { - fn width(&self) -> usize { - NUM_BLAKE3_COMPRESS_INNER_COLS - } -} - -impl Air for Blake3CompressInnerChip -where - AB: SP1AirBuilder, -{ - fn eval(&self, builder: &mut AB) { - let main = builder.main(); - let (local, next) = (main.row_slice(0), main.row_slice(1)); - let local: &Blake3CompressInnerCols = (*local).borrow(); - let next: &Blake3CompressInnerCols = (*next).borrow(); - - self.constrain_control_flow_flags(builder, local, next); - - self.constrain_memory(builder, local); - - self.constrain_g_operation(builder, local); - - // TODO: constraint ecall_receive column. - // TODO: constraint clk column to increment by 1 within same invocation of syscall. - builder.receive_syscall( - local.shard, - local.channel, - local.clk, - AB::F::from_canonical_u32(SyscallCode::BLAKE3_COMPRESS_INNER.syscall_id()), - local.state_ptr, - local.message_ptr, - local.ecall_receive, - ); - } -} - -impl Blake3CompressInnerChip { - /// Constrains the given index is correct for the given selector. The `selector` is an - /// `n`-dimensional boolean array whose `i`-th element is true if and only if the index is `i`. - fn constrain_index_selector( - &self, - builder: &mut AB, - selector: &[AB::Var], - index: AB::Var, - is_real: AB::Var, - ) { - let mut acc: AB::Expr = AB::F::zero().into(); - for i in 0..selector.len() { - acc += selector[i].into(); - builder.assert_bool(selector[i]) - } - builder - .when(is_real) - .assert_eq(acc, AB::F::from_canonical_usize(1)); - for i in 0..selector.len() { - builder - .when(selector[i]) - .assert_eq(index, AB::F::from_canonical_usize(i)); - } - } - - /// Constrains the control flow flags such as the operation index and the round index. - fn constrain_control_flow_flags( - &self, - builder: &mut AB, - local: &Blake3CompressInnerCols, - next: &Blake3CompressInnerCols, - ) { - // If this is the i-th operation, then the next row should be the (i+1)-th operation. - for i in 0..OPERATION_COUNT { - builder.when_transition().when(next.is_real).assert_eq( - local.is_operation_index_n[i], - next.is_operation_index_n[(i + 1) % OPERATION_COUNT], - ); - } - - // If this is the last operation, the round index should be incremented. Otherwise, the - // round index should remain the same. - for i in 0..OPERATION_COUNT { - if i + 1 < OPERATION_COUNT { - builder - .when_transition() - .when(local.is_operation_index_n[i]) - .assert_eq(local.round_index, next.round_index); - } else { - builder - .when_transition() - .when(local.is_operation_index_n[i]) - .when_not(local.is_round_index_n[ROUND_COUNT - 1]) - .assert_eq( - local.round_index + AB::F::from_canonical_u16(1), - next.round_index, - ); - - builder - .when_transition() - .when(local.is_operation_index_n[i]) - .when(local.is_round_index_n[ROUND_COUNT - 1]) - .assert_zero(next.round_index); - } - } - } - - /// Constrain the memory access for the state and the message. - fn constrain_memory( - &self, - builder: &mut AB, - local: &Blake3CompressInnerCols, - ) { - // Calculate the 4 indices to read from the state. This corresponds to a, b, c, and d. - for i in 0..NUM_STATE_WORDS_PER_CALL { - let index_to_read = { - self.constrain_index_selector( - builder, - &local.is_operation_index_n, - local.operation_index, - local.is_real, - ); - - let mut acc = AB::Expr::from_canonical_usize(0); - for operation in 0..OPERATION_COUNT { - acc += AB::Expr::from_canonical_usize(G_INDEX[operation][i]) - * local.is_operation_index_n[operation]; - } - acc - }; - builder.assert_eq(local.state_index[i], index_to_read); - } - - // Read & write the state. - for i in 0..NUM_STATE_WORDS_PER_CALL { - builder.eval_memory_access( - local.shard, - local.channel, - local.clk, - local.state_ptr + local.state_index[i] * AB::F::from_canonical_usize(WORD_SIZE), - &local.state_reads_writes[i], - local.is_real, - ); - } - - // Calculate the indices to read from the message. - for i in 0..NUM_MSG_WORDS_PER_CALL { - let index_to_read = { - self.constrain_index_selector( - builder, - &local.is_round_index_n, - local.round_index, - local.is_real, - ); - - let mut acc = AB::Expr::from_canonical_usize(0); - - for round in 0..ROUND_COUNT { - for operation in 0..OPERATION_COUNT { - acc += - AB::Expr::from_canonical_usize(MSG_SCHEDULE[round][2 * operation + i]) - * local.is_operation_index_n[operation] - * local.is_round_index_n[round]; - } - } - acc - }; - builder.assert_eq(local.msg_schedule[i], index_to_read); - } - - // Read the message. - for i in 0..NUM_MSG_WORDS_PER_CALL { - builder.eval_memory_access( - local.shard, - local.channel, - local.clk, - local.message_ptr + local.msg_schedule[i] * AB::F::from_canonical_usize(WORD_SIZE), - &local.message_reads[i], - local.is_real, - ); - } - } - - /// Constrains the input and the output of the `g` operation. - fn constrain_g_operation( - &self, - builder: &mut AB, - local: &Blake3CompressInnerCols, - ) { - builder.assert_bool(local.is_real); - - // Call g and write the result to the state. - { - let input = [ - local.state_reads_writes[0].prev_value, - local.state_reads_writes[1].prev_value, - local.state_reads_writes[2].prev_value, - local.state_reads_writes[3].prev_value, - local.message_reads[0].access.value, - local.message_reads[1].access.value, - ]; - - // Call the g function. - GOperation::::eval( - builder, - input, - local.g, - local.shard, - local.channel, - local.is_real, - ); - - // Finally, the results of the g function should be written to the memory. - for i in 0..NUM_STATE_WORDS_PER_CALL { - for j in 0..WORD_SIZE { - builder.when(local.is_real).assert_eq( - local.state_reads_writes[i].access.value[j], - local.g.result[i][j], - ); - } - } - } - } -} diff --git a/core/src/syscall/precompiles/blake3/compress/columns.rs b/core/src/syscall/precompiles/blake3/compress/columns.rs deleted file mode 100644 index bf7bbe4e1..000000000 --- a/core/src/syscall/precompiles/blake3/compress/columns.rs +++ /dev/null @@ -1,55 +0,0 @@ -use std::mem::size_of; - -use sp1_derive::AlignedBorrow; - -use crate::memory::MemoryReadCols; -use crate::memory::MemoryReadWriteCols; - -use super::g::GOperation; -use super::NUM_MSG_WORDS_PER_CALL; -use super::NUM_STATE_WORDS_PER_CALL; -use super::OPERATION_COUNT; -use super::ROUND_COUNT; - -pub const NUM_BLAKE3_COMPRESS_INNER_COLS: usize = size_of::>(); - -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -#[repr(C)] -pub struct Blake3CompressInnerCols { - pub shard: T, - pub channel: T, - pub clk: T, - pub ecall_receive: T, - - /// The pointer to the state. - pub state_ptr: T, - - /// The pointer to the message. - pub message_ptr: T, - - /// Reads and writes a part of the state. - pub state_reads_writes: [MemoryReadWriteCols; NUM_STATE_WORDS_PER_CALL], - - /// Reads a part of the message. - pub message_reads: [MemoryReadCols; NUM_MSG_WORDS_PER_CALL], - - /// Indicates which call of `g` is being performed. - pub operation_index: T, - pub is_operation_index_n: [T; OPERATION_COUNT], - - /// Indicates which call of `round` is being performed. - pub round_index: T, - pub is_round_index_n: [T; ROUND_COUNT], - - /// The indices to pass to `g`. - pub state_index: [T; NUM_STATE_WORDS_PER_CALL], - - /// The two values from `MSG_SCHEDULE` to pass to `g`. - pub msg_schedule: [T; NUM_MSG_WORDS_PER_CALL], - - /// The `g` operation to perform. - pub g: GOperation, - - /// Indicates if the current call is real or not. - pub is_real: T, -} diff --git a/core/src/syscall/precompiles/blake3/compress/execute.rs b/core/src/syscall/precompiles/blake3/compress/execute.rs deleted file mode 100644 index 35298b041..000000000 --- a/core/src/syscall/precompiles/blake3/compress/execute.rs +++ /dev/null @@ -1,76 +0,0 @@ -use crate::runtime::Syscall; -use crate::runtime::{MemoryReadRecord, MemoryWriteRecord}; -use crate::syscall::precompiles::blake3::{ - g_func, Blake3CompressInnerChip, Blake3CompressInnerEvent, G_INDEX, MSG_SCHEDULE, - NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, OPERATION_COUNT, ROUND_COUNT, -}; -use crate::syscall::precompiles::SyscallContext; - -impl Syscall for Blake3CompressInnerChip { - fn num_extra_cycles(&self) -> u32 { - (ROUND_COUNT * OPERATION_COUNT) as u32 - } - - fn execute(&self, rt: &mut SyscallContext, arg1: u32, arg2: u32) -> Option { - let state_ptr = arg1; - let message_ptr = arg2; - - let start_clk = rt.clk; - let mut message_reads = - [[[MemoryReadRecord::default(); NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT]; - let mut state_writes = [[[MemoryWriteRecord::default(); NUM_STATE_WORDS_PER_CALL]; - OPERATION_COUNT]; ROUND_COUNT]; - - for round in 0..ROUND_COUNT { - for operation in 0..OPERATION_COUNT { - let state_index = G_INDEX[operation]; - let message_index: [usize; NUM_MSG_WORDS_PER_CALL] = [ - MSG_SCHEDULE[round][2 * operation], - MSG_SCHEDULE[round][2 * operation + 1], - ]; - - let mut input = vec![]; - // Read the input to g. - { - for index in state_index.iter() { - input.push(rt.word_unsafe(state_ptr + (*index as u32) * 4)); - } - for i in 0..NUM_MSG_WORDS_PER_CALL { - let (record, value) = rt.mr(message_ptr + (message_index[i] as u32) * 4); - message_reads[round][operation][i] = record; - input.push(value); - } - } - - // Call g. - let results = g_func(input.try_into().unwrap()); - - // Write the state. - for i in 0..NUM_STATE_WORDS_PER_CALL { - state_writes[round][operation][i] = - rt.mw(state_ptr + (state_index[i] as u32) * 4, results[i]); - } - - // Increment the clock for the next call of g. - rt.clk += 1; - } - } - - let shard = rt.current_shard(); - let channel = rt.current_channel(); - - rt.record_mut() - .blake3_compress_inner_events - .push(Blake3CompressInnerEvent { - shard, - channel, - clk: start_clk, - state_ptr, - message_reads, - state_writes, - message_ptr, - }); - - None - } -} diff --git a/core/src/syscall/precompiles/blake3/compress/g.rs b/core/src/syscall/precompiles/blake3/compress/g.rs deleted file mode 100644 index 06e8c3034..000000000 --- a/core/src/syscall/precompiles/blake3/compress/g.rs +++ /dev/null @@ -1,277 +0,0 @@ -use p3_field::Field; -use sp1_derive::AlignedBorrow; - -use crate::air::SP1AirBuilder; -use crate::air::Word; -use crate::air::WORD_SIZE; -use crate::operations::AddOperation; -use crate::operations::FixedRotateRightOperation; -use crate::operations::XorOperation; -use crate::runtime::ExecutionRecord; - -use super::g_func; -/// A set of columns needed to compute the `g` of the input state. -/// ``` ignore -/// fn g(state: &mut BlockWords, a: usize, b: usize, c: usize, d: usize, x: u32, y: u32) { -/// state[a] = state[a].wrapping_add(state[b]).wrapping_add(x); -/// state[d] = (state[d] ^ state[a]).rotate_right(16); -/// state[c] = state[c].wrapping_add(state[d]); -/// state[b] = (state[b] ^ state[c]).rotate_right(12); -/// state[a] = state[a].wrapping_add(state[b]).wrapping_add(y); -/// state[d] = (state[d] ^ state[a]).rotate_right(8); -/// state[c] = state[c].wrapping_add(state[d]); -/// state[b] = (state[b] ^ state[c]).rotate_right(7); -/// } -/// ``` -#[derive(AlignedBorrow, Default, Debug, Clone, Copy)] -#[repr(C)] -pub struct GOperation { - pub a_plus_b: AddOperation, - pub a_plus_b_plus_x: AddOperation, - pub d_xor_a: XorOperation, - // Rotate right by 16 bits by just shifting bytes. - pub c_plus_d: AddOperation, - pub b_xor_c: XorOperation, - pub b_xor_c_rotate_right_12: FixedRotateRightOperation, - pub a_plus_b_2: AddOperation, - pub a_plus_b_2_add_y: AddOperation, - // Rotate right by 8 bits by just shifting bytes. - pub d_xor_a_2: XorOperation, - pub c_plus_d_2: AddOperation, - pub b_xor_c_2: XorOperation, - pub b_xor_c_2_rotate_right_7: FixedRotateRightOperation, - /// `state[a]`, `state[b]`, `state[c]`, `state[d]` after all the steps. - pub result: [Word; 4], -} - -impl GOperation { - pub fn populate( - &mut self, - record: &mut ExecutionRecord, - shard: u32, - channel: u32, - input: [u32; 6], - ) -> [u32; 4] { - let mut a = input[0]; - let mut b = input[1]; - let mut c = input[2]; - let mut d = input[3]; - let x = input[4]; - let y = input[5]; - - // First 4 steps. - { - // a = a + b + x. - a = self.a_plus_b.populate(record, shard, channel, a, b); - a = self.a_plus_b_plus_x.populate(record, shard, channel, a, x); - - // d = (d ^ a).rotate_right(16). - d = self.d_xor_a.populate(record, shard, channel, d, a); - d = d.rotate_right(16); - - // c = c + d. - c = self.c_plus_d.populate(record, shard, channel, c, d); - - // b = (b ^ c).rotate_right(12). - b = self.b_xor_c.populate(record, shard, channel, b, c); - b = self - .b_xor_c_rotate_right_12 - .populate(record, shard, channel, b, 12); - } - - // Second 4 steps. - { - // a = a + b + y. - a = self.a_plus_b_2.populate(record, shard, channel, a, b); - a = self.a_plus_b_2_add_y.populate(record, shard, channel, a, y); - - // d = (d ^ a).rotate_right(8). - d = self.d_xor_a_2.populate(record, shard, channel, d, a); - d = d.rotate_right(8); - - // c = c + d. - c = self.c_plus_d_2.populate(record, shard, channel, c, d); - - // b = (b ^ c).rotate_right(7). - b = self.b_xor_c_2.populate(record, shard, channel, b, c); - b = self - .b_xor_c_2_rotate_right_7 - .populate(record, shard, channel, b, 7); - } - - let result = [a, b, c, d]; - assert_eq!(result, g_func(input)); - self.result = result.map(Word::from); - result - } - - pub fn eval( - builder: &mut AB, - input: [Word; 6], - cols: GOperation, - shard: AB::Var, - channel: impl Into + Clone, - is_real: AB::Var, - ) { - builder.assert_bool(is_real); - let mut a = input[0]; - let mut b = input[1]; - let mut c = input[2]; - let mut d = input[3]; - let x = input[4]; - let y = input[5]; - - // First 4 steps. - { - // a = a + b + x. - AddOperation::::eval( - builder, - a, - b, - cols.a_plus_b, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b.value; - AddOperation::::eval( - builder, - a, - x, - cols.a_plus_b_plus_x, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b_plus_x.value; - - // d = (d ^ a).rotate_right(16). - XorOperation::::eval( - builder, - d, - a, - cols.d_xor_a, - shard, - channel.clone(), - is_real, - ); - d = cols.d_xor_a.value; - // Rotate right by 16 bits. - d = Word([d[2], d[3], d[0], d[1]]); - - // c = c + d. - AddOperation::::eval( - builder, - c, - d, - cols.c_plus_d, - shard, - channel.clone(), - is_real.into(), - ); - c = cols.c_plus_d.value; - - // b = (b ^ c).rotate_right(12). - XorOperation::::eval( - builder, - b, - c, - cols.b_xor_c, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c.value; - FixedRotateRightOperation::::eval( - builder, - b, - 12, - cols.b_xor_c_rotate_right_12, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c_rotate_right_12.value; - } - - // Second 4 steps. - { - // a = a + b + y. - AddOperation::::eval( - builder, - a, - b, - cols.a_plus_b_2, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b_2.value; - AddOperation::::eval( - builder, - a, - y, - cols.a_plus_b_2_add_y, - shard, - channel.clone(), - is_real.into(), - ); - a = cols.a_plus_b_2_add_y.value; - - // d = (d ^ a).rotate_right(8). - XorOperation::::eval( - builder, - d, - a, - cols.d_xor_a_2, - shard, - channel.clone(), - is_real, - ); - d = cols.d_xor_a_2.value; - // Rotate right by 8 bits. - d = Word([d[1], d[2], d[3], d[0]]); - - // c = c + d. - AddOperation::::eval( - builder, - c, - d, - cols.c_plus_d_2, - shard, - channel.clone(), - is_real.into(), - ); - c = cols.c_plus_d_2.value; - - // b = (b ^ c).rotate_right(7). - XorOperation::::eval( - builder, - b, - c, - cols.b_xor_c_2, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c_2.value; - FixedRotateRightOperation::::eval( - builder, - b, - 7, - cols.b_xor_c_2_rotate_right_7, - shard, - channel.clone(), - is_real, - ); - b = cols.b_xor_c_2_rotate_right_7.value; - } - - let results = [a, b, c, d]; - for i in 0..4 { - for j in 0..WORD_SIZE { - builder.assert_eq(cols.result[i][j], results[i][j]); - } - } - } -} diff --git a/core/src/syscall/precompiles/blake3/compress/mod.rs b/core/src/syscall/precompiles/blake3/compress/mod.rs deleted file mode 100644 index a89b9bcc3..000000000 --- a/core/src/syscall/precompiles/blake3/compress/mod.rs +++ /dev/null @@ -1,179 +0,0 @@ -//! This module contains the implementation of the `blake3_compress_inner` precompile based on the -//! implementation of the `blake3` hash function in BLAKE3. -//! -//! Pseudo-code. -//! -//! state = [0u32; 16] -//! message = [0u32; 16] -//! -//! for round in 0..7 { -//! for operation in 0..8 { -//! // * Pick 4 indices a, b, c, d for the state, based on the operation index. -//! // * Pick 2 indices x, y for the message, based on both the round and the operation index. -//! // -//! // g takes those 6 values, and updates the 4 state values, at indices a, b, c, d. -//! // -//! // Each call of g becomes one row in the trace. -//! g(&mut state[a], &mut state[b], &mut state[c], &mut state[d], message[x], message[y]); -//! } -//! } -//! -//! Note that this precompile is only the blake3 compress inner function. The Blake3 compress -//! function has a series of 8 XOR operations after the compress inner function. -mod air; -mod columns; -mod execute; -mod g; -mod trace; -use crate::runtime::{MemoryReadRecord, MemoryWriteRecord}; - -use serde::{Deserialize, Serialize}; - -/// The number of `Word`s in the message of the compress inner operation. -pub(crate) const MSG_SIZE: usize = 16; - -/// The number of times we call `round` in the compress inner operation. -pub(crate) const ROUND_COUNT: usize = 7; - -/// The number of times we call `g` in the compress inner operation. -pub(crate) const OPERATION_COUNT: usize = 8; - -/// The number of `Word`s in the state that we pass to `g`. -pub(crate) const NUM_STATE_WORDS_PER_CALL: usize = 4; - -/// The number of `Word`s in the message that we pass to `g`. -pub(crate) const NUM_MSG_WORDS_PER_CALL: usize = 2; - -/// The number of `Word`s in the input of `g`. -pub(crate) const G_INPUT_SIZE: usize = NUM_MSG_WORDS_PER_CALL + NUM_STATE_WORDS_PER_CALL; - -/// 2-dimensional array specifying which message values `g` should access. Values at `(i, 2 * j)` -/// and `(i, 2 * j + 1)` are the indices of the message values that `g` should access in the `j`-th -/// call of the `i`-th round. -pub(crate) const MSG_SCHEDULE: [[usize; MSG_SIZE]; ROUND_COUNT] = [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], - [2, 6, 3, 10, 7, 0, 4, 13, 1, 11, 12, 5, 9, 14, 15, 8], - [3, 4, 10, 12, 13, 2, 7, 14, 6, 5, 9, 0, 11, 15, 8, 1], - [10, 7, 12, 9, 14, 3, 13, 15, 4, 0, 11, 2, 5, 8, 1, 6], - [12, 13, 9, 11, 15, 10, 14, 8, 7, 2, 5, 3, 0, 1, 6, 4], - [9, 14, 11, 5, 8, 12, 15, 1, 13, 3, 0, 10, 2, 6, 4, 7], - [11, 15, 5, 0, 1, 9, 8, 6, 14, 10, 2, 12, 3, 4, 7, 13], -]; - -/// The `i`-th row of `G_INDEX` is the indices used for the `i`-th call to `g`. -pub(crate) const G_INDEX: [[usize; NUM_STATE_WORDS_PER_CALL]; OPERATION_COUNT] = [ - [0, 4, 8, 12], - [1, 5, 9, 13], - [2, 6, 10, 14], - [3, 7, 11, 15], - [0, 5, 10, 15], - [1, 6, 11, 12], - [2, 7, 8, 13], - [3, 4, 9, 14], -]; - -pub(crate) const fn g_func(input: [u32; 6]) -> [u32; 4] { - let mut a = input[0]; - let mut b = input[1]; - let mut c = input[2]; - let mut d = input[3]; - let x = input[4]; - let y = input[5]; - a = a.wrapping_add(b).wrapping_add(x); - d = (d ^ a).rotate_right(16); - c = c.wrapping_add(d); - b = (b ^ c).rotate_right(12); - a = a.wrapping_add(b).wrapping_add(y); - d = (d ^ a).rotate_right(8); - c = c.wrapping_add(d); - b = (b ^ c).rotate_right(7); - [a, b, c, d] -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct Blake3CompressInnerEvent { - pub clk: u32, - pub shard: u32, - pub channel: u32, - pub state_ptr: u32, - pub message_ptr: u32, - pub message_reads: [[[MemoryReadRecord; NUM_MSG_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], - pub state_writes: - [[[MemoryWriteRecord; NUM_STATE_WORDS_PER_CALL]; OPERATION_COUNT]; ROUND_COUNT], -} - -pub struct Blake3CompressInnerChip {} - -impl Blake3CompressInnerChip { - pub const fn new() -> Self { - Self {} - } -} - -#[cfg(test)] -pub mod compress_tests { - use crate::runtime::Instruction; - use crate::runtime::Opcode; - use crate::runtime::Register; - use crate::runtime::SyscallCode; - use crate::Program; - - use super::MSG_SIZE; - - /// The number of `Word`s in the state of the compress inner operation. - const STATE_SIZE: usize = 16; - - pub fn blake3_compress_internal_program() -> Program { - let state_ptr = 100; - let msg_ptr = 500; - let mut instructions = vec![]; - - for i in 0..STATE_SIZE { - // Store 1000 + i in memory for the i-th word of the state. 1000 + i is an arbitrary - // number that is easy to spot while debugging. - instructions.extend(vec![ - Instruction::new(Opcode::ADD, 29, 0, 1000 + i as u32, false, true), - Instruction::new(Opcode::ADD, 30, 0, state_ptr + i as u32 * 4, false, true), - Instruction::new(Opcode::SW, 29, 30, 0, false, true), - ]); - } - for i in 0..MSG_SIZE { - // Store 2000 + i in memory for the i-th word of the message. 2000 + i is an arbitrary - // number that is easy to spot while debugging. - instructions.extend(vec![ - Instruction::new(Opcode::ADD, 29, 0, 2000 + i as u32, false, true), - Instruction::new(Opcode::ADD, 30, 0, msg_ptr + i as u32 * 4, false, true), - Instruction::new(Opcode::SW, 29, 30, 0, false, true), - ]); - } - instructions.extend(vec![ - Instruction::new( - Opcode::ADD, - 5, - 0, - SyscallCode::BLAKE3_COMPRESS_INNER as u32, - false, - true, - ), - Instruction::new(Opcode::ADD, Register::X10 as u32, 0, state_ptr, false, true), - Instruction::new(Opcode::ADD, Register::X11 as u32, 0, msg_ptr, false, true), - Instruction::new(Opcode::ECALL, 5, 10, 11, false, false), - ]); - Program::new(instructions, 0, 0) - } - - // Tests disabled because syscall is not enabled in default runtime/chip configs. - // #[test] - // fn prove_babybear() { - // setup_logger(); - // let program = blake3_compress_internal_program(); - // run_test(program).unwrap(); - // } - - // #[test] - // fn test_blake3_compress_inner_elf() { - // setup_logger(); - // let program = Program::from(BLAKE3_COMPRESS_ELF); - // run_test(program).unwrap(); - // } -} diff --git a/core/src/syscall/precompiles/blake3/compress/trace.rs b/core/src/syscall/precompiles/blake3/compress/trace.rs deleted file mode 100644 index 14994cb03..000000000 --- a/core/src/syscall/precompiles/blake3/compress/trace.rs +++ /dev/null @@ -1,131 +0,0 @@ -use std::borrow::BorrowMut; - -use p3_field::PrimeField32; -use p3_matrix::dense::RowMajorMatrix; - -use super::columns::Blake3CompressInnerCols; -use super::{ - G_INDEX, G_INPUT_SIZE, MSG_SCHEDULE, NUM_MSG_WORDS_PER_CALL, NUM_STATE_WORDS_PER_CALL, - OPERATION_COUNT, -}; -use crate::air::MachineAir; -use crate::bytes::event::ByteRecord; -use crate::runtime::ExecutionRecord; -use crate::runtime::MemoryRecordEnum; -use crate::runtime::Program; -use crate::syscall::precompiles::blake3::compress::columns::NUM_BLAKE3_COMPRESS_INNER_COLS; -use crate::syscall::precompiles::blake3::{Blake3CompressInnerChip, ROUND_COUNT}; -use crate::utils::pad_rows; - -impl MachineAir for Blake3CompressInnerChip { - type Record = ExecutionRecord; - type Program = Program; - - fn name(&self) -> String { - "Blake3CompressInner".to_string() - } - - fn generate_trace( - &self, - input: &ExecutionRecord, - output: &mut ExecutionRecord, - ) -> RowMajorMatrix { - let mut rows = Vec::new(); - - let mut new_byte_lookup_events = Vec::new(); - - for i in 0..input.blake3_compress_inner_events.len() { - let event = input.blake3_compress_inner_events[i].clone(); - let shard = event.shard; - let channel = event.channel; - let mut clk = event.clk; - for round in 0..ROUND_COUNT { - for operation in 0..OPERATION_COUNT { - let mut row = [F::zero(); NUM_BLAKE3_COMPRESS_INNER_COLS]; - let cols: &mut Blake3CompressInnerCols = row.as_mut_slice().borrow_mut(); - - // Assign basic values to the columns. - { - cols.shard = F::from_canonical_u32(event.shard); - cols.channel = F::from_canonical_u32(event.channel); - cols.clk = F::from_canonical_u32(clk); - - cols.round_index = F::from_canonical_u32(round as u32); - cols.is_round_index_n[round] = F::one(); - - cols.operation_index = F::from_canonical_u32(operation as u32); - cols.is_operation_index_n[operation] = F::one(); - - for i in 0..NUM_STATE_WORDS_PER_CALL { - cols.state_index[i] = F::from_canonical_usize(G_INDEX[operation][i]); - } - - for i in 0..NUM_MSG_WORDS_PER_CALL { - cols.msg_schedule[i] = - F::from_canonical_usize(MSG_SCHEDULE[round][2 * operation + i]); - } - - if round == 0 && operation == 0 { - cols.ecall_receive = F::one(); - } - } - - // Memory columns. - { - cols.message_ptr = F::from_canonical_u32(event.message_ptr); - for i in 0..NUM_MSG_WORDS_PER_CALL { - cols.message_reads[i].populate( - channel, - event.message_reads[round][operation][i], - &mut new_byte_lookup_events, - ); - } - - cols.state_ptr = F::from_canonical_u32(event.state_ptr); - for i in 0..NUM_STATE_WORDS_PER_CALL { - cols.state_reads_writes[i].populate( - channel, - MemoryRecordEnum::Write(event.state_writes[round][operation][i]), - &mut new_byte_lookup_events, - ); - } - } - - // Apply the `g` operation. - { - let input: [u32; G_INPUT_SIZE] = [ - event.state_writes[round][operation][0].prev_value, - event.state_writes[round][operation][1].prev_value, - event.state_writes[round][operation][2].prev_value, - event.state_writes[round][operation][3].prev_value, - event.message_reads[round][operation][0].value, - event.message_reads[round][operation][1].value, - ]; - - cols.g.populate(output, shard, channel, input); - } - - clk += 1; - - cols.is_real = F::one(); - - rows.push(row); - } - } - } - - output.add_byte_lookup_events(new_byte_lookup_events); - - pad_rows(&mut rows, || [F::zero(); NUM_BLAKE3_COMPRESS_INNER_COLS]); - - // Convert the trace to a row major matrix. - RowMajorMatrix::new( - rows.into_iter().flatten().collect::>(), - NUM_BLAKE3_COMPRESS_INNER_COLS, - ) - } - - fn included(&self, shard: &Self::Record) -> bool { - !shard.blake3_compress_inner_events.is_empty() - } -} diff --git a/core/src/syscall/precompiles/blake3/mod.rs b/core/src/syscall/precompiles/blake3/mod.rs deleted file mode 100644 index 8b286ad17..000000000 --- a/core/src/syscall/precompiles/blake3/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod compress; - -pub use compress::*; diff --git a/core/src/syscall/precompiles/edwards/ed_add.rs b/core/src/syscall/precompiles/edwards/ed_add.rs index f4e15423a..44d29edb8 100644 --- a/core/src/syscall/precompiles/edwards/ed_add.rs +++ b/core/src/syscall/precompiles/edwards/ed_add.rs @@ -6,6 +6,7 @@ use std::marker::PhantomData; use num::BigUint; use num::Zero; +use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::AbstractField; use p3_field::PrimeField32; @@ -54,6 +55,7 @@ pub struct EdAddAssignCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub p_ptr: T, pub q_ptr: T, pub p_access: [MemoryWriteCols; WORDS_CURVE_POINT], @@ -238,10 +240,19 @@ impl MachineAir for Ed }); // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_ED_ADD_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut EdAddAssignCols = + trace.values[i * NUM_ED_ADD_COLS..(i + 1) * NUM_ED_ADD_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -261,141 +272,150 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &EdAddAssignCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &EdAddAssignCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &EdAddAssignCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); - let x1 = limbs_from_prev_access(&row.p_access[0..8]); - let x2 = limbs_from_prev_access(&row.q_access[0..8]); - let y1 = limbs_from_prev_access(&row.p_access[8..16]); - let y2 = limbs_from_prev_access(&row.q_access[8..16]); + let x1 = limbs_from_prev_access(&local.p_access[0..8]); + let x2 = limbs_from_prev_access(&local.q_access[0..8]); + let y1 = limbs_from_prev_access(&local.p_access[8..16]); + let y2 = limbs_from_prev_access(&local.q_access[8..16]); // x3_numerator = x1 * y2 + x2 * y1. - row.x3_numerator.eval( + local.x3_numerator.eval( builder, &[x1, x2], &[y2, y1], - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // y3_numerator = y1 * y2 + x1 * x2. - row.y3_numerator.eval( + local.y3_numerator.eval( builder, &[y1, x1], &[y2, x2], - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // f = x1 * x2 * y1 * y2. - row.x1_mul_y1.eval( + local.x1_mul_y1.eval( builder, &x1, &y1, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x2_mul_y2.eval( + local.x2_mul_y2.eval( builder, &x2, &y2, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - let x1_mul_y1 = row.x1_mul_y1.result; - let x2_mul_y2 = row.x2_mul_y2.result; - row.f.eval( + let x1_mul_y1 = local.x1_mul_y1.result; + let x2_mul_y2 = local.x2_mul_y2.result; + local.f.eval( builder, &x1_mul_y1, &x2_mul_y2, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // d * f. - let f = row.f.result; + let f = local.f.result; let d_biguint = E::d_biguint(); let d_const = E::BaseField::to_limbs_field::(&d_biguint); - row.d_mul_f.eval( + local.d_mul_f.eval( builder, &f, &d_const, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - let d_mul_f = row.d_mul_f.result; + let d_mul_f = local.d_mul_f.result; // x3 = x3_numerator / (1 + d * f). - row.x3_ins.eval( + local.x3_ins.eval( builder, - &row.x3_numerator.result, + &local.x3_numerator.result, &d_mul_f, true, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // y3 = y3_numerator / (1 - d * f). - row.y3_ins.eval( + local.y3_ins.eval( builder, - &row.y3_numerator.result, + &local.y3_numerator.result, &d_mul_f, false, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // Constraint self.p_access.value = [self.x3_ins.result, self.y3_ins.result] // This is to ensure that p_access is updated with the new value. - let p_access_vec = value_as_limbs(&row.p_access); + let p_access_vec = value_as_limbs(&local.p_access); builder - .when(row.is_real) - .assert_all_eq(row.x3_ins.result, p_access_vec[0..NUM_LIMBS].to_vec()); - builder.when(row.is_real).assert_all_eq( - row.y3_ins.result, + .when(local.is_real) + .assert_all_eq(local.x3_ins.result, p_access_vec[0..NUM_LIMBS].to_vec()); + builder.when(local.is_real).assert_all_eq( + local.y3_ins.result, p_access_vec[NUM_LIMBS..NUM_LIMBS * 2].to_vec(), ); builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk.into(), - row.q_ptr, - &row.q_access, - row.is_real, + local.shard, + local.channel, + local.clk.into(), + local.q_ptr, + &local.q_access, + local.is_real, ); builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk + AB::F::from_canonical_u32(1), - row.p_ptr, - &row.p_access, - row.is_real, + local.shard, + local.channel, + local.clk + AB::F::from_canonical_u32(1), + local.p_ptr, + &local.p_access, + local.is_real, ); builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::ED_ADD.syscall_id()), - row.p_ptr, - row.q_ptr, - row.is_real, + local.p_ptr, + local.q_ptr, + local.is_real, ); } } diff --git a/core/src/syscall/precompiles/edwards/ed_decompress.rs b/core/src/syscall/precompiles/edwards/ed_decompress.rs index be62467c0..a0618137c 100644 --- a/core/src/syscall/precompiles/edwards/ed_decompress.rs +++ b/core/src/syscall/precompiles/edwards/ed_decompress.rs @@ -53,6 +53,7 @@ use super::{WordsFieldElement, WORDS_FIELD_ELEMENT}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct EdDecompressEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -78,6 +79,7 @@ pub struct EdDecompressCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub ptr: T, pub sign: T, pub x_access: GenericArray, WordsFieldElement>, @@ -104,6 +106,13 @@ impl EdDecompressCols { self.channel = F::from_canonical_u32(event.channel); self.clk = F::from_canonical_u32(event.clk); self.ptr = F::from_canonical_u32(event.ptr); + self.nonce = F::from_canonical_u32( + record + .nonce_lookup + .get(&event.lookup_id) + .copied() + .unwrap_or_default(), + ); self.sign = F::from_bool(event.sign); for i in 0..8 { self.x_access[i].populate( @@ -276,6 +285,7 @@ impl EdDecompressCols { self.shard, self.channel, self.clk, + self.nonce, AB::F::from_canonical_u32(SyscallCode::ED_DECOMPRESS.syscall_id()), self.ptr, self.sign, @@ -326,11 +336,13 @@ impl Syscall for EdDecompressChip { let x_memory_records_vec = rt.mw_slice(slice_ptr, &decompressed_x_words); let x_memory_records: [MemoryWriteRecord; 8] = x_memory_records_vec.try_into().unwrap(); + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut() .ed_decompress_events .push(EdDecompressEvent { + lookup_id, shard, channel, clk: start_clk, @@ -390,10 +402,20 @@ impl MachineAir for EdDecompressChip>(), NUM_ED_DECOMPRESS_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut EdDecompressCols = trace.values + [i * NUM_ED_DECOMPRESS_COLS..(i + 1) * NUM_ED_DECOMPRESS_COLS] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -413,9 +435,18 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &EdDecompressCols = (*row).borrow(); - row.eval::(builder); + let local = main.row_slice(0); + let local: &EdDecompressCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &EdDecompressCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + + local.eval::(builder); } } diff --git a/core/src/syscall/precompiles/keccak256/air.rs b/core/src/syscall/precompiles/keccak256/air.rs index 9e67c1249..164761679 100644 --- a/core/src/syscall/precompiles/keccak256/air.rs +++ b/core/src/syscall/precompiles/keccak256/air.rs @@ -32,6 +32,12 @@ where let local: &KeccakMemCols = (*local).borrow(); let next: &KeccakMemCols = (*next).borrow(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let first_step = local.keccak.step_flags[0]; let final_step = local.keccak.step_flags[NUM_ROUNDS - 1]; let not_final_step = AB::Expr::one() - final_step; @@ -68,6 +74,7 @@ where local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::KECCAK_PERMUTE.syscall_id()), local.state_addr, AB::Expr::zero(), @@ -79,6 +86,7 @@ where let mut transition_not_final_builder = transition_builder.when(not_final_step); transition_not_final_builder.assert_eq(local.shard, next.shard); transition_not_final_builder.assert_eq(local.clk, next.clk); + transition_not_final_builder.assert_eq(local.channel, next.channel); transition_not_final_builder.assert_eq(local.state_addr, next.state_addr); transition_not_final_builder.assert_eq(local.is_real, next.is_real); @@ -123,6 +131,16 @@ where } } + // Range check all the values in `state_mem` to be bytes. + for i in 0..STATE_NUM_WORDS { + builder.slice_range_check_u8( + &local.state_mem[i].value().0, + local.shard, + local.channel, + local.do_memory_check, + ); + } + let mut sub_builder = SubAirBuilder::::new(builder, 0..NUM_KECCAK_COLS); diff --git a/core/src/syscall/precompiles/keccak256/columns.rs b/core/src/syscall/precompiles/keccak256/columns.rs index a3e2dd304..ad3aa5f09 100644 --- a/core/src/syscall/precompiles/keccak256/columns.rs +++ b/core/src/syscall/precompiles/keccak256/columns.rs @@ -20,6 +20,7 @@ pub(crate) struct KeccakMemCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub state_addr: T, /// Memory columns for the state. diff --git a/core/src/syscall/precompiles/keccak256/execute.rs b/core/src/syscall/precompiles/keccak256/execute.rs index d6c306c45..eecc747be 100644 --- a/core/src/syscall/precompiles/keccak256/execute.rs +++ b/core/src/syscall/precompiles/keccak256/execute.rs @@ -99,9 +99,11 @@ impl Syscall for KeccakPermuteChip { // Push the Keccak permute event. let shard = rt.current_shard(); let channel = rt.current_channel(); + let lookup_id = rt.syscall_lookup_id; rt.record_mut() .keccak_permute_events .push(KeccakPermuteEvent { + lookup_id, shard, channel, clk: start_clk, diff --git a/core/src/syscall/precompiles/keccak256/mod.rs b/core/src/syscall/precompiles/keccak256/mod.rs index 4110707a8..2b95b8b40 100644 --- a/core/src/syscall/precompiles/keccak256/mod.rs +++ b/core/src/syscall/precompiles/keccak256/mod.rs @@ -15,6 +15,7 @@ const STATE_NUM_WORDS: usize = STATE_SIZE * 2; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct KeccakPermuteEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/keccak256/trace.rs b/core/src/syscall/precompiles/keccak256/trace.rs index 01b07fb74..e4700fe97 100644 --- a/core/src/syscall/precompiles/keccak256/trace.rs +++ b/core/src/syscall/precompiles/keccak256/trace.rs @@ -83,8 +83,12 @@ impl MachineAir for KeccakPermuteChip { *read_record, &mut new_byte_lookup_events, ); + new_byte_lookup_events.add_u8_range_checks( + shard, + channel, + &read_record.value.to_le_bytes(), + ); } - cols.do_memory_check = F::one(); cols.receive_ecall = F::one(); } @@ -99,8 +103,12 @@ impl MachineAir for KeccakPermuteChip { *write_record, &mut new_byte_lookup_events, ); + new_byte_lookup_events.add_u8_range_checks( + shard, + channel, + &write_record.value.to_le_bytes(), + ); } - cols.do_memory_check = F::one(); } @@ -147,10 +155,19 @@ impl MachineAir for KeccakPermuteChip { } // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_KECCAK_MEM_COLS, - ) + ); + + // Write the nonce to the trace. + for i in 0..trace.height() { + let cols: &mut KeccakMemCols = + trace.values[i * NUM_KECCAK_MEM_COLS..(i + 1) * NUM_KECCAK_MEM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/mod.rs b/core/src/syscall/precompiles/mod.rs index 7b107e2d5..bc08c6856 100644 --- a/core/src/syscall/precompiles/mod.rs +++ b/core/src/syscall/precompiles/mod.rs @@ -1,4 +1,3 @@ -pub mod blake3; pub mod edwards; pub mod keccak256; pub mod sha256; @@ -20,6 +19,7 @@ use serde::{Deserialize, Serialize}; /// Elliptic curve add event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECAddEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -67,7 +67,9 @@ pub fn create_ec_add_event( let p_memory_records = rt.mw_slice(p_ptr, &result_words); + println!("ec-add lookup id {:?}", rt.syscall_lookup_id); ECAddEvent { + lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, @@ -83,6 +85,7 @@ pub fn create_ec_add_event( /// Elliptic curve double event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDoubleEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -119,6 +122,7 @@ pub fn create_ec_double_event( let p_memory_records = rt.mw_slice(p_ptr, &result_words); ECDoubleEvent { + lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, @@ -131,6 +135,7 @@ pub fn create_ec_double_event( /// Elliptic curve point decompress event. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ECDecompressEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -176,6 +181,7 @@ pub fn create_ec_decompress_event( let y_memory_records = rt.mw_slice(slice_ptr, &y_words); ECDecompressEvent { + lookup_id: rt.syscall_lookup_id, shard: rt.current_shard(), channel: rt.current_channel(), clk: start_clk, diff --git a/core/src/syscall/precompiles/sha256/compress/air.rs b/core/src/syscall/precompiles/sha256/compress/air.rs index 2f4bd5000..7a28a456b 100644 --- a/core/src/syscall/precompiles/sha256/compress/air.rs +++ b/core/src/syscall/precompiles/sha256/compress/air.rs @@ -30,6 +30,12 @@ where let local: &ShaCompressCols = (*local).borrow(); let next: &ShaCompressCols = (*next).borrow(); + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + self.eval_control_flow_flags(builder, local, next); self.eval_memory(builder, local); @@ -46,6 +52,7 @@ where local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::SHA_COMPRESS.syscall_id()), local.w_ptr, local.h_ptr, @@ -71,19 +78,15 @@ impl ShaCompressChip { for i in 0..8 { octet_sum += local.octet[i].into(); } - builder.when(local.is_real).assert_one(octet_sum); + builder.assert_one(octet_sum); // Verify that the first row's octet value is correct. - builder - .when_first_row() - .when(local.is_real) - .assert_one(local.octet[0]); + builder.when_first_row().assert_one(local.octet[0]); // Verify correct transition for octet column. for i in 0..8 { builder .when_transition() - .when(next.is_real) .when(local.octet[i]) .assert_one(next.octet[(i + 1) % 8]) } @@ -98,19 +101,15 @@ impl ShaCompressChip { for i in 0..10 { octet_num_sum += local.octet_num[i].into(); } - builder.when(local.is_real).assert_one(octet_num_sum); + builder.assert_one(octet_num_sum); // The first row should have octet_num[0] = 1 if it's real. - builder - .when_first_row() - .when(local.is_real) - .assert_one(local.octet_num[0]); + builder.when_first_row().assert_one(local.octet_num[0]); // If current row is not last of an octet and next row is real, octet_num should be the same. for i in 0..10 { builder .when_transition() - .when(next.is_real) .when_not(local.octet[7]) .assert_eq(local.octet_num[i], next.octet_num[i]); } @@ -119,7 +118,6 @@ impl ShaCompressChip { for i in 0..10 { builder .when_transition() - .when(next.is_real) .when(local.octet[7]) .assert_eq(local.octet_num[i], next.octet_num[(i + 1) % 10]); } @@ -146,19 +144,26 @@ impl ShaCompressChip { .assert_word_eq(*var, *local.mem.value()); } + // Assert that the is_initialize flag is correct. + builder.assert_eq(local.is_initialize, local.octet_num[0] * local.is_real); + // Assert that the is_compression flag is correct. builder.assert_eq( local.is_compression, - local.octet_num[1] + (local.octet_num[1] + local.octet_num[2] + local.octet_num[3] + local.octet_num[4] + local.octet_num[5] + local.octet_num[6] + local.octet_num[7] - + local.octet_num[8], + + local.octet_num[8]) + * local.is_real, ); + // Assert that the is_finalize flag is correct. + builder.assert_eq(local.is_finalize, local.octet_num[9] * local.is_real); + builder.assert_eq( local.is_last_row.into(), local.octet[7] * local.octet_num[9], @@ -175,6 +180,10 @@ impl ShaCompressChip { .when(local.is_real) .when_not(local.is_last_row) .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when_not(local.is_last_row) + .assert_eq(local.channel, next.channel); builder .when_transition() .when(local.is_real) @@ -186,6 +195,9 @@ impl ShaCompressChip { .when_not(local.is_last_row) .assert_eq(local.h_ptr, next.h_ptr); + // Assert that is_real is a bool. + builder.assert_bool(local.is_real); + // If this row is real and not the last cycle, then next row should also be real. builder .when_transition() @@ -193,6 +205,12 @@ impl ShaCompressChip { .when_not(local.is_last_row) .assert_one(next.is_real); + // Once the is_real flag is changed to false, it should not be changed back. + builder + .when_transition() + .when_not(local.is_real) + .assert_zero(next.is_real); + // Assert that the table ends in nonreal columns. Since each compress ecall is 80 cycles and // the table is padded to a power of 2, the last row of the table should always be padding. builder.when_last_row().assert_zero(local.is_real); @@ -200,15 +218,13 @@ impl ShaCompressChip { /// Constrains that memory address is correct and that memory is correctly written/read. fn eval_memory(&self, builder: &mut AB, local: &ShaCompressCols) { - let is_initialize = local.octet_num[0]; - let is_finalize = local.octet_num[9]; builder.eval_memory_access( local.shard, local.channel, - local.clk + is_finalize, + local.clk + local.is_finalize, local.mem_addr, &local.mem, - is_initialize + local.is_compression + is_finalize, + local.is_initialize + local.is_compression + local.is_finalize, ); // Calculate the current cycle_num. @@ -224,7 +240,7 @@ impl ShaCompressChip { } // Verify correct mem address for initialize phase - builder.when(is_initialize).assert_eq( + builder.when(local.is_initialize).assert_eq( local.mem_addr, local.h_ptr + cycle_step.clone() * AB::Expr::from_canonical_u32(4), ); @@ -239,7 +255,7 @@ impl ShaCompressChip { ); // Verify correct mem address for finalize phase - builder.when(is_finalize).assert_eq( + builder.when(local.is_finalize).assert_eq( local.mem_addr, local.h_ptr + cycle_step.clone() * AB::Expr::from_canonical_u32(4), ); @@ -251,11 +267,11 @@ impl ShaCompressChip { ]; for (i, var) in vars.iter().enumerate() { builder - .when(is_initialize) + .when(local.is_initialize) .when(local.octet[i]) .assert_word_eq(*var, *local.mem.prev_value()); builder - .when(is_initialize) + .when(local.is_initialize) .when(local.octet[i]) .assert_word_eq(*var, *local.mem.value()); } @@ -267,7 +283,7 @@ impl ShaCompressChip { // In the finalize phase, verify that the correct value is written to memory. builder - .when(is_finalize) + .when(local.is_finalize) .assert_word_eq(*local.mem.value(), local.finalize_add.value); } @@ -579,7 +595,6 @@ impl ShaCompressChip { builder: &mut AB, local: &ShaCompressCols, ) { - let is_finalize = local.octet_num[9]; // In the finalize phase, need to execute h[0] + a, h[1] + b, ..., h[7] + h, for each of the // phase's 8 rows. // We can get the needed operand (a,b,c,...,h) by doing an inner product between octet and @@ -596,7 +611,7 @@ impl ShaCompressChip { } builder - .when(is_finalize) + .when(local.is_finalize) .assert_word_eq(filtered_operand, local.finalized_operand.map(|x| x.into())); // finalize_add.result = h[i] + finalized_operand @@ -607,7 +622,7 @@ impl ShaCompressChip { local.finalize_add, local.shard, local.channel, - is_finalize.into(), + local.is_finalize.into(), ); // Memory write is constrained in constrain_memory. diff --git a/core/src/syscall/precompiles/sha256/compress/columns.rs b/core/src/syscall/precompiles/sha256/compress/columns.rs index 94a200aed..0fd7a7fbf 100644 --- a/core/src/syscall/precompiles/sha256/compress/columns.rs +++ b/core/src/syscall/precompiles/sha256/compress/columns.rs @@ -26,6 +26,7 @@ pub struct ShaCompressCols { /// Inputs. pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub w_ptr: T, pub h_ptr: T, @@ -102,7 +103,9 @@ pub struct ShaCompressCols { pub finalized_operand: Word, pub finalize_add: AddOperation, + pub is_initialize: T, pub is_compression: T, + pub is_finalize: T, pub is_last_row: T, pub is_real: T, diff --git a/core/src/syscall/precompiles/sha256/compress/execute.rs b/core/src/syscall/precompiles/sha256/compress/execute.rs index a019abbd4..5ed33dd2b 100644 --- a/core/src/syscall/precompiles/sha256/compress/execute.rs +++ b/core/src/syscall/precompiles/sha256/compress/execute.rs @@ -76,9 +76,11 @@ impl Syscall for ShaCompressChip { } // Push the SHA extend event. + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut().sha_compress_events.push(ShaCompressEvent { + lookup_id, shard, channel, clk: start_clk, diff --git a/core/src/syscall/precompiles/sha256/compress/mod.rs b/core/src/syscall/precompiles/sha256/compress/mod.rs index fd6c50f0f..47401a25b 100644 --- a/core/src/syscall/precompiles/sha256/compress/mod.rs +++ b/core/src/syscall/precompiles/sha256/compress/mod.rs @@ -20,6 +20,7 @@ pub const SHA_COMPRESS_K: [u32; 64] = [ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaCompressEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/sha256/compress/trace.rs b/core/src/syscall/precompiles/sha256/compress/trace.rs index bd0b8f817..6cd524fbd 100644 --- a/core/src/syscall/precompiles/sha256/compress/trace.rs +++ b/core/src/syscall/precompiles/sha256/compress/trace.rs @@ -2,6 +2,7 @@ use std::borrow::BorrowMut; use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; use super::{ columns::{ShaCompressCols, NUM_SHA_COMPRESS_COLS}, @@ -53,6 +54,7 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); + cols.is_initialize = F::one(); cols.mem.populate_read( channel, @@ -207,6 +209,7 @@ impl MachineAir for ShaCompressChip { cols.octet[j] = F::one(); cols.octet_num[octet_num_idx] = F::one(); + cols.is_finalize = F::one(); cols.finalize_add .populate(output, shard, channel, og_h[j], event.h[j]); @@ -249,13 +252,48 @@ impl MachineAir for ShaCompressChip { output.add_byte_lookup_events(new_byte_lookup_events); + let num_real_rows = rows.len(); + pad_rows(&mut rows, || [F::zero(); NUM_SHA_COMPRESS_COLS]); + // Set the octet_num and octect columns for the padded rows. + let mut octet_num = 0; + let mut octet = 0; + for row in rows[num_real_rows..].iter_mut() { + let cols: &mut ShaCompressCols = row.as_mut_slice().borrow_mut(); + cols.octet_num[octet_num] = F::one(); + cols.octet[octet] = F::one(); + + // If in the compression phase, set the k value. + if octet_num != 0 && octet_num != 9 { + let compression_idx = octet_num - 1; + let k_idx = compression_idx * 8 + octet; + cols.k = Word::from(SHA_COMPRESS_K[k_idx]); + } + + octet = (octet + 1) % 8; + if octet == 0 { + octet_num = (octet_num + 1) % 10; + } + + cols.is_last_row = cols.octet[7] * cols.octet_num[9]; + } + // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_SHA_COMPRESS_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut ShaCompressCols = trace.values + [i * NUM_SHA_COMPRESS_COLS..(i + 1) * NUM_SHA_COMPRESS_COLS] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/sha256/extend/air.rs b/core/src/syscall/precompiles/sha256/extend/air.rs index 9da604804..69f38c355 100644 --- a/core/src/syscall/precompiles/sha256/extend/air.rs +++ b/core/src/syscall/precompiles/sha256/extend/air.rs @@ -27,6 +27,13 @@ where let (local, next) = (main.row_slice(0), main.row_slice(1)); let local: &ShaExtendCols = (*local).borrow(); let next: &ShaExtendCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); + let i_start = AB::F::from_canonical_u32(16); let nb_bytes_in_word = AB::F::from_canonical_u32(4); @@ -42,6 +49,10 @@ where .when_transition() .when_not(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when_not(local.cycle_16_end.result * local.cycle_48[2]) + .assert_eq(local.channel, next.channel); builder .when_transition() .when_not(local.cycle_16_end.result * local.cycle_48[2]) @@ -214,22 +225,28 @@ where local.is_real, ); + builder.assert_word_eq(*local.w_i.value(), local.s2.value); + // Receive syscall event in first row of 48-cycle. builder.receive_syscall( local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::SHA_EXTEND.syscall_id()), local.w_ptr, AB::Expr::zero(), local.cycle_48_start, ); - // If this row is real and not the last cycle, then next row should also be real. + // Assert that is_real is a bool. + builder.assert_bool(local.is_real); + + // Ensure that all rows in a 48 row cycle has the same `is_real` values. builder .when_transition() - .when(local.is_real - local.cycle_48_end) - .assert_one(next.is_real); + .when_not(local.cycle_48_end) + .assert_eq(local.is_real, next.is_real); // Assert that the table ends in nonreal columns. Since each extend ecall is 48 cycles and // the table is padded to a power of 2, the last row of the table should always be padding. diff --git a/core/src/syscall/precompiles/sha256/extend/columns.rs b/core/src/syscall/precompiles/sha256/extend/columns.rs index 5eb99e1f4..0855b4413 100644 --- a/core/src/syscall/precompiles/sha256/extend/columns.rs +++ b/core/src/syscall/precompiles/sha256/extend/columns.rs @@ -18,6 +18,7 @@ pub struct ShaExtendCols { /// Inputs. pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub w_ptr: T, @@ -36,8 +37,9 @@ pub struct ShaExtendCols { /// Flags for when in the first, second, or third 16-row cycle. pub cycle_48: [T; 3], - /// Whether the current row is the first of a 48-row cycle. + /// Whether the current row is the first of a 48-row cycle and is real. pub cycle_48_start: T, + /// Whether the current row is the end of a 48-row cycle and is real. pub cycle_48_end: T, /// Inputs to `s0`. diff --git a/core/src/syscall/precompiles/sha256/extend/execute.rs b/core/src/syscall/precompiles/sha256/extend/execute.rs index bd163c26c..d9b1a70e0 100644 --- a/core/src/syscall/precompiles/sha256/extend/execute.rs +++ b/core/src/syscall/precompiles/sha256/extend/execute.rs @@ -60,9 +60,11 @@ impl Syscall for ShaExtendChip { } // Push the SHA extend event. + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); rt.record_mut().sha_extend_events.push(ShaExtendEvent { + lookup_id, shard, channel, clk: clk_init, diff --git a/core/src/syscall/precompiles/sha256/extend/flags.rs b/core/src/syscall/precompiles/sha256/extend/flags.rs index 2f97dc92f..a06f117e3 100644 --- a/core/src/syscall/precompiles/sha256/extend/flags.rs +++ b/core/src/syscall/precompiles/sha256/extend/flags.rs @@ -7,6 +7,7 @@ use p3_field::PrimeField32; use p3_field::TwoAdicField; use p3_matrix::Matrix; +use crate::air::BaseAirBuilder; use crate::air::SP1AirBuilder; use crate::operations::IsZeroOperation; @@ -70,7 +71,7 @@ impl ShaExtendChip { builder, local.cycle_16 - AB::Expr::from(g), local.cycle_16_start, - local.is_real.into(), + one.clone(), ); // Constrain `cycle_16_end.result` to be `cycle_16 - 1 == 0`. Intuitively g^16 is 1. @@ -78,7 +79,7 @@ impl ShaExtendChip { builder, local.cycle_16 - AB::Expr::one(), local.cycle_16_end, - local.is_real.into(), + one.clone(), ); // Constrain `cycle_48` to be [1, 0, 0] in the first row. @@ -123,10 +124,10 @@ impl ShaExtendChip { .when(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(next.i, AB::F::from_canonical_u32(16)); - // When it's not the end of a 16-cycle, the next `i` must be the current plus one. + // When it's not the end of a 48-cycle, the next `i` must be the current plus one. builder .when_transition() - .when(one.clone() - local.cycle_16_end.result) + .when_not(local.cycle_16_end.result * local.cycle_48[2]) .assert_eq(local.i + one.clone(), next.i); } } diff --git a/core/src/syscall/precompiles/sha256/extend/mod.rs b/core/src/syscall/precompiles/sha256/extend/mod.rs index 4caff508b..7868cabd8 100644 --- a/core/src/syscall/precompiles/sha256/extend/mod.rs +++ b/core/src/syscall/precompiles/sha256/extend/mod.rs @@ -11,6 +11,7 @@ use serde::{Deserialize, Serialize}; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ShaExtendEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, diff --git a/core/src/syscall/precompiles/sha256/extend/trace.rs b/core/src/syscall/precompiles/sha256/extend/trace.rs index 2a976ef0d..2dcf88226 100644 --- a/core/src/syscall/precompiles/sha256/extend/trace.rs +++ b/core/src/syscall/precompiles/sha256/extend/trace.rs @@ -1,7 +1,7 @@ -use std::borrow::BorrowMut; - use p3_field::PrimeField32; use p3_matrix::dense::RowMajorMatrix; +use p3_matrix::Matrix; +use std::borrow::BorrowMut; use crate::{ air::MachineAir, @@ -156,10 +156,19 @@ impl MachineAir for ShaExtendChip { } // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), NUM_SHA_EXTEND_COLS, - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut ShaExtendCols = + trace.values[i * NUM_SHA_EXTEND_COLS..(i + 1) * NUM_SHA_EXTEND_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { diff --git a/core/src/syscall/precompiles/uint256/air.rs b/core/src/syscall/precompiles/uint256/air.rs index 498ac78c6..dd8cce29e 100644 --- a/core/src/syscall/precompiles/uint256/air.rs +++ b/core/src/syscall/precompiles/uint256/air.rs @@ -17,6 +17,7 @@ use crate::utils::{ use generic_array::GenericArray; use num::Zero; use num::{BigUint, One}; +use p3_air::AirBuilder; use p3_air::{Air, BaseAir}; use p3_field::AbstractField; use p3_field::PrimeField32; @@ -33,6 +34,7 @@ const NUM_COLS: usize = size_of::>(); #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Uint256MulEvent { + pub lookup_id: usize, pub shard: u32, pub channel: u32, pub clk: u32, @@ -71,6 +73,9 @@ pub struct Uint256MulCols { /// The clock cycle of the syscall. pub clk: T, + /// The none of the operation. + pub nonce: T, + /// The pointer to the first input. pub x_ptr: T, @@ -201,7 +206,17 @@ impl MachineAir for Uint256MulChip { }); // Convert the trace to a row major matrix. - RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS) + let mut trace = + RowMajorMatrix::new(rows.into_iter().flatten().collect::>(), NUM_COLS); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut Uint256MulCols = + trace.values[i * NUM_COLS..(i + 1) * NUM_COLS].borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -257,10 +272,12 @@ impl Syscall for Uint256MulChip { // Write the result to x and keep track of the memory records. let x_memory_records = rt.mw_slice(x_ptr, &result); + let lookup_id = rt.syscall_lookup_id; let shard = rt.current_shard(); let channel = rt.current_channel(); let clk = rt.clk; rt.record_mut().uint256_mul_events.push(Uint256MulEvent { + lookup_id, shard, channel, clk, @@ -293,6 +310,14 @@ where let main = builder.main(); let local = main.row_slice(0); let local: &Uint256MulCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &Uint256MulCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); // We are computing (x * y) % modulus. The value of x is stored in the "prev_value" of // the x_memory, since we write to it later. @@ -368,6 +393,7 @@ where local.shard, local.channel, local.clk, + local.nonce, AB::F::from_canonical_u32(SyscallCode::UINT256_MUL.syscall_id()), local.x_ptr, local.y_ptr, diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs index adbab629f..2eef29c6c 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_add.rs @@ -52,6 +52,7 @@ pub struct WeierstrassAddAssignCols { pub is_real: T, pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub p_ptr: T, pub q_ptr: T, @@ -302,10 +303,21 @@ impl MachineAir }); // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_add_cols::(), - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut WeierstrassAddAssignCols = trace.values[i + * num_weierstrass_add_cols::() + ..(i + 1) * num_weierstrass_add_cols::()] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -331,117 +343,125 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &WeierstrassAddAssignCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &WeierstrassAddAssignCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &WeierstrassAddAssignCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let num_words_field_element = ::Limbs::USIZE / 4; - let p_x = limbs_from_prev_access(&row.p_access[0..num_words_field_element]); - let p_y = limbs_from_prev_access(&row.p_access[num_words_field_element..]); + let p_x = limbs_from_prev_access(&local.p_access[0..num_words_field_element]); + let p_y = limbs_from_prev_access(&local.p_access[num_words_field_element..]); - let q_x = limbs_from_prev_access(&row.q_access[0..num_words_field_element]); - let q_y = limbs_from_prev_access(&row.q_access[num_words_field_element..]); + let q_x = limbs_from_prev_access(&local.q_access[0..num_words_field_element]); + let q_y = limbs_from_prev_access(&local.q_access[num_words_field_element..]); // slope = (q.y - p.y) / (q.x - p.x). let slope = { - row.slope_numerator.eval( + local.slope_numerator.eval( builder, &q_y, &p_y, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_denominator.eval( + local.slope_denominator.eval( builder, &q_x, &p_x, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope.eval( + local.slope.eval( builder, - &row.slope_numerator.result, - &row.slope_denominator.result, + &local.slope_numerator.result, + &local.slope_denominator.result, FieldOperation::Div, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.slope.result + &local.slope.result }; // x = slope * slope - self.x - other.x. let x = { - row.slope_squared.eval( + local.slope_squared.eval( builder, slope, slope, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.p_x_plus_q_x.eval( + local.p_x_plus_q_x.eval( builder, &p_x, &q_x, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x3_ins.eval( + local.x3_ins.eval( builder, - &row.slope_squared.result, - &row.p_x_plus_q_x.result, + &local.slope_squared.result, + &local.p_x_plus_q_x.result, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.x3_ins.result + &local.x3_ins.result }; // y = slope * (p.x - x_3n) - q.y. { - row.p_x_minus_x.eval( + local.p_x_minus_x.eval( builder, &p_x, x, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_times_p_x_minus_x.eval( + local.slope_times_p_x_minus_x.eval( builder, slope, - &row.p_x_minus_x.result, + &local.p_x_minus_x.result, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.y3_ins.eval( + local.y3_ins.eval( builder, - &row.slope_times_p_x_minus_x.result, + &local.slope_times_p_x_minus_x.result, &p_y, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); } @@ -449,29 +469,29 @@ where // ensure that p_access is updated with the new value. for i in 0..E::BaseField::NB_LIMBS { builder - .when(row.is_real) - .assert_eq(row.x3_ins.result[i], row.p_access[i / 4].value()[i % 4]); - builder.when(row.is_real).assert_eq( - row.y3_ins.result[i], - row.p_access[num_words_field_element + i / 4].value()[i % 4], + .when(local.is_real) + .assert_eq(local.x3_ins.result[i], local.p_access[i / 4].value()[i % 4]); + builder.when(local.is_real).assert_eq( + local.y3_ins.result[i], + local.p_access[num_words_field_element + i / 4].value()[i % 4], ); } builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk.into(), - row.q_ptr, - &row.q_access, - row.is_real, + local.shard, + local.channel, + local.clk.into(), + local.q_ptr, + &local.q_access, + local.is_real, ); builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. - row.p_ptr, - &row.p_access, - row.is_real, + local.shard, + local.channel, + local.clk + AB::F::from_canonical_u32(1), // We read p at +1 since p, q could be the same. + local.p_ptr, + &local.p_access, + local.is_real, ); // Fetch the syscall id for the curve type. @@ -487,13 +507,14 @@ where }; builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, syscall_id_felt, - row.p_ptr, - row.q_ptr, - row.is_real, + local.p_ptr, + local.q_ptr, + local.is_real, ); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs index bd38edea8..62958e86c 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_decompress.rs @@ -54,6 +54,7 @@ pub struct WeierstrassDecompressCols { pub shard: T, pub channel: T, pub clk: T, + pub nonce: T, pub ptr: T, pub is_odd: T, pub x_access: GenericArray, P::WordsFieldElement>, @@ -222,10 +223,21 @@ impl MachineAir row }); - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_decompress_cols::(), - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut WeierstrassDecompressCols = trace.values[i + * num_weierstrass_decompress_cols::() + ..(i + 1) * num_weierstrass_decompress_cols::()] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -250,99 +262,108 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &WeierstrassDecompressCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &WeierstrassDecompressCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &WeierstrassDecompressCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let num_limbs = ::Limbs::USIZE; let num_words_field_element = num_limbs / 4; - builder.assert_bool(row.is_odd); + builder.assert_bool(local.is_odd); let x: Limbs::Limbs> = - limbs_from_prev_access(&row.x_access); - row.range_x - .eval(builder, &x, row.shard, row.channel, row.is_real); - row.x_2.eval( + limbs_from_prev_access(&local.x_access); + local + .range_x + .eval(builder, &x, local.shard, local.channel, local.is_real); + local.x_2.eval( builder, &x, &x, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x_3.eval( + local.x_3.eval( builder, - &row.x_2.result, + &local.x_2.result, &x, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); let b = E::b_int(); let b_const = E::BaseField::to_limbs_field::(&b); - row.x_3_plus_b.eval( + local.x_3_plus_b.eval( builder, - &row.x_3.result, + &local.x_3.result, &b_const, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.neg_y.eval( + local.neg_y.eval( builder, &[AB::Expr::zero()].iter(), - &row.y.multiplication.result, + &local.y.multiplication.result, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); // Interpret the lowest bit of Y as whether it is odd or not. - let y_is_odd = row.y.lsb; + let y_is_odd = local.y.lsb; - row.y.eval( + local.y.eval( builder, - &row.x_3_plus_b.result, - row.y.lsb, - row.shard, - row.channel, - row.is_real, + &local.x_3_plus_b.result, + local.y.lsb, + local.shard, + local.channel, + local.is_real, ); let y_limbs: Limbs::Limbs> = - limbs_from_access(&row.y_access); + limbs_from_access(&local.y_access); builder - .when(row.is_real) - .when_ne(y_is_odd, AB::Expr::one() - row.is_odd) - .assert_all_eq(row.y.multiplication.result, y_limbs); + .when(local.is_real) + .when_ne(y_is_odd, AB::Expr::one() - local.is_odd) + .assert_all_eq(local.y.multiplication.result, y_limbs); builder - .when(row.is_real) - .when_ne(y_is_odd, row.is_odd) - .assert_all_eq(row.neg_y.result, y_limbs); + .when(local.is_real) + .when_ne(y_is_odd, local.is_odd) + .assert_all_eq(local.neg_y.result, y_limbs); for i in 0..num_words_field_element { builder.eval_memory_access( - row.shard, - row.channel, - row.clk, - row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), - &row.x_access[i], - row.is_real, + local.shard, + local.channel, + local.clk, + local.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4 + num_limbs as u32), + &local.x_access[i], + local.is_real, ); } for i in 0..num_words_field_element { builder.eval_memory_access( - row.shard, - row.channel, - row.clk, - row.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), - &row.y_access[i], - row.is_real, + local.shard, + local.channel, + local.clk, + local.ptr.into() + AB::F::from_canonical_u32((i as u32) * 4), + &local.y_access[i], + local.is_real, ); } @@ -357,13 +378,14 @@ where }; builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, syscall_id, - row.ptr, - row.is_odd, - row.is_real, + local.ptr, + local.is_odd, + local.is_real, ); } } diff --git a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs index 50bb0a433..9221d680f 100644 --- a/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs +++ b/core/src/syscall/precompiles/weierstrass/weierstrass_double.rs @@ -53,6 +53,7 @@ pub struct WeierstrassDoubleAssignCols { pub is_real: T, pub shard: T, pub channel: T, + pub nonce: T, pub clk: T, pub p_ptr: T, pub p_access: GenericArray, P::WordsCurvePoint>, @@ -317,10 +318,21 @@ impl MachineAir }); // Convert the trace to a row major matrix. - RowMajorMatrix::new( + let mut trace = RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), num_weierstrass_double_cols::(), - ) + ); + + // Write the nonces to the trace. + for i in 0..trace.height() { + let cols: &mut WeierstrassDoubleAssignCols = trace.values[i + * num_weierstrass_double_cols::() + ..(i + 1) * num_weierstrass_double_cols::()] + .borrow_mut(); + cols.nonce = F::from_canonical_usize(i); + } + + trace } fn included(&self, shard: &Self::Record) -> bool { @@ -346,136 +358,143 @@ where { fn eval(&self, builder: &mut AB) { let main = builder.main(); - let row = main.row_slice(0); - let row: &WeierstrassDoubleAssignCols = (*row).borrow(); + let local = main.row_slice(0); + let local: &WeierstrassDoubleAssignCols = (*local).borrow(); + let next = main.row_slice(1); + let next: &WeierstrassDoubleAssignCols = (*next).borrow(); + + // Constrain the incrementing nonce. + builder.when_first_row().assert_zero(local.nonce); + builder + .when_transition() + .assert_eq(local.nonce + AB::Expr::one(), next.nonce); let num_words_field_element = E::BaseField::NB_LIMBS / 4; - let p_x = limbs_from_prev_access(&row.p_access[0..num_words_field_element]); - let p_y = limbs_from_prev_access(&row.p_access[num_words_field_element..]); + let p_x = limbs_from_prev_access(&local.p_access[0..num_words_field_element]); + let p_y = limbs_from_prev_access(&local.p_access[num_words_field_element..]); - // a in the Weierstrass form: y^2 = x^3 + a * x + b. - // TODO: U32 can't be hardcoded here? + // `a` in the Weierstrass form: y^2 = x^3 + a * x + b. let a = E::BaseField::to_limbs_field::(&E::a_int()); // slope = slope_numerator / slope_denominator. let slope = { // slope_numerator = a + (p.x * p.x) * 3. { - row.p_x_squared.eval( + local.p_x_squared.eval( builder, &p_x, &p_x, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.p_x_squared_times_3.eval( + local.p_x_squared_times_3.eval( builder, - &row.p_x_squared.result, + &local.p_x_squared.result, &E::BaseField::to_limbs_field::(&BigUint::from(3u32)), FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_numerator.eval( + local.slope_numerator.eval( builder, &a, - &row.p_x_squared_times_3.result, + &local.p_x_squared_times_3.result, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); }; // slope_denominator = 2 * y. - row.slope_denominator.eval( + local.slope_denominator.eval( builder, &E::BaseField::to_limbs_field::(&BigUint::from(2u32)), &p_y, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope.eval( + local.slope.eval( builder, - &row.slope_numerator.result, - &row.slope_denominator.result, + &local.slope_numerator.result, + &local.slope_denominator.result, FieldOperation::Div, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.slope.result + &local.slope.result }; // x = slope * slope - (p.x + p.x). let x = { - row.slope_squared.eval( + local.slope_squared.eval( builder, slope, slope, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.p_x_plus_p_x.eval( + local.p_x_plus_p_x.eval( builder, &p_x, &p_x, FieldOperation::Add, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.x3_ins.eval( + local.x3_ins.eval( builder, - &row.slope_squared.result, - &row.p_x_plus_p_x.result, + &local.slope_squared.result, + &local.p_x_plus_p_x.result, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - &row.x3_ins.result + &local.x3_ins.result }; // y = slope * (p.x - x) - p.y. { - row.p_x_minus_x.eval( + local.p_x_minus_x.eval( builder, &p_x, x, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.slope_times_p_x_minus_x.eval( + local.slope_times_p_x_minus_x.eval( builder, slope, - &row.p_x_minus_x.result, + &local.p_x_minus_x.result, FieldOperation::Mul, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); - row.y3_ins.eval( + local.y3_ins.eval( builder, - &row.slope_times_p_x_minus_x.result, + &local.slope_times_p_x_minus_x.result, &p_y, FieldOperation::Sub, - row.shard, - row.channel, - row.is_real, + local.shard, + local.channel, + local.is_real, ); } @@ -483,21 +502,21 @@ where // ensure that p_access is updated with the new value. for i in 0..E::BaseField::NB_LIMBS { builder - .when(row.is_real) - .assert_eq(row.x3_ins.result[i], row.p_access[i / 4].value()[i % 4]); - builder.when(row.is_real).assert_eq( - row.y3_ins.result[i], - row.p_access[num_words_field_element + i / 4].value()[i % 4], + .when(local.is_real) + .assert_eq(local.x3_ins.result[i], local.p_access[i / 4].value()[i % 4]); + builder.when(local.is_real).assert_eq( + local.y3_ins.result[i], + local.p_access[num_words_field_element + i / 4].value()[i % 4], ); } builder.eval_memory_access_slice( - row.shard, - row.channel, - row.clk.into(), - row.p_ptr, - &row.p_access, - row.is_real, + local.shard, + local.channel, + local.clk.into(), + local.p_ptr, + &local.p_access, + local.is_real, ); // Fetch the syscall id for the curve type. @@ -513,13 +532,14 @@ where }; builder.receive_syscall( - row.shard, - row.channel, - row.clk, + local.shard, + local.channel, + local.clk, + local.nonce, syscall_id_felt, - row.p_ptr, + local.p_ptr, AB::Expr::zero(), - row.is_real, + local.is_real, ); } } diff --git a/core/src/syscall/verify.rs b/core/src/syscall/verify.rs index e40639aeb..11b043010 100644 --- a/core/src/syscall/verify.rs +++ b/core/src/syscall/verify.rs @@ -1,6 +1,6 @@ use crate::{ runtime::{Syscall, SyscallContext}, - stark::{RiscvAir, StarkGenericConfig}, + stark::StarkGenericConfig, utils::BabyBearPoseidon2Inner, }; @@ -38,17 +38,6 @@ impl Syscall for SyscallVerifySP1Proof { let config = BabyBearPoseidon2Inner::new(); let mut challenger = config.challenger(); - // TODO: need to use RecursionAir here - let machine = RiscvAir::machine(config); - - // TODO: Need to import PublicValues from recursion. - // Assert the commit in vkey from runtime inputs matches the one from syscall. - // Assert that the public values digest from runtime inputs matches the one from syscall. - - // TODO: Verify proof - // machine - // .verify(proof_vk, proof, &mut challenger) - // .expect("proof verification failed"); None } diff --git a/core/src/utils/programs.rs b/core/src/utils/programs.rs index 58af5a08c..a71a7dfef 100644 --- a/core/src/utils/programs.rs +++ b/core/src/utils/programs.rs @@ -34,9 +34,6 @@ pub mod tests { pub const ED25519_ELF: &[u8] = include_bytes!("../../../tests/ed25519/elf/riscv32im-succinct-zkvm-elf"); - pub const BLAKE3_COMPRESS_ELF: &[u8] = - include_bytes!("../../../tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf"); - pub const CYCLE_TRACKER_ELF: &[u8] = include_bytes!("../../../tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf"); diff --git a/prover/src/build.rs b/prover/src/build.rs index 129ec454a..24f152d49 100644 --- a/prover/src/build.rs +++ b/prover/src/build.rs @@ -37,9 +37,6 @@ pub fn try_install_plonk_bn254_artifacts() -> PathBuf { } /// Tries to build the PLONK artifacts inside the development directory. -/// -/// TODO: Maybe add some additional logic here to handle rebuilding the artifacts if they are -/// already built. pub fn try_build_plonk_bn254_artifacts_dev( template_vk: &StarkVerifyingKey, template_proof: &ShardProof, diff --git a/prover/src/install.rs b/prover/src/install.rs index 873c50f68..a469a40b1 100644 --- a/prover/src/install.rs +++ b/prover/src/install.rs @@ -10,7 +10,7 @@ use crate::utils::block_on; pub const PLONK_BN254_ARTIFACTS_URL_BASE: &str = "https://sp1-circuits.s3-us-east-2.amazonaws.com"; /// The current version of the plonk bn254 artifacts. -pub const PLONK_BN254_ARTIFACTS_COMMIT: &str = "e48c01ec"; +pub const PLONK_BN254_ARTIFACTS_COMMIT: &str = "4a525e9f"; /// Install the latest plonk bn254 artifacts. /// diff --git a/prover/src/verify.rs b/prover/src/verify.rs index fedd467cf..f829f4d6a 100644 --- a/prover/src/verify.rs +++ b/prover/src/verify.rs @@ -4,6 +4,7 @@ use anyhow::Result; use num_bigint::BigUint; use p3_baby_bear::BabyBear; use p3_field::{AbstractField, PrimeField}; +use sp1_core::air::MachineAir; use sp1_core::{ air::PublicValues, io::SP1PublicValues, @@ -45,7 +46,7 @@ impl SP1Prover { self.core_machine .verify(&vk.vk, &machine_proof, &mut challenger)?; - // Verify shard transitions + // Verify shard transitions. for (i, shard_proof) in proof.0.iter().enumerate() { let public_values = PublicValues::from_vec(shard_proof.public_values.clone()); // Verify shard transitions @@ -100,6 +101,58 @@ impl SP1Prover { } } + // Verify that the number of shards is not too large. + if proof.0.len() > 1 << 16 { + return Err(MachineVerificationError::TooManyShards); + } + + // Verify that the `MemoryInit` and `MemoryFinalize` chips are the last chips in the proof. + for (i, shard_proof) in proof.0.iter().enumerate() { + let chips = self + .core_machine + .shard_chips_ordered(&shard_proof.chip_ordering) + .collect::>(); + let program_memory_init_count = chips + .clone() + .into_iter() + .filter(|chip| chip.name() == "MemoryProgram") + .count(); + let memory_init_count = chips + .clone() + .into_iter() + .filter(|chip| chip.name() == "MemoryInit") + .count(); + let memory_final_count = chips + .into_iter() + .filter(|chip| chip.name() == "MemoryFinalize") + .count(); + + // Assert that the `MemoryProgram` chip only exists in the first shard. + if i == 0 && program_memory_init_count != 1 { + return Err(MachineVerificationError::InvalidChipOccurence( + "memory should exist in the first chip".to_string(), + )); + } + if i != 0 && program_memory_init_count > 0 { + return Err(MachineVerificationError::InvalidChipOccurence( + "memory program should not exist in the first chip".to_string(), + )); + } + + // Assert that the `MemoryInit` and `MemoryFinalize` chips only exist in the last shard. + if i != proof.0.len() - 1 && (memory_final_count > 0 || memory_init_count > 0) { + return Err(MachineVerificationError::InvalidChipOccurence( + "memory init and finalize should not eixst anywhere but the last chip" + .to_string(), + )); + } + if i == proof.0.len() - 1 && (memory_init_count != 1 || memory_final_count != 1) { + return Err(MachineVerificationError::InvalidChipOccurence( + "memory init and finalize should exist the last chip".to_string(), + )); + } + } + Ok(()) } diff --git a/recursion/circuit/Cargo.toml b/recursion/circuit/Cargo.toml index 8843d1185..1b5076d5c 100644 --- a/recursion/circuit/Cargo.toml +++ b/recursion/circuit/Cargo.toml @@ -31,3 +31,6 @@ p3-poseidon2 = { workspace = true } zkhash = { git = "https://github.com/HorizenLabs/poseidon2" } rand = "0.8.5" sp1-recursion-gnark-ffi = { path = "../gnark-ffi" } + +[features] +plonk = ["sp1-recursion-gnark-ffi/plonk"] diff --git a/recursion/circuit/src/poseidon2.rs b/recursion/circuit/src/poseidon2.rs index a5a8cc113..792754014 100644 --- a/recursion/circuit/src/poseidon2.rs +++ b/recursion/circuit/src/poseidon2.rs @@ -1,5 +1,7 @@ //! An implementation of Poseidon2 over BN254. +use std::array; + use itertools::Itertools; use p3_field::AbstractField; use p3_field::Field; @@ -16,6 +18,8 @@ pub trait Poseidon2CircuitBuilder { fn p2_permute_mut(&mut self, state: [Var; SPONGE_SIZE]); fn p2_hash(&mut self, input: &[Felt]) -> OuterDigestVariable; fn p2_compress(&mut self, input: [OuterDigestVariable; 2]) -> OuterDigestVariable; + fn p2_babybear_permute_mut(&mut self, state: [Felt; 16]); + fn p2_babybear_hash(&mut self, input: &[Felt]) -> [Felt; 8]; } impl Poseidon2CircuitBuilder for Builder { @@ -52,6 +56,24 @@ impl Poseidon2CircuitBuilder for Builder { self.p2_permute_mut(state); [state[0]; DIGEST_SIZE] } + + fn p2_babybear_permute_mut(&mut self, state: [Felt; 16]) { + self.push(DslIr::CircuitPoseidon2PermuteBabyBear(state)); + } + + fn p2_babybear_hash(&mut self, input: &[Felt]) -> [Felt; 8] { + let mut state: [Felt; 16] = array::from_fn(|_| self.eval(C::F::zero())); + + for block_chunk in &input.iter().chunks(8) { + state + .iter_mut() + .zip(block_chunk) + .for_each(|(s, i)| *s = self.eval(*i)); + self.p2_babybear_permute_mut(state); + } + + array::from_fn(|i| state[i]) + } } #[cfg(test)] @@ -60,6 +82,9 @@ pub mod tests { use p3_bn254_fr::Bn254Fr; use p3_field::AbstractField; use p3_symmetric::{CryptographicHasher, Permutation, PseudoCompressionFunction}; + use rand::thread_rng; + use rand::Rng; + use sp1_core::utils::{inner_perm, InnerHash}; use sp1_recursion_compiler::config::OuterConfig; use sp1_recursion_compiler::constraints::ConstraintCompiler; use sp1_recursion_compiler::ir::{Builder, Felt, Var, Witness}; @@ -95,6 +120,25 @@ pub mod tests { PlonkBn254Prover::test::(constraints.clone(), Witness::default()); } + #[test] + fn test_p2_babybear_permute_mut() { + let mut rng = thread_rng(); + let mut builder = Builder::::default(); + let input: [BabyBear; 16] = [rng.gen(); 16]; + let input_vars: [Felt<_>; 16] = input.map(|x| builder.eval(x)); + builder.p2_babybear_permute_mut(input_vars); + + let perm = inner_perm(); + let result = perm.permute(input); + for i in 0..16 { + builder.assert_felt_eq(input_vars[i], result[i]); + } + + let mut backend = ConstraintCompiler::::default(); + let constraints = backend.emit(builder.operations); + PlonkBn254Prover::test::(constraints.clone(), Witness::default()); + } + #[test] fn test_p2_hash() { let perm = outer_perm(); @@ -147,4 +191,53 @@ pub mod tests { let constraints = backend.emit(builder.operations); PlonkBn254Prover::test::(constraints.clone(), Witness::default()); } + + #[test] + fn test_p2_babybear_hash() { + let perm = inner_perm(); + let hasher = InnerHash::new(perm.clone()); + + let input: [BabyBear; 26] = [ + BabyBear::from_canonical_u32(0), + BabyBear::from_canonical_u32(1), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(2), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + BabyBear::from_canonical_u32(3), + ]; + let output = hasher.hash_iter(input); + println!("{:?}", output); + + let mut builder = Builder::::default(); + let input_felts: [Felt<_>; 26] = input.map(|x| builder.eval(x)); + let result = builder.p2_babybear_hash(input_felts.as_slice()); + + for i in 0..8 { + builder.assert_felt_eq(result[i], output[i]); + } + + let mut backend = ConstraintCompiler::::default(); + let constraints = backend.emit(builder.operations); + PlonkBn254Prover::test::(constraints.clone(), Witness::default()); + } } diff --git a/recursion/circuit/src/stark.rs b/recursion/circuit/src/stark.rs index 48c800aa1..574a80141 100644 --- a/recursion/circuit/src/stark.rs +++ b/recursion/circuit/src/stark.rs @@ -2,6 +2,7 @@ use std::borrow::Borrow; use std::marker::PhantomData; use crate::fri::verify_two_adic_pcs; +use crate::poseidon2::Poseidon2CircuitBuilder; use crate::types::OuterDigestVariable; use crate::utils::{babybear_bytes_to_bn254, babybears_to_bn254, words_to_bytes}; use crate::witness::Witnessable; @@ -20,7 +21,7 @@ use sp1_recursion_compiler::constraints::{Constraint, ConstraintCompiler}; use sp1_recursion_compiler::ir::{Builder, Config, Ext, Felt, Var}; use sp1_recursion_compiler::ir::{Usize, Witness}; use sp1_recursion_compiler::prelude::SymbolicVar; -use sp1_recursion_core::air::RecursionPublicValues; +use sp1_recursion_core::air::{RecursionPublicValues, NUM_PV_ELMS_TO_HASH}; use sp1_recursion_core::stark::config::{outer_fri_config, BabyBearPoseidon2Outer}; use sp1_recursion_core::stark::RecursionAirSkinnyDeg9; use sp1_recursion_program::commit::PolynomialSpaceVariable; @@ -270,7 +271,9 @@ pub fn build_wrap_circuit( let element = builder.get(&proof.public_values, i); pv_elements.push(element); } + let pv: &RecursionPublicValues<_> = pv_elements.as_slice().borrow(); + let one_felt: Felt<_> = builder.constant(BabyBear::one()); // Proof must be complete. In the reduce program, this will ensure that the SP1 proof has been // fully accumulated. @@ -347,6 +350,13 @@ pub fn build_wrap_circuit( } builder.assert_ext_eq(cumulative_sum, zero_ext); + // Verify the public values digest. + let calculated_digest = builder.p2_babybear_hash(&pv_elements[0..NUM_PV_ELMS_TO_HASH]); + let expected_digest = pv.digest; + for (calculated_elm, expected_elm) in calculated_digest.iter().zip(expected_digest.iter()) { + builder.assert_felt_eq(*expected_elm, *calculated_elm); + } + let mut backend = ConstraintCompiler::::default(); backend.emit(builder.operations) } diff --git a/recursion/compiler/src/constraints/mod.rs b/recursion/compiler/src/constraints/mod.rs index c5c67647a..eb4395135 100644 --- a/recursion/compiler/src/constraints/mod.rs +++ b/recursion/compiler/src/constraints/mod.rs @@ -249,6 +249,10 @@ impl ConstraintCompiler { opcode: ConstraintOpcode::Permute, args: state.iter().map(|x| vec![x.id()]).collect(), }), + DslIr::CircuitPoseidon2PermuteBabyBear(state) => constraints.push(Constraint { + opcode: ConstraintOpcode::PermuteBabyBear, + args: state.iter().map(|x| vec![x.id()]).collect(), + }), DslIr::CircuitSelectV(cond, a, b, out) => { constraints.push(Constraint { opcode: ConstraintOpcode::SelectV, diff --git a/recursion/compiler/src/constraints/opcodes.rs b/recursion/compiler/src/constraints/opcodes.rs index 581b4558d..4911e0f10 100644 --- a/recursion/compiler/src/constraints/opcodes.rs +++ b/recursion/compiler/src/constraints/opcodes.rs @@ -46,4 +46,5 @@ pub enum ConstraintOpcode { CommitVkeyHash, CommitCommitedValuesDigest, CircuitFelts2Ext, + PermuteBabyBear, } diff --git a/recursion/compiler/src/ir/bits.rs b/recursion/compiler/src/ir/bits.rs index f69c8cee1..396fb92be 100644 --- a/recursion/compiler/src/ir/bits.rs +++ b/recursion/compiler/src/ir/bits.rs @@ -26,6 +26,15 @@ impl Builder { output } + /// Range checks a variable to a certain number of bits. + pub fn range_check_v(&mut self, num: Var, num_bits: usize) { + let bits = self.num2bits_v(num); + self.range(num_bits, bits.len()).for_each(|i, builder| { + let bit = builder.get(&bits, i); + builder.assert_var_eq(bit, C::N::zero()); + }); + } + /// Converts a variable to bits inside a circuit. pub fn num2bits_v_circuit(&mut self, num: Var, bits: usize) -> Vec> { let mut output = Vec::new(); diff --git a/recursion/compiler/src/ir/instructions.rs b/recursion/compiler/src/ir/instructions.rs index f7a5cee3e..3826fc327 100644 --- a/recursion/compiler/src/ir/instructions.rs +++ b/recursion/compiler/src/ir/instructions.rs @@ -201,6 +201,8 @@ pub enum DslIr { /// Permutes an array of Bn254 elements using Poseidon2 (output = p2_permute(array)). Should only /// be used when target is a gnark circuit. CircuitPoseidon2Permute([Var; 3]), + /// Permutates an array of BabyBear elements in the circuit. + CircuitPoseidon2PermuteBabyBear([Felt; 16]), // Miscellaneous instructions. /// Decompose hint operation of a usize into an array. (output = num2bits(usize)). diff --git a/recursion/core/src/air/builder.rs b/recursion/core/src/air/builder.rs index 6bcf20d40..ab6e6c101 100644 --- a/recursion/core/src/air/builder.rs +++ b/recursion/core/src/air/builder.rs @@ -30,6 +30,8 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { is_real: impl Into, ) { let is_real: Self::Expr = is_real.into(); + self.assert_bool(is_real.clone()); + let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); @@ -66,6 +68,8 @@ pub trait RecursionMemoryAirBuilder: RecursionInteractionAirBuilder { is_real: impl Into, ) { let is_real: Self::Expr = is_real.into(); + self.assert_bool(is_real.clone()); + let timestamp: Self::Expr = timestamp.into(); let mem_access = memory_access.access(); diff --git a/recursion/core/src/cpu/air/branch.rs b/recursion/core/src/cpu/air/branch.rs index 91bebfa65..105dd771d 100644 --- a/recursion/core/src/cpu/air/branch.rs +++ b/recursion/core/src/cpu/air/branch.rs @@ -3,7 +3,9 @@ use p3_field::{AbstractField, Field}; use sp1_core::air::{BinomialExtension, ExtensionAirBuilder}; use crate::{ - air::{BinomialExtensionUtils, IsExtZeroOperation, SP1RecursionAirBuilder}, + air::{ + BinomialExtensionUtils, Block, BlockBuilder, IsExtZeroOperation, SP1RecursionAirBuilder, + }, cpu::{CpuChip, CpuCols}, memory::MemoryCols, }; @@ -22,18 +24,24 @@ impl CpuChip { let is_branch_instruction = self.is_branch_instruction::(local); let one = AB::Expr::one(); - // If the instruction is a BNEINC, verify that the a value is incremented by one. - builder - .when(local.is_real) - .when(local.selectors.is_bneinc) - .assert_eq(local.a.value()[0], local.a.prev_value()[0] + one.clone()); - // Convert operand values from Block to BinomialExtension. Note that it gets the // previous value of the `a` and `b` operands, since BNENIC will modify `a`. + let a_prev_ext: BinomialExtension = + BinomialExtensionUtils::from_block(local.a.prev_value().map(|x| x.into())); let a_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.a.value().map(|x| x.into())); let b_ext: BinomialExtension = BinomialExtensionUtils::from_block(local.b.value().map(|x| x.into())); + let one_ext: BinomialExtension = + BinomialExtensionUtils::from_block(Block::from(one.clone())); + + let expected_a_ext = a_prev_ext + one_ext; + + // If the instruction is a BNEINC, verify that the a value is incremented by one. + builder + .when(local.is_real) + .when(local.selectors.is_bneinc) + .assert_block_eq(a_ext.as_block(), expected_a_ext.as_block()); let comparison_diff = a_ext - b_ext; diff --git a/recursion/core/src/cpu/air/jump.rs b/recursion/core/src/cpu/air/jump.rs index bf86a70cc..dd5e9b8bb 100644 --- a/recursion/core/src/cpu/air/jump.rs +++ b/recursion/core/src/cpu/air/jump.rs @@ -2,7 +2,7 @@ use p3_air::AirBuilder; use p3_field::{AbstractField, Field}; use crate::{ - air::SP1RecursionAirBuilder, + air::{Block, BlockBuilder, SP1RecursionAirBuilder}, cpu::{CpuChip, CpuCols}, memory::MemoryCols, runtime::STACK_SIZE, @@ -21,19 +21,29 @@ impl CpuChip { ) where AB: SP1RecursionAirBuilder, { + let is_jump_instr = self.is_jump_instruction::(local); + // Verify the next row's fp. builder .when_first_row() .assert_eq(local.fp, F::from_canonical_usize(STACK_SIZE)); - let not_jump_instruction = AB::Expr::one() - self.is_jump_instruction::(local); + let not_jump_instruction = AB::Expr::one() - is_jump_instr.clone(); let expected_next_fp = local.selectors.is_jal * (local.fp + local.c.value()[0]) - + local.selectors.is_jalr * local.a.value()[0] + + local.selectors.is_jalr * local.c.value()[0] + not_jump_instruction * local.fp; builder .when_transition() .when(next.is_real) .assert_eq(next.fp, expected_next_fp); + // Verify the a operand values. + let expected_a_val = local.selectors.is_jal * local.pc + + local.selectors.is_jalr * (local.pc + AB::Expr::one()); + let expected_a_val_block = Block::from(expected_a_val); + builder + .when(is_jump_instr) + .assert_block_eq(*local.a.value(), expected_a_val_block); + // Add to the `next_pc` expression. *next_pc += local.selectors.is_jal * (local.pc + local.b.value()[0]); *next_pc += local.selectors.is_jalr * local.b.value()[0]; diff --git a/recursion/core/src/cpu/air/memory.rs b/recursion/core/src/cpu/air/memory.rs index c0a3a2b63..d1b024130 100644 --- a/recursion/core/src/cpu/air/memory.rs +++ b/recursion/core/src/cpu/air/memory.rs @@ -30,7 +30,7 @@ impl CpuChip { local.clk + AB::F::from_canonical_u32(MemoryAccessPosition::Memory as u32), memory_cols.memory_addr, &memory_cols.memory, - is_memory_instr, + is_memory_instr.clone(), ); // Constraints on the memory column depending on load or store. @@ -41,7 +41,7 @@ impl CpuChip { ); // When there is a store, we ensure that we are writing the value of the a operand to the memory. builder - .when(local.selectors.is_store) + .when(is_memory_instr) .assert_block_eq(*local.a.value(), *memory_cols.memory.value()); } } diff --git a/recursion/core/src/multi/mod.rs b/recursion/core/src/multi/mod.rs index 23173f7de..10a9c3db0 100644 --- a/recursion/core/src/multi/mod.rs +++ b/recursion/core/src/multi/mod.rs @@ -190,8 +190,7 @@ where local.poseidon2_receive_table, ); sub_builder.assert_eq( - local.is_poseidon2 - * Poseidon2Chip::do_memory_access::(poseidon2_columns), + local.is_poseidon2 * Poseidon2Chip::do_memory_access::(poseidon2_columns), local.poseidon2_memory_access, ); @@ -201,7 +200,7 @@ where local.poseidon2(), next.poseidon2(), local.poseidon2_receive_table, - local.poseidon2_memory_access.into(), + local.poseidon2_memory_access, ); } } diff --git a/recursion/core/src/poseidon2/columns.rs b/recursion/core/src/poseidon2/columns.rs index 12fa73047..fa12a655f 100644 --- a/recursion/core/src/poseidon2/columns.rs +++ b/recursion/core/src/poseidon2/columns.rs @@ -11,7 +11,10 @@ pub struct Poseidon2Cols { pub left_input: T, pub right_input: T, pub rounds: [T; 24], // 1 round for memory input; 1 round for initialize; 8 rounds for external; 13 rounds for internal; 1 round for memory output + pub do_receive: T, + pub do_memory: T, pub round_specific_cols: RoundSpecificCols, + pub is_real: T, } #[derive(AlignedBorrow, Clone, Copy)] @@ -45,6 +48,7 @@ impl RoundSpecificCols { pub struct ComputationCols { pub input: [T; WIDTH], pub add_rc: [T; WIDTH], + pub sbox_deg_3: [T; WIDTH], pub sbox_deg_7: [T; WIDTH], pub output: [T; WIDTH], } diff --git a/recursion/core/src/poseidon2/external.rs b/recursion/core/src/poseidon2/external.rs index c871bd873..d340ba2b4 100644 --- a/recursion/core/src/poseidon2/external.rs +++ b/recursion/core/src/poseidon2/external.rs @@ -6,7 +6,6 @@ use p3_field::AbstractField; use p3_matrix::Matrix; use sp1_core::air::{BaseAirBuilder, ExtensionAirBuilder, SP1AirBuilder}; use sp1_primitives::RC_16_30_U32; -use std::ops::Add; use crate::air::{RecursionInteractionAirBuilder, RecursionMemoryAirBuilder}; use crate::memory::MemoryCols; @@ -40,7 +39,7 @@ impl Poseidon2Chip { local: &Poseidon2Cols, next: &Poseidon2Cols, receive_table: AB::Var, - memory_access: AB::Expr, + memory_access: AB::Var, ) { const NUM_ROUNDS_F: usize = 8; const NUM_ROUNDS_P: usize = 13; @@ -66,6 +65,10 @@ impl Poseidon2Chip { .sum::(); let is_memory_write = local.rounds[local.rounds.len() - 1]; + self.eval_control_flow_and_inputs(builder, local, next); + + self.eval_syscall(builder, local, receive_table); + self.eval_mem( builder, local, @@ -84,16 +87,71 @@ impl Poseidon2Chip { is_internal_layer.clone(), NUM_ROUNDS_F + NUM_ROUNDS_P + 1, ); + } - self.eval_syscall(builder, local, receive_table); - - // Range check all flags. - for i in 0..local.rounds.len() { + fn eval_control_flow_and_inputs( + &self, + builder: &mut AB, + local: &Poseidon2Cols, + next: &Poseidon2Cols, + ) { + let num_total_rounds = local.rounds.len(); + for i in 0..num_total_rounds { + // Verify that the round flags are correct. builder.assert_bool(local.rounds[i]); + + // Assert that the next round is correct. + builder + .when_transition() + .assert_eq(local.rounds[i], next.rounds[(i + 1) % num_total_rounds]); + + if i != num_total_rounds - 1 { + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.clk, next.clk); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.dst_input, next.dst_input); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.left_input, next.left_input); + builder + .when_transition() + .when(local.rounds[i]) + .assert_eq(local.right_input, next.right_input); + } } - builder.assert_bool( - is_memory_read + is_initial + is_external_layer + is_internal_layer + is_memory_write, + + // Ensure that at most one of the round flags is set. + let round_acc = local + .rounds + .iter() + .fold(AB::Expr::zero(), |acc, round_flag| acc + *round_flag); + builder.assert_bool(round_acc); + + // Verify the do_memory flag. + builder.assert_eq( + local.do_memory, + local.is_real * (local.rounds[0] + local.rounds[23]), ); + + // Verify the do_receive flag. + builder.assert_eq(local.do_receive, local.is_real * local.rounds[0]); + + // Verify the first row starts at round 0. + builder.when_first_row().assert_one(local.rounds[0]); + // The round count is not a power of 2, so the last row should not be real. + builder.when_last_row().assert_zero(local.is_real); + + // Verify that all is_real flags within a round are equal. + let is_last_round = local.rounds[23]; + builder + .when_transition() + .when_not(is_last_round) + .assert_eq(local.is_real, next.is_real); } fn eval_mem( @@ -103,20 +161,23 @@ impl Poseidon2Chip { next: &Poseidon2Cols, is_memory_read: AB::Var, is_memory_write: AB::Var, - memory_access: AB::Expr, + memory_access: AB::Var, ) { let memory_access_cols = local.round_specific_cols.memory_access(); builder + .when(local.is_real) .when(is_memory_read) .assert_eq(local.left_input, memory_access_cols.addr_first_half); builder + .when(local.is_real) .when(is_memory_read) .assert_eq(local.right_input, memory_access_cols.addr_second_half); builder + .when(local.is_real) .when(is_memory_write) .assert_eq(local.dst_input, memory_access_cols.addr_first_half); - builder.when(is_memory_write).assert_eq( + builder.when(local.is_real).when(is_memory_write).assert_eq( local.dst_input + AB::F::from_canonical_usize(WIDTH / 2), memory_access_cols.addr_second_half, ); @@ -131,7 +192,11 @@ impl Poseidon2Chip { local.clk + AB::Expr::one() * is_memory_write, addr, &memory_access_cols.mem_access[i], - memory_access.clone(), + memory_access, + ); + builder.when(local.is_real).when(is_memory_read).assert_eq( + *memory_access_cols.mem_access[i].value(), + *memory_access_cols.mem_access[i].prev_value(), ); } @@ -139,10 +204,14 @@ impl Poseidon2Chip { // computation round. let next_computation_col = next.round_specific_cols.computation(); for i in 0..WIDTH { - builder.when_transition().when(is_memory_read).assert_eq( - *memory_access_cols.mem_access[i].value(), - next_computation_col.input[i], - ); + builder + .when_transition() + .when(local.is_real) + .when(is_memory_read) + .assert_eq( + *memory_access_cols.mem_access[i].value(), + next_computation_col.input[i], + ); } } @@ -184,6 +253,7 @@ impl Poseidon2Chip { } } builder + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(result, computation_cols.add_rc[i]); } @@ -196,8 +266,15 @@ impl Poseidon2Chip { let sbox_deg_3 = computation_cols.add_rc[i] * computation_cols.add_rc[i] * computation_cols.add_rc[i]; - let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3.clone() * computation_cols.add_rc[i]; builder + .when(local.is_real) + .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) + .assert_eq(computation_cols.sbox_deg_3[i], sbox_deg_3); + let sbox_deg_7 = computation_cols.sbox_deg_3[i] + * computation_cols.sbox_deg_3[i] + * computation_cols.add_rc[i]; + builder + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(sbox_deg_7, computation_cols.sbox_deg_7[i]); } @@ -253,6 +330,7 @@ impl Poseidon2Chip { for i in 0..WIDTH { state[i] += sums[i % 4].clone(); builder + .when(local.is_real) .when(is_external_layer.clone() + is_initial.clone()) .assert_eq(state[i].clone(), computation_cols.output[i]); } @@ -264,6 +342,7 @@ impl Poseidon2Chip { let mut state: [AB::Expr; WIDTH] = sbox_result.clone(); internal_linear_layer(&mut state); builder + .when(local.is_real) .when(is_internal_layer.clone()) .assert_all_eq(state.clone(), computation_cols.output); } @@ -281,6 +360,7 @@ impl Poseidon2Chip { builder .when_transition() + .when(local.is_real) .when(is_initial.clone() + is_external_layer.clone() + is_internal_layer.clone()) .assert_eq(computation_cols.output[i], next_round_value); } @@ -307,13 +387,11 @@ impl Poseidon2Chip { } pub const fn do_receive_table(local: &Poseidon2Cols) -> T { - local.rounds[0] + local.do_receive } - pub fn do_memory_access, Output>( - local: &Poseidon2Cols, - ) -> Output { - local.rounds[0] + local.rounds[23] + pub fn do_memory_access(local: &Poseidon2Cols) -> T { + local.do_memory } } @@ -333,7 +411,7 @@ where local, next, Self::do_receive_table::(local), - Self::do_memory_access::(local), + Self::do_memory_access::(local), ); } } diff --git a/recursion/core/src/poseidon2/trace.rs b/recursion/core/src/poseidon2/trace.rs index cc6a41d94..2d5639edd 100644 --- a/recursion/core/src/poseidon2/trace.rs +++ b/recursion/core/src/poseidon2/trace.rs @@ -49,7 +49,9 @@ impl MachineAir for Poseidon2Chip { for r in 0..rounds { let mut row = [F::zero(); NUM_POSEIDON2_COLS]; let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); + cols.is_real = F::one(); + let is_receive = r == 0; let is_memory_read = r == 0; let is_initial_layer = r == 1; let is_external_layer = @@ -78,6 +80,10 @@ impl MachineAir for Poseidon2Chip { cols.right_input = poseidon2_event.right; cols.rounds[r] = F::one(); + if is_receive { + cols.do_receive = F::one(); + } + if is_memory_read || is_memory_write { let memory_access_cols = cols.round_specific_cols.memory_access_mut(); @@ -97,6 +103,7 @@ impl MachineAir for Poseidon2Chip { .populate(&poseidon2_event.result_records[i]); } } + cols.do_memory = F::one(); } else { let computation_cols = cols.round_specific_cols.computation_mut(); @@ -131,6 +138,7 @@ impl MachineAir for Poseidon2Chip { let sbox_deg_3 = computation_cols.add_rc[j] * computation_cols.add_rc[j] * computation_cols.add_rc[j]; + computation_cols.sbox_deg_3[j] = sbox_deg_3; computation_cols.sbox_deg_7[j] = sbox_deg_3 * sbox_deg_3 * computation_cols.add_rc[j]; } @@ -163,6 +171,8 @@ impl MachineAir for Poseidon2Chip { } } + let num_real_rows = rows.len(); + // Pad the trace to a power of two. pad_rows_fixed( &mut rows, @@ -170,6 +180,14 @@ impl MachineAir for Poseidon2Chip { self.fixed_log2_rows, ); + let mut round_num = 0; + for row in rows[num_real_rows..].iter_mut() { + let cols: &mut Poseidon2Cols = row.as_mut_slice().borrow_mut(); + cols.rounds[round_num] = F::one(); + + round_num = (round_num + 1) % rounds; + } + // Convert the trace to a row major matrix. RowMajorMatrix::new( rows.into_iter().flatten().collect::>(), diff --git a/recursion/gnark-ffi/Cargo.toml b/recursion/gnark-ffi/Cargo.toml index d10ed2fd3..467ba5bc0 100644 --- a/recursion/gnark-ffi/Cargo.toml +++ b/recursion/gnark-ffi/Cargo.toml @@ -5,8 +5,10 @@ edition = "2021" [dependencies] p3-field = { workspace = true } +p3-symmetric = { workspace = true } p3-baby-bear = { workspace = true } sp1-recursion-compiler = { path = "../compiler" } +sp1-core = { path = "../../core" } serde = "1.0.201" serde_json = "1.0.117" tempfile = "3.10.1" diff --git a/recursion/gnark-ffi/go/main.go b/recursion/gnark-ffi/go/main.go index 89bba4a7e..ed782400f 100644 --- a/recursion/gnark-ffi/go/main.go +++ b/recursion/gnark-ffi/go/main.go @@ -17,11 +17,15 @@ import ( "sync" "github.com/consensys/gnark-crypto/ecc" + "github.com/consensys/gnark/backend/groth16" "github.com/consensys/gnark/backend/plonk" "github.com/consensys/gnark/frontend" + "github.com/consensys/gnark/frontend/cs/r1cs" "github.com/consensys/gnark/frontend/cs/scs" "github.com/consensys/gnark/test/unsafekzg" "github.com/succinctlabs/sp1-recursion-gnark/sp1" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" ) func main() {} @@ -141,3 +145,73 @@ func TestMain() error { return nil } + +//export TestPoseidonBabyBear2 +func TestPoseidonBabyBear2() *C.char { + input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + babybear.NewF("0"), + } + + expectedOutput := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("348670919"), + babybear.NewF("1568590631"), + babybear.NewF("1535107508"), + babybear.NewF("186917780"), + babybear.NewF("587749971"), + babybear.NewF("1827585060"), + babybear.NewF("1218809104"), + babybear.NewF("691692291"), + babybear.NewF("1480664293"), + babybear.NewF("1491566329"), + babybear.NewF("366224457"), + babybear.NewF("490018300"), + babybear.NewF("732772134"), + babybear.NewF("560796067"), + babybear.NewF("484676252"), + babybear.NewF("405025962"), + } + + circuit := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput} + assignment := sp1.TestPoseidon2BabyBearCircuit{Input: input, ExpectedOutput: expectedOutput} + + builder := r1cs.NewBuilder + r1cs, err := frontend.Compile(ecc.BN254.ScalarField(), builder, &circuit) + if err != nil { + return C.CString(err.Error()) + } + + var pk groth16.ProvingKey + pk, err = groth16.DummySetup(r1cs) + if err != nil { + return C.CString(err.Error()) + } + + // Generate witness. + witness, err := frontend.NewWitness(&assignment, ecc.BN254.ScalarField()) + if err != nil { + return C.CString(err.Error()) + } + + // Generate the proof. + _, err = groth16.Prove(r1cs, pk, witness) + if err != nil { + return C.CString(err.Error()) + } + + return nil +} diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/constants.go b/recursion/gnark-ffi/go/sp1/poseidon2/constants.go index fb350f180..edb5a5e4a 100644 --- a/recursion/gnark-ffi/go/sp1/poseidon2/constants.go +++ b/recursion/gnark-ffi/go/sp1/poseidon2/constants.go @@ -3,11 +3,22 @@ package poseidon2 import ( "github.com/consensys/gnark/frontend" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" ) +// Poseidon2 round constants for a state consisting of three BN254 field elements. var RC3 [NUM_EXTERNAL_ROUNDS + NUM_INTERNAL_ROUNDS][WIDTH]frontend.Variable +// Poseidon2 round constaints for a state consisting of 16 BabyBear field elements. + +var RC16 [30][BABYBEAR_WIDTH]babybear.Variable + func init() { + init_rc3() + init_rc16() +} + +func init_rc3() { round := 0 RC3[round] = [WIDTH]frontend.Variable{ @@ -457,3 +468,580 @@ func init() { frontend.Variable("0x0fc1bbceba0590f5abbdffa6d3b35e3297c021a3a409926d0e2d54dc1c84fda6"), } } + +func init_rc16() { + round := 0 + + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2110014213"), + babybear.NewF("3964964605"), + babybear.NewF("2190662774"), + babybear.NewF("2732996483"), + babybear.NewF("640767983"), + babybear.NewF("3403899136"), + babybear.NewF("1716033721"), + babybear.NewF("1606702601"), + babybear.NewF("3759873288"), + babybear.NewF("1466015491"), + babybear.NewF("1498308946"), + babybear.NewF("2844375094"), + babybear.NewF("3042463841"), + babybear.NewF("1969905919"), + babybear.NewF("4109944726"), + babybear.NewF("3925048366"), + } + + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3706859504"), + babybear.NewF("759122502"), + babybear.NewF("3167665446"), + babybear.NewF("1131812921"), + babybear.NewF("1080754908"), + babybear.NewF("4080114493"), + babybear.NewF("893583089"), + babybear.NewF("2019677373"), + babybear.NewF("3128604556"), + babybear.NewF("580640471"), + babybear.NewF("3277620260"), + babybear.NewF("842931656"), + babybear.NewF("548879852"), + babybear.NewF("3608554714"), + babybear.NewF("3575647916"), + babybear.NewF("81826002"), + } + + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("4289086263"), + babybear.NewF("1563933798"), + babybear.NewF("1440025885"), + babybear.NewF("184445025"), + babybear.NewF("2598651360"), + babybear.NewF("1396647410"), + babybear.NewF("1575877922"), + babybear.NewF("3303853401"), + babybear.NewF("137125468"), + babybear.NewF("765010148"), + babybear.NewF("633675867"), + babybear.NewF("2037803363"), + babybear.NewF("2573389828"), + babybear.NewF("1895729703"), + babybear.NewF("541515871"), + babybear.NewF("1783382863"), + } + + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2641856484"), + babybear.NewF("3035743342"), + babybear.NewF("3672796326"), + babybear.NewF("245668751"), + babybear.NewF("2025460432"), + babybear.NewF("201609705"), + babybear.NewF("286217151"), + babybear.NewF("4093475563"), + babybear.NewF("2519572182"), + babybear.NewF("3080699870"), + babybear.NewF("2762001832"), + babybear.NewF("1244250808"), + babybear.NewF("606038199"), + babybear.NewF("3182740831"), + babybear.NewF("73007766"), + babybear.NewF("2572204153"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1196780786"), + babybear.NewF("3447394443"), + babybear.NewF("747167305"), + babybear.NewF("2968073607"), + babybear.NewF("1053214930"), + babybear.NewF("1074411832"), + babybear.NewF("4016794508"), + babybear.NewF("1570312929"), + babybear.NewF("113576933"), + babybear.NewF("4042581186"), + babybear.NewF("3634515733"), + babybear.NewF("1032701597"), + babybear.NewF("2364839308"), + babybear.NewF("3840286918"), + babybear.NewF("888378655"), + babybear.NewF("2520191583"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("36046858"), + babybear.NewF("2927525953"), + babybear.NewF("3912129105"), + babybear.NewF("4004832531"), + babybear.NewF("193772436"), + babybear.NewF("1590247392"), + babybear.NewF("4125818172"), + babybear.NewF("2516251696"), + babybear.NewF("4050945750"), + babybear.NewF("269498914"), + babybear.NewF("1973292656"), + babybear.NewF("891403491"), + babybear.NewF("1845429189"), + babybear.NewF("2611996363"), + babybear.NewF("2310542653"), + babybear.NewF("4071195740"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3505307391"), + babybear.NewF("786445290"), + babybear.NewF("3815313971"), + babybear.NewF("1111591756"), + babybear.NewF("4233279834"), + babybear.NewF("2775453034"), + babybear.NewF("1991257625"), + babybear.NewF("2940505809"), + babybear.NewF("2751316206"), + babybear.NewF("1028870679"), + babybear.NewF("1282466273"), + babybear.NewF("1059053371"), + babybear.NewF("834521354"), + babybear.NewF("138721483"), + babybear.NewF("3100410803"), + babybear.NewF("3843128331"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3878220780"), + babybear.NewF("4058162439"), + babybear.NewF("1478942487"), + babybear.NewF("799012923"), + babybear.NewF("496734827"), + babybear.NewF("3521261236"), + babybear.NewF("755421082"), + babybear.NewF("1361409515"), + babybear.NewF("392099473"), + babybear.NewF("3178453393"), + babybear.NewF("4068463721"), + babybear.NewF("7935614"), + babybear.NewF("4140885645"), + babybear.NewF("2150748066"), + babybear.NewF("1685210312"), + babybear.NewF("3852983224"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2896943075"), + babybear.NewF("3087590927"), + babybear.NewF("992175959"), + babybear.NewF("970216228"), + babybear.NewF("3473630090"), + babybear.NewF("3899670400"), + babybear.NewF("3603388822"), + babybear.NewF("2633488197"), + babybear.NewF("2479406964"), + babybear.NewF("2420952999"), + babybear.NewF("1852516800"), + babybear.NewF("4253075697"), + babybear.NewF("979699862"), + babybear.NewF("1163403191"), + babybear.NewF("1608599874"), + babybear.NewF("3056104448"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3779109343"), + babybear.NewF("536205958"), + babybear.NewF("4183458361"), + babybear.NewF("1649720295"), + babybear.NewF("1444912244"), + babybear.NewF("3122230878"), + babybear.NewF("384301396"), + babybear.NewF("4228198516"), + babybear.NewF("1662916865"), + babybear.NewF("4082161114"), + babybear.NewF("2121897314"), + babybear.NewF("1706239958"), + babybear.NewF("4166959388"), + babybear.NewF("1626054781"), + babybear.NewF("3005858978"), + babybear.NewF("1431907253"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1418914503"), + babybear.NewF("1365856753"), + babybear.NewF("3942715745"), + babybear.NewF("1429155552"), + babybear.NewF("3545642795"), + babybear.NewF("3772474257"), + babybear.NewF("1621094396"), + babybear.NewF("2154399145"), + babybear.NewF("826697382"), + babybear.NewF("1700781391"), + babybear.NewF("3539164324"), + babybear.NewF("652815039"), + babybear.NewF("442484755"), + babybear.NewF("2055299391"), + babybear.NewF("1064289978"), + babybear.NewF("1152335780"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3417648695"), + babybear.NewF("186040114"), + babybear.NewF("3475580573"), + babybear.NewF("2113941250"), + babybear.NewF("1779573826"), + babybear.NewF("1573808590"), + babybear.NewF("3235694804"), + babybear.NewF("2922195281"), + babybear.NewF("1119462702"), + babybear.NewF("3688305521"), + babybear.NewF("1849567013"), + babybear.NewF("667446787"), + babybear.NewF("753897224"), + babybear.NewF("1896396780"), + babybear.NewF("3143026334"), + babybear.NewF("3829603876"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("859661334"), + babybear.NewF("3898844357"), + babybear.NewF("180258337"), + babybear.NewF("2321867017"), + babybear.NewF("3599002504"), + babybear.NewF("2886782421"), + babybear.NewF("3038299378"), + babybear.NewF("1035366250"), + babybear.NewF("2038912197"), + babybear.NewF("2920174523"), + babybear.NewF("1277696101"), + babybear.NewF("2785700290"), + babybear.NewF("3806504335"), + babybear.NewF("3518858933"), + babybear.NewF("654843672"), + babybear.NewF("2127120275"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1548195514"), + babybear.NewF("2378056027"), + babybear.NewF("390914568"), + babybear.NewF("1472049779"), + babybear.NewF("1552596765"), + babybear.NewF("1905886441"), + babybear.NewF("1611959354"), + babybear.NewF("3653263304"), + babybear.NewF("3423946386"), + babybear.NewF("340857935"), + babybear.NewF("2208879480"), + babybear.NewF("139364268"), + babybear.NewF("3447281773"), + babybear.NewF("3777813707"), + babybear.NewF("55640413"), + babybear.NewF("4101901741"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("104929687"), + babybear.NewF("1459980974"), + babybear.NewF("1831234737"), + babybear.NewF("457139004"), + babybear.NewF("2581487628"), + babybear.NewF("2112044563"), + babybear.NewF("3567013861"), + babybear.NewF("2792004347"), + babybear.NewF("576325418"), + babybear.NewF("41126132"), + babybear.NewF("2713562324"), + babybear.NewF("151213722"), + babybear.NewF("2891185935"), + babybear.NewF("546846420"), + babybear.NewF("2939794919"), + babybear.NewF("2543469905"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2191909784"), + babybear.NewF("3315138460"), + babybear.NewF("530414574"), + babybear.NewF("1242280418"), + babybear.NewF("1211740715"), + babybear.NewF("3993672165"), + babybear.NewF("2505083323"), + babybear.NewF("3845798801"), + babybear.NewF("538768466"), + babybear.NewF("2063567560"), + babybear.NewF("3366148274"), + babybear.NewF("1449831887"), + babybear.NewF("2408012466"), + babybear.NewF("294726285"), + babybear.NewF("3943435493"), + babybear.NewF("924016661"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3633138367"), + babybear.NewF("3222789372"), + babybear.NewF("809116305"), + babybear.NewF("30100013"), + babybear.NewF("2655172876"), + babybear.NewF("2564247117"), + babybear.NewF("2478649732"), + babybear.NewF("4113689151"), + babybear.NewF("4120146082"), + babybear.NewF("2512308515"), + babybear.NewF("650406041"), + babybear.NewF("4240012393"), + babybear.NewF("2683508708"), + babybear.NewF("951073977"), + babybear.NewF("3460081988"), + babybear.NewF("339124269"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("130182653"), + babybear.NewF("2755946749"), + babybear.NewF("542600513"), + babybear.NewF("2816103022"), + babybear.NewF("1931786340"), + babybear.NewF("2044470840"), + babybear.NewF("1709908013"), + babybear.NewF("2938369043"), + babybear.NewF("3640399693"), + babybear.NewF("1374470239"), + babybear.NewF("2191149676"), + babybear.NewF("2637495682"), + babybear.NewF("4236394040"), + babybear.NewF("2289358846"), + babybear.NewF("3833368530"), + babybear.NewF("974546524"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3306659113"), + babybear.NewF("2234814261"), + babybear.NewF("1188782305"), + babybear.NewF("223782844"), + babybear.NewF("2248980567"), + babybear.NewF("2309786141"), + babybear.NewF("2023401627"), + babybear.NewF("3278877413"), + babybear.NewF("2022138149"), + babybear.NewF("575851471"), + babybear.NewF("1612560780"), + babybear.NewF("3926656936"), + babybear.NewF("3318548977"), + babybear.NewF("2591863678"), + babybear.NewF("188109355"), + babybear.NewF("4217723909"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1564209905"), + babybear.NewF("2154197895"), + babybear.NewF("2459687029"), + babybear.NewF("2870634489"), + babybear.NewF("1375012945"), + babybear.NewF("1529454825"), + babybear.NewF("306140690"), + babybear.NewF("2855578299"), + babybear.NewF("1246997295"), + babybear.NewF("3024298763"), + babybear.NewF("1915270363"), + babybear.NewF("1218245412"), + babybear.NewF("2479314020"), + babybear.NewF("2989827755"), + babybear.NewF("814378556"), + babybear.NewF("4039775921"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1165280628"), + babybear.NewF("1203983801"), + babybear.NewF("3814740033"), + babybear.NewF("1919627044"), + babybear.NewF("600240215"), + babybear.NewF("773269071"), + babybear.NewF("486685186"), + babybear.NewF("4254048810"), + babybear.NewF("1415023565"), + babybear.NewF("502840102"), + babybear.NewF("4225648358"), + babybear.NewF("510217063"), + babybear.NewF("166444818"), + babybear.NewF("1430745893"), + babybear.NewF("1376516190"), + babybear.NewF("1775891321"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1170945922"), + babybear.NewF("1105391877"), + babybear.NewF("261536467"), + babybear.NewF("1401687994"), + babybear.NewF("1022529847"), + babybear.NewF("2476446456"), + babybear.NewF("2603844878"), + babybear.NewF("3706336043"), + babybear.NewF("3463053714"), + babybear.NewF("1509644517"), + babybear.NewF("588552318"), + babybear.NewF("65252581"), + babybear.NewF("3696502656"), + babybear.NewF("2183330763"), + babybear.NewF("3664021233"), + babybear.NewF("1643809916"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2922875898"), + babybear.NewF("3740690643"), + babybear.NewF("3932461140"), + babybear.NewF("161156271"), + babybear.NewF("2619943483"), + babybear.NewF("4077039509"), + babybear.NewF("2921201703"), + babybear.NewF("2085619718"), + babybear.NewF("2065264646"), + babybear.NewF("2615693812"), + babybear.NewF("3116555433"), + babybear.NewF("246100007"), + babybear.NewF("4281387154"), + babybear.NewF("4046141001"), + babybear.NewF("4027749321"), + babybear.NewF("111611860"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2066954820"), + babybear.NewF("2502099969"), + babybear.NewF("2915053115"), + babybear.NewF("2362518586"), + babybear.NewF("366091708"), + babybear.NewF("2083204932"), + babybear.NewF("4138385632"), + babybear.NewF("3195157567"), + babybear.NewF("1318086382"), + babybear.NewF("521723799"), + babybear.NewF("702443405"), + babybear.NewF("2507670985"), + babybear.NewF("1760347557"), + babybear.NewF("2631999893"), + babybear.NewF("1672737554"), + babybear.NewF("1060867760"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2359801781"), + babybear.NewF("2800231467"), + babybear.NewF("3010357035"), + babybear.NewF("1035997899"), + babybear.NewF("1210110952"), + babybear.NewF("1018506770"), + babybear.NewF("2799468177"), + babybear.NewF("1479380761"), + babybear.NewF("1536021911"), + babybear.NewF("358993854"), + babybear.NewF("579904113"), + babybear.NewF("3432144800"), + babybear.NewF("3625515809"), + babybear.NewF("199241497"), + babybear.NewF("4058304109"), + babybear.NewF("2590164234"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("1688530738"), + babybear.NewF("1580733335"), + babybear.NewF("2443981517"), + babybear.NewF("2206270565"), + babybear.NewF("2780074229"), + babybear.NewF("2628739677"), + babybear.NewF("2940123659"), + babybear.NewF("4145206827"), + babybear.NewF("3572278009"), + babybear.NewF("2779607509"), + babybear.NewF("1098718697"), + babybear.NewF("1424913749"), + babybear.NewF("2224415875"), + babybear.NewF("1108922178"), + babybear.NewF("3646272562"), + babybear.NewF("3935186184"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("820046587"), + babybear.NewF("1393386250"), + babybear.NewF("2665818575"), + babybear.NewF("2231782019"), + babybear.NewF("672377010"), + babybear.NewF("1920315467"), + babybear.NewF("1913164407"), + babybear.NewF("2029526876"), + babybear.NewF("2629271820"), + babybear.NewF("384320012"), + babybear.NewF("4112320585"), + babybear.NewF("3131824773"), + babybear.NewF("2347818197"), + babybear.NewF("2220997386"), + babybear.NewF("1772368609"), + babybear.NewF("2579960095"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3544930873"), + babybear.NewF("225847443"), + babybear.NewF("3070082278"), + babybear.NewF("95643305"), + babybear.NewF("3438572042"), + babybear.NewF("3312856509"), + babybear.NewF("615850007"), + babybear.NewF("1863868773"), + babybear.NewF("803582265"), + babybear.NewF("3461976859"), + babybear.NewF("2903025799"), + babybear.NewF("1482092434"), + babybear.NewF("3902972499"), + babybear.NewF("3872341868"), + babybear.NewF("1530411808"), + babybear.NewF("2214923584"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("3118792481"), + babybear.NewF("2241076515"), + babybear.NewF("3983669831"), + babybear.NewF("3180915147"), + babybear.NewF("3838626501"), + babybear.NewF("1921630011"), + babybear.NewF("3415351771"), + babybear.NewF("2249953859"), + babybear.NewF("3755081630"), + babybear.NewF("486327260"), + babybear.NewF("1227575720"), + babybear.NewF("3643869379"), + babybear.NewF("2982026073"), + babybear.NewF("2466043731"), + babybear.NewF("1982634375"), + babybear.NewF("3769609014"), + } + round += 1 + RC16[round] = [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2195455495"), + babybear.NewF("2596863283"), + babybear.NewF("4244994973"), + babybear.NewF("1983609348"), + babybear.NewF("4019674395"), + babybear.NewF("3469982031"), + babybear.NewF("1458697570"), + babybear.NewF("1593516217"), + babybear.NewF("1963896497"), + babybear.NewF("3115309118"), + babybear.NewF("1659132465"), + babybear.NewF("2536770756"), + babybear.NewF("3059294171"), + babybear.NewF("2618031334"), + babybear.NewF("2040903247"), + babybear.NewF("3799795076"), + } +} diff --git a/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go new file mode 100644 index 000000000..9f8395623 --- /dev/null +++ b/recursion/gnark-ffi/go/sp1/poseidon2/poseidon2_babybear.go @@ -0,0 +1,157 @@ +package poseidon2 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" +) + +const BABYBEAR_WIDTH = 16 +const BABYBEAR_NUM_EXTERNAL_ROUNDS = 8 +const BABYBEAR_NUM_INTERNAL_ROUNDS = 13 +const BABYBEAR_DEGREE = 7 + +type Poseidon2BabyBearChip struct { + api frontend.API + fieldApi *babybear.Chip +} + +func NewBabyBearChip(api frontend.API) *Poseidon2BabyBearChip { + return &Poseidon2BabyBearChip{ + api: api, + fieldApi: babybear.NewChip(api), + } +} + +func (p *Poseidon2BabyBearChip) PermuteMut(state *[BABYBEAR_WIDTH]babybear.Variable) { + // The initial linear layer. + p.externalLinearLayer(state) + + // The first half of the external rounds. + rounds := BABYBEAR_NUM_EXTERNAL_ROUNDS + BABYBEAR_NUM_INTERNAL_ROUNDS + roundsFBeggining := BABYBEAR_NUM_EXTERNAL_ROUNDS / 2 + for r := 0; r < roundsFBeggining; r++ { + p.addRc(state, RC16[r]) + p.sbox(state) + p.externalLinearLayer(state) + } + + // The internal rounds. + p_end := roundsFBeggining + BABYBEAR_NUM_INTERNAL_ROUNDS + for r := roundsFBeggining; r < p_end; r++ { + state[0] = p.fieldApi.AddF(state[0], RC16[r][0]) + state[0] = p.sboxP(state[0]) + p.diffusionPermuteMut(state) + } + + // The second half of the external rounds. + for r := p_end; r < rounds; r++ { + p.addRc(state, RC16[r]) + p.sbox(state) + p.externalLinearLayer(state) + } +} + +func (p *Poseidon2BabyBearChip) addRc(state *[BABYBEAR_WIDTH]babybear.Variable, rc [BABYBEAR_WIDTH]babybear.Variable) { + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.AddF(state[i], rc[i]) + } +} + +func (p *Poseidon2BabyBearChip) sboxP(input babybear.Variable) babybear.Variable { + zero := babybear.NewF("0") + inputCpy := p.fieldApi.AddF(input, zero) + inputCpy = p.fieldApi.ReduceSlow(inputCpy) + inputValue := inputCpy.Value + i2 := p.api.Mul(inputValue, inputValue) + i4 := p.api.Mul(i2, i2) + i6 := p.api.Mul(i4, i2) + i7 := p.api.Mul(i6, inputValue) + i7bb := p.fieldApi.ReduceSlow(babybear.Variable{ + Value: i7, + NbBits: 31 * 7, + }) + return i7bb +} + +func (p *Poseidon2BabyBearChip) sbox(state *[BABYBEAR_WIDTH]babybear.Variable) { + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.sboxP(state[i]) + } +} + +func (p *Poseidon2BabyBearChip) mdsLightPermutation4x4(state []babybear.Variable) { + t01 := p.fieldApi.AddF(state[0], state[1]) + t23 := p.fieldApi.AddF(state[2], state[3]) + t0123 := p.fieldApi.AddF(t01, t23) + t01123 := p.fieldApi.AddF(t0123, state[1]) + t01233 := p.fieldApi.AddF(t0123, state[3]) + state[3] = p.fieldApi.AddF(t01233, p.fieldApi.MulFConst(state[0], 2)) + state[1] = p.fieldApi.AddF(t01123, p.fieldApi.MulFConst(state[2], 2)) + state[0] = p.fieldApi.AddF(t01123, t01) + state[2] = p.fieldApi.AddF(t01233, t23) +} + +func (p *Poseidon2BabyBearChip) externalLinearLayer(state *[BABYBEAR_WIDTH]babybear.Variable) { + for i := 0; i < BABYBEAR_WIDTH; i += 4 { + p.mdsLightPermutation4x4(state[i : i+4]) + } + + sums := [4]babybear.Variable{ + state[0], + state[1], + state[2], + state[3], + } + for i := 4; i < BABYBEAR_WIDTH; i += 4 { + sums[0] = p.fieldApi.AddF(sums[0], state[i]) + sums[1] = p.fieldApi.AddF(sums[1], state[i+1]) + sums[2] = p.fieldApi.AddF(sums[2], state[i+2]) + sums[3] = p.fieldApi.AddF(sums[3], state[i+3]) + } + + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.AddF(state[i], sums[i%4]) + } +} + +func (p *Poseidon2BabyBearChip) diffusionPermuteMut(state *[BABYBEAR_WIDTH]babybear.Variable) { + matInternalDiagM1 := [BABYBEAR_WIDTH]babybear.Variable{ + babybear.NewF("2013265919"), + babybear.NewF("1"), + babybear.NewF("2"), + babybear.NewF("4"), + babybear.NewF("8"), + babybear.NewF("16"), + babybear.NewF("32"), + babybear.NewF("64"), + babybear.NewF("128"), + babybear.NewF("256"), + babybear.NewF("512"), + babybear.NewF("1024"), + babybear.NewF("2048"), + babybear.NewF("4096"), + babybear.NewF("8192"), + babybear.NewF("32768"), + } + montyInverse := babybear.NewF("943718400") + p.matmulInternal(state, &matInternalDiagM1) + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.MulF(state[i], montyInverse) + } + +} + +func (p *Poseidon2BabyBearChip) matmulInternal( + state *[BABYBEAR_WIDTH]babybear.Variable, + matInternalDiagM1 *[BABYBEAR_WIDTH]babybear.Variable, +) { + sum := babybear.NewF("0") + for i := 0; i < BABYBEAR_WIDTH; i++ { + sum = p.fieldApi.AddF(sum, state[i]) + } + + for i := 0; i < BABYBEAR_WIDTH; i++ { + state[i] = p.fieldApi.MulF(state[i], matInternalDiagM1[i]) + state[i] = p.fieldApi.AddF(state[i], sum) + } +} diff --git a/recursion/gnark-ffi/go/sp1/sp1.go b/recursion/gnark-ffi/go/sp1/sp1.go index f3f3b24a5..ccde52095 100644 --- a/recursion/gnark-ffi/go/sp1/sp1.go +++ b/recursion/gnark-ffi/go/sp1/sp1.go @@ -68,6 +68,7 @@ func (circuit *Circuit) Define(api frontend.API) error { } hashAPI := poseidon2.NewChip(api) + hashBabyBearAPI := poseidon2.NewBabyBearChip(api) fieldAPI := babybear.NewChip(api) vars := make(map[string]frontend.Variable) felts := make(map[string]babybear.Variable) @@ -132,6 +133,15 @@ func (circuit *Circuit) Define(api frontend.API) error { vars[cs.Args[0][0]] = state[0] vars[cs.Args[1][0]] = state[1] vars[cs.Args[2][0]] = state[2] + case "PermuteBabyBear": + var state [16]babybear.Variable + for i := 0; i < 16; i++ { + state[i] = felts[cs.Args[i][0]] + } + hashBabyBearAPI.PermuteMut(&state) + for i := 0; i < 16; i++ { + felts[cs.Args[i][0]] = state[i] + } case "SelectV": vars[cs.Args[0][0]] = api.Select(vars[cs.Args[1][0]], vars[cs.Args[2][0]], vars[cs.Args[3][0]]) case "SelectF": diff --git a/recursion/gnark-ffi/go/sp1/test.go b/recursion/gnark-ffi/go/sp1/test.go new file mode 100644 index 000000000..8d2aa8f0a --- /dev/null +++ b/recursion/gnark-ffi/go/sp1/test.go @@ -0,0 +1,31 @@ +package sp1 + +import ( + "github.com/consensys/gnark/frontend" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/babybear" + "github.com/succinctlabs/sp1-recursion-gnark/sp1/poseidon2" +) + +type TestPoseidon2BabyBearCircuit struct { + Input [poseidon2.BABYBEAR_WIDTH]babybear.Variable `gnark:",public"` + ExpectedOutput [poseidon2.BABYBEAR_WIDTH]babybear.Variable `gnark:",public"` +} + +func (circuit *TestPoseidon2BabyBearCircuit) Define(api frontend.API) error { + poseidon2BabyBearChip := poseidon2.NewBabyBearChip(api) + fieldApi := babybear.NewChip(api) + + zero := babybear.NewF("0") + input := [poseidon2.BABYBEAR_WIDTH]babybear.Variable{} + for i := 0; i < poseidon2.BABYBEAR_WIDTH; i++ { + input[i] = fieldApi.AddF(circuit.Input[i], zero) + } + + poseidon2BabyBearChip.PermuteMut(&input) + + for i := 0; i < poseidon2.BABYBEAR_WIDTH; i++ { + fieldApi.AssertIsEqualF(circuit.ExpectedOutput[i], input[i]) + } + + return nil +} diff --git a/recursion/gnark-ffi/src/ffi.rs b/recursion/gnark-ffi/src/ffi.rs index d7ecf9d61..35a279ff2 100644 --- a/recursion/gnark-ffi/src/ffi.rs +++ b/recursion/gnark-ffi/src/ffi.rs @@ -110,6 +110,23 @@ pub fn test_plonk_bn254(witness_json: &str, constraints_json: &str) { } } +pub fn test_babybear_poseidon2() { + cfg_if! { + if #[cfg(feature = "plonk")] { + unsafe { + let err_ptr = bind::TestPoseidonBabyBear2(); + if !err_ptr.is_null() { + // Safety: The error message is returned from the go code and is guaranteed to be valid. + let err = CString::from_raw(err_ptr); + panic!("TestPlonkBn254 failed: {}", err.into_string().unwrap()); + } + } + } else { + panic!("plonk feature not enabled"); + } + } +} + /// Converts a C string into a Rust String. /// /// # Safety @@ -140,3 +157,20 @@ impl C_PlonkBn254Proof { } } } + +#[cfg(test)] +mod tests { + use p3_baby_bear::BabyBear; + use p3_field::AbstractField; + use p3_symmetric::Permutation; + + #[cfg(feature = "plonk")] + #[test] + pub fn test_babybear_poseidon2() { + let perm = sp1_core::utils::inner_perm(); + let zeros = [BabyBear::zero(); 16]; + let result = perm.permute(zeros); + println!("{:?}", result); + super::test_babybear_poseidon2(); + } +} diff --git a/recursion/groth16/constraints.json b/recursion/groth16/constraints.json deleted file mode 100644 index 28fc1fdac..000000000 --- a/recursion/groth16/constraints.json +++ /dev/null @@ -1 +0,0 @@ -[{"opcode":"ImmV","args":[["var0"],["100"]]},{"opcode":"Num2BitsV","args":[["var1","var2","var3","var4","var5","var6","var7","var8","var9","var10","var11","var12","var13","var14","var15","var16","var17","var18","var19","var20","var21","var22","var23","var24","var25","var26","var27","var28","var29","var30","var31","var32"],["var0"],["32"]]},{"opcode":"ImmV","args":[["backend0"],["0"]]},{"opcode":"AssertEqV","args":[["var1"],["backend0"]]},{"opcode":"ImmV","args":[["backend1"],["0"]]},{"opcode":"AssertEqV","args":[["var2"],["backend1"]]},{"opcode":"ImmV","args":[["backend2"],["1"]]},{"opcode":"AssertEqV","args":[["var3"],["backend2"]]},{"opcode":"ImmV","args":[["backend3"],["0"]]},{"opcode":"AssertEqV","args":[["var4"],["backend3"]]},{"opcode":"ImmV","args":[["backend4"],["0"]]},{"opcode":"AssertEqV","args":[["var5"],["backend4"]]},{"opcode":"ImmV","args":[["backend5"],["1"]]},{"opcode":"AssertEqV","args":[["var6"],["backend5"]]},{"opcode":"ImmV","args":[["backend6"],["1"]]},{"opcode":"AssertEqV","args":[["var7"],["backend6"]]},{"opcode":"ImmV","args":[["backend7"],["0"]]},{"opcode":"AssertEqV","args":[["var8"],["backend7"]]},{"opcode":"ImmV","args":[["backend8"],["0"]]},{"opcode":"AssertEqV","args":[["var9"],["backend8"]]},{"opcode":"ImmV","args":[["backend9"],["0"]]},{"opcode":"AssertEqV","args":[["var10"],["backend9"]]},{"opcode":"ImmV","args":[["backend10"],["0"]]},{"opcode":"AssertEqV","args":[["var11"],["backend10"]]},{"opcode":"ImmV","args":[["backend11"],["0"]]},{"opcode":"AssertEqV","args":[["var12"],["backend11"]]},{"opcode":"ImmV","args":[["backend12"],["0"]]},{"opcode":"AssertEqV","args":[["var13"],["backend12"]]},{"opcode":"ImmV","args":[["backend13"],["0"]]},{"opcode":"AssertEqV","args":[["var14"],["backend13"]]},{"opcode":"ImmV","args":[["backend14"],["0"]]},{"opcode":"AssertEqV","args":[["var15"],["backend14"]]},{"opcode":"ImmV","args":[["backend15"],["0"]]},{"opcode":"AssertEqV","args":[["var16"],["backend15"]]},{"opcode":"ImmV","args":[["backend16"],["0"]]},{"opcode":"AssertEqV","args":[["var17"],["backend16"]]},{"opcode":"ImmV","args":[["backend17"],["0"]]},{"opcode":"AssertEqV","args":[["var18"],["backend17"]]},{"opcode":"ImmV","args":[["backend18"],["0"]]},{"opcode":"AssertEqV","args":[["var19"],["backend18"]]},{"opcode":"ImmV","args":[["backend19"],["0"]]},{"opcode":"AssertEqV","args":[["var20"],["backend19"]]},{"opcode":"ImmV","args":[["backend20"],["0"]]},{"opcode":"AssertEqV","args":[["var21"],["backend20"]]},{"opcode":"ImmV","args":[["backend21"],["0"]]},{"opcode":"AssertEqV","args":[["var22"],["backend21"]]},{"opcode":"ImmV","args":[["backend22"],["0"]]},{"opcode":"AssertEqV","args":[["var23"],["backend22"]]},{"opcode":"ImmV","args":[["backend23"],["0"]]},{"opcode":"AssertEqV","args":[["var24"],["backend23"]]},{"opcode":"ImmV","args":[["backend24"],["0"]]},{"opcode":"AssertEqV","args":[["var25"],["backend24"]]},{"opcode":"ImmV","args":[["backend25"],["0"]]},{"opcode":"AssertEqV","args":[["var26"],["backend25"]]},{"opcode":"ImmV","args":[["backend26"],["0"]]},{"opcode":"AssertEqV","args":[["var27"],["backend26"]]},{"opcode":"ImmV","args":[["backend27"],["0"]]},{"opcode":"AssertEqV","args":[["var28"],["backend27"]]},{"opcode":"ImmV","args":[["backend28"],["0"]]},{"opcode":"AssertEqV","args":[["var29"],["backend28"]]},{"opcode":"ImmV","args":[["backend29"],["0"]]},{"opcode":"AssertEqV","args":[["var30"],["backend29"]]},{"opcode":"ImmV","args":[["backend30"],["0"]]},{"opcode":"AssertEqV","args":[["var31"],["backend30"]]},{"opcode":"ImmV","args":[["backend31"],["0"]]},{"opcode":"AssertEqV","args":[["var32"],["backend31"]]}] \ No newline at end of file diff --git a/recursion/groth16/lib/libbabybear.a b/recursion/groth16/lib/libbabybear.a deleted file mode 100644 index e047c9496..000000000 Binary files a/recursion/groth16/lib/libbabybear.a and /dev/null differ diff --git a/recursion/groth16/main b/recursion/groth16/main deleted file mode 100755 index 126a88bb4..000000000 Binary files a/recursion/groth16/main and /dev/null differ diff --git a/recursion/groth16/witness.json b/recursion/groth16/witness.json deleted file mode 100644 index ed4386877..000000000 --- a/recursion/groth16/witness.json +++ /dev/null @@ -1 +0,0 @@ -{"vars":["999"],"felts":["999"],"exts":[["999","0","0","0"]]} \ No newline at end of file diff --git a/recursion/program/src/machine/compress.rs b/recursion/program/src/machine/compress.rs index f8fbc857b..406a7a04c 100644 --- a/recursion/program/src/machine/compress.rs +++ b/recursion/program/src/machine/compress.rs @@ -236,6 +236,7 @@ where challenger.observe(builder, element); } // verify proof. + let shard_idx = builder.eval(C::N::one()); StarkVerifier::::verify_shard( builder, &vk, @@ -243,6 +244,7 @@ where machine, &mut challenger, &proof, + shard_idx, ); // Load the public values from the proof. diff --git a/recursion/program/src/machine/core.rs b/recursion/program/src/machine/core.rs index 1b8635110..515bb1e7b 100644 --- a/recursion/program/src/machine/core.rs +++ b/recursion/program/src/machine/core.rs @@ -160,12 +160,18 @@ where let cumulative_sum: Ext<_, _> = builder.eval(C::EF::zero().cons()); let current_pc: Felt<_> = builder.uninit(); let exit_code: Felt<_> = builder.uninit(); + + // Range check that the number of proofs is sufficiently small. + let num_shard_proofs: Var<_> = shard_proofs.len().materialize(builder); + builder.range_check_v(num_shard_proofs, 16); + // Verify proofs, validate transitions, and update accumulation variables. builder.range(0, shard_proofs.len()).for_each(|i, builder| { // Load the proof. let proof = builder.get(&shard_proofs, i); // Verify the shard proof. + let shard_idx = builder.eval(i + C::N::one()); let mut challenger = leaf_challenger.copy(builder); StarkVerifier::::verify_shard( builder, @@ -174,6 +180,7 @@ where machine, &mut challenger, &proof, + shard_idx, ); // Extract public values. @@ -263,6 +270,9 @@ where // Assert that exit code is the same for all proofs. builder.assert_felt_eq(exit_code, public_values.exit_code); + // Assert that the exit code is zero (success) for all proofs. + builder.assert_felt_eq(exit_code, C::F::zero()); + // Assert that the deferred proof digest is the same for all proofs. for (digest, current_digest) in deferred_proofs_digest .iter() diff --git a/recursion/program/src/machine/deferred.rs b/recursion/program/src/machine/deferred.rs index 2ae232ab7..be380516b 100644 --- a/recursion/program/src/machine/deferred.rs +++ b/recursion/program/src/machine/deferred.rs @@ -187,7 +187,9 @@ where let element = builder.get(&proof.public_values, j); challenger.observe(builder, element); } - // verify the proof. + + // Verify the proof. + let shard_idx = builder.eval(C::N::one()); StarkVerifier::::verify_shard( builder, &compress_vk, @@ -195,6 +197,7 @@ where machine, &mut challenger, &proof, + shard_idx, ); // Load the public values from the proof. diff --git a/recursion/program/src/machine/root.rs b/recursion/program/src/machine/root.rs index 8e8eb72c6..4b3cb9e88 100644 --- a/recursion/program/src/machine/root.rs +++ b/recursion/program/src/machine/root.rs @@ -107,7 +107,16 @@ where challenger.observe(builder, element); } // verify proof. - StarkVerifier::::verify_shard(builder, &vk, pcs, machine, &mut challenger, proof); + let shard_idx = builder.eval(C::N::one()); + StarkVerifier::::verify_shard( + builder, + &vk, + pcs, + machine, + &mut challenger, + proof, + shard_idx, + ); // Get the public inputs from the proof. let public_values_elements = (0..RECURSIVE_PROOF_NUM_PV_ELTS) diff --git a/recursion/program/src/stark.rs b/recursion/program/src/stark.rs index aec5d9453..f3040fb7f 100644 --- a/recursion/program/src/stark.rs +++ b/recursion/program/src/stark.rs @@ -12,7 +12,6 @@ use sp1_core::stark::StarkMachine; use sp1_core::stark::StarkVerifyingKey; use sp1_recursion_compiler::ir::Array; use sp1_recursion_compiler::ir::Ext; -use sp1_recursion_compiler::ir::ExtConst; use sp1_recursion_compiler::ir::SymbolicExt; use sp1_recursion_compiler::ir::SymbolicVar; use sp1_recursion_compiler::ir::Var; @@ -94,48 +93,6 @@ impl<'a, SC: StarkGenericConfig, A: MachineAir> VerifyingKeyHint<'a, SC } } -impl StarkRecursiveVerifier for StarkMachine -where - C::F: TwoAdicField, - SC: StarkGenericConfig< - Val = C::F, - Challenge = C::EF, - Domain = TwoAdicMultiplicativeCoset, - >, - A: MachineAir + for<'a> Air>, - C::F: TwoAdicField, - C::EF: TwoAdicField, - Com: Into<[SC::Val; DIGEST_SIZE]>, -{ - fn verify_shard( - &self, - builder: &mut Builder, - vk: &VerifyingKeyVariable, - pcs: &TwoAdicFriPcsVariable, - challenger: &mut DuplexChallengerVariable, - proof: &ShardProofVariable, - is_complete: impl Into::N>>, - ) { - // Verify the shard proof. - StarkVerifier::::verify_shard(builder, vk, pcs, self, challenger, proof); - - // Verify that the cumulative sum of the chip is zero if the shard is complete. - let cumulative_sum: Ext<_, _> = builder.uninit(); - builder - .range(0, proof.opened_values.chips.len()) - .for_each(|i, builder| { - let values = builder.get(&proof.opened_values.chips, i); - builder.assign(cumulative_sum, cumulative_sum + values.cumulative_sum); - }); - - builder - .if_eq(is_complete.into(), C::N::one()) - .then(|builder| { - builder.assert_ext_eq(cumulative_sum, C::EF::zero().cons()); - }); - } -} - pub type RecursiveVerifierConstraintFolder<'a, C> = GenericVerifierConstraintFolder< 'a, ::F, @@ -161,6 +118,7 @@ where machine: &StarkMachine, challenger: &mut DuplexChallengerVariable, proof: &ShardProofVariable, + shard_idx: Var, ) where A: MachineAir + for<'a> Air>, C::F: TwoAdicField, @@ -356,6 +314,39 @@ where builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); } + if chip.name() == "MemoryProgram" { + builder.if_eq(shard_idx, C::N::one()).then_or_else( + |builder| { + builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); + }, + |builder| { + builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); + }, + ); + } + + if chip.name() == "MemoryInit" { + builder.if_eq(shard_idx, C::N::one()).then_or_else( + |builder| { + builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); + }, + |builder| { + builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); + }, + ); + } + + if chip.name() == "MemoryFinalize" { + builder.if_eq(shard_idx, C::N::one()).then_or_else( + |builder| { + builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY)); + }, + |builder| { + builder.assert_var_eq(index, C::N::from_canonical_usize(EMPTY)); + }, + ); + } + builder .if_ne(index, C::N::from_canonical_usize(EMPTY)) .then(|builder| { diff --git a/tests/blake3-compress/Cargo.lock b/tests/blake3-compress/Cargo.lock deleted file mode 100644 index b45f827d8..000000000 --- a/tests/blake3-compress/Cargo.lock +++ /dev/null @@ -1,760 +0,0 @@ -# This file is automatically @generated by Cargo. -# It is not intended for manual editing. -version = 3 - -[[package]] -name = "anyhow" -version = "1.0.86" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" - -[[package]] -name = "arrayvec" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" - -[[package]] -name = "autocfg" -version = "1.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" - -[[package]] -name = "base16ct" -version = "0.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4c7f02d4ea65f2c1853089ffd8d2787bdbc63de2f0d29dedbcf8ccdfa0ccd4cf" - -[[package]] -name = "base64ct" -version = "1.6.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" - -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - -[[package]] -name = "bitvec" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bc2832c24239b0141d5674bb9174f9d68a8b5b3f2753311927c172ca46f7e9c" -dependencies = [ - "funty", - "radium", - "tap", - "wyz", -] - -[[package]] -name = "blake3-compress-test" -version = "0.1.0" -dependencies = [ - "sp1-zkvm", -] - -[[package]] -name = "block-buffer" -version = "0.10.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" -dependencies = [ - "generic-array", -] - -[[package]] -name = "byte-slice-cast" -version = "1.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c3ac9f8b63eca6fd385229b3675f6cc0dc5c8a5c8a54a59d4f52ffd670d87b0c" - -[[package]] -name = "cfg-if" -version = "1.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" - -[[package]] -name = "const-oid" -version = "0.9.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" - -[[package]] -name = "cpufeatures" -version = "0.2.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "53fe5e26ff1b7aef8bca9c6080520cfb8d9333c7568e1829cef191a9723e5504" -dependencies = [ - "libc", -] - -[[package]] -name = "crypto-bigint" -version = "0.5.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0dc92fb57ca44df6db8059111ab3af99a63d5d0f8375d9972e319a379c6bab76" -dependencies = [ - "generic-array", - "rand_core", - "subtle", - "zeroize", -] - -[[package]] -name = "crypto-common" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" -dependencies = [ - "generic-array", - "typenum", -] - -[[package]] -name = "der" -version = "0.7.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c" -dependencies = [ - "const-oid", - "zeroize", -] - -[[package]] -name = "derive_more" -version = "0.99.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "digest" -version = "0.10.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" -dependencies = [ - "block-buffer", - "const-oid", - "crypto-common", - "subtle", -] - -[[package]] -name = "ecdsa" -version = "0.16.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" -dependencies = [ - "der", - "digest", - "elliptic-curve", - "rfc6979", - "signature", - "spki", -] - -[[package]] -name = "elliptic-curve" -version = "0.13.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" -dependencies = [ - "base16ct", - "crypto-bigint", - "digest", - "ff", - "generic-array", - "group", - "pkcs8", - "rand_core", - "sec1", - "subtle", - "tap", - "zeroize", -] - -[[package]] -name = "equivalent" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" - -[[package]] -name = "ff" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ded41244b729663b1e574f1b4fb731469f69f79c17667b5d776b16cda0479449" -dependencies = [ - "bitvec", - "rand_core", - "subtle", -] - -[[package]] -name = "funty" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6d5a32815ae3f33302d95fdcb2ce17862f8c65363dcfd29360480ba1001fc9c" - -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", - "zeroize", -] - -[[package]] -name = "getrandom" -version = "0.2.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" -dependencies = [ - "cfg-if", - "libc", - "wasi", -] - -[[package]] -name = "group" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0f9ef7462f7c099f518d754361858f86d8a07af53ba9af0fe635bbccb151a63" -dependencies = [ - "ff", - "rand_core", - "subtle", -] - -[[package]] -name = "hashbrown" -version = "0.14.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" - -[[package]] -name = "hex" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" - -[[package]] -name = "hmac" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" -dependencies = [ - "digest", -] - -[[package]] -name = "impl-trait-for-tuples" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "11d7a9f6330b71fea57921c9b61c47ee6e84f72d394754eff6163ae67e7395eb" -dependencies = [ - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "indexmap" -version = "2.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" -dependencies = [ - "equivalent", - "hashbrown", -] - -[[package]] -name = "k256" -version = "0.13.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "956ff9b67e26e1a6a866cb758f12c6f8746208489e3e4a4b5580802f2f0a587b" -dependencies = [ - "cfg-if", - "ecdsa", - "elliptic-curve", - "once_cell", - "sha2", - "signature", -] - -[[package]] -name = "libc" -version = "0.2.155" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" - -[[package]] -name = "libm" -version = "0.2.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" - -[[package]] -name = "memchr" -version = "2.7.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" - -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - -[[package]] -name = "num-bigint" -version = "0.4.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c165a9ab64cf766f73521c0dd2cfdff64f488b8f0b3e621face3462d3db536d7" -dependencies = [ - "num-integer", - "num-traits", -] - -[[package]] -name = "num-complex" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", -] - -[[package]] -name = "num-iter" -version = "0.1.45" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-rational" -version = "0.4.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" -dependencies = [ - "num-bigint", - "num-integer", - "num-traits", -] - -[[package]] -name = "num-traits" -version = "0.2.19" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" -dependencies = [ - "autocfg", -] - -[[package]] -name = "once_cell" -version = "1.19.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" - -[[package]] -name = "parity-scale-codec" -version = "3.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "881331e34fa842a2fb61cc2db9643a8fedc615e47cfcc52597d1af0db9a7e8fe" -dependencies = [ - "arrayvec", - "byte-slice-cast", - "impl-trait-for-tuples", - "parity-scale-codec-derive", -] - -[[package]] -name = "parity-scale-codec-derive" -version = "3.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be30eaf4b0a9fba5336683b38de57bb86d179a35862ba6bfcf57625d006bde5b" -dependencies = [ - "proc-macro-crate 2.0.2", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "pkcs8" -version = "0.10.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f950b2377845cebe5cf8b5165cb3cc1a5e0fa5cfa3e1f7f55707d8fd82e0a7b7" -dependencies = [ - "der", - "spki", -] - -[[package]] -name = "ppv-lite86" -version = "0.2.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" - -[[package]] -name = "proc-macro-crate" -version = "1.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f4c021e1093a56626774e81216a4ce732a735e5bad4868a03f3ed65ca0c3919" -dependencies = [ - "once_cell", - "toml_edit 0.19.15", -] - -[[package]] -name = "proc-macro-crate" -version = "2.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b00f26d3400549137f92511a46ac1cd8ce37cb5598a96d382381458b992a5d24" -dependencies = [ - "toml_datetime", - "toml_edit 0.20.2", -] - -[[package]] -name = "proc-macro2" -version = "1.0.78" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae" -dependencies = [ - "unicode-ident", -] - -[[package]] -name = "quote" -version = "1.0.35" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" -dependencies = [ - "proc-macro2", -] - -[[package]] -name = "radium" -version = "0.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc33ff2d4973d518d823d61aa239014831e521c75da58e3df4840d3f47749d09" - -[[package]] -name = "rand" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" -dependencies = [ - "libc", - "rand_chacha", - "rand_core", -] - -[[package]] -name = "rand_chacha" -version = "0.3.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" -dependencies = [ - "ppv-lite86", - "rand_core", -] - -[[package]] -name = "rand_core" -version = "0.6.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" -dependencies = [ - "getrandom", -] - -[[package]] -name = "rfc6979" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8dd2a808d456c4a54e300a23e9f5a67e122c3024119acbfd73e3bf664491cb2" -dependencies = [ - "hmac", - "subtle", -] - -[[package]] -name = "scale-info" -version = "2.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7c453e59a955f81fb62ee5d596b450383d699f152d350e9d23a0db2adb78e4c0" -dependencies = [ - "cfg-if", - "derive_more", - "parity-scale-codec", - "scale-info-derive", -] - -[[package]] -name = "scale-info-derive" -version = "2.11.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18cf6c6447f813ef19eb450e985bcce6705f9ce7660db221b59093d15c79c4b7" -dependencies = [ - "proc-macro-crate 1.3.1", - "proc-macro2", - "quote", - "syn 1.0.109", -] - -[[package]] -name = "sec1" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3e97a565f76233a6003f9f5c54be1d9c5bdfa3eccfb189469f11ec4901c47dc" -dependencies = [ - "base16ct", - "der", - "generic-array", - "pkcs8", - "subtle", - "zeroize", -] - -[[package]] -name = "serde" -version = "1.0.203" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7253ab4de971e72fb7be983802300c30b5a7f0c2e56fab8abfc6a214307c0094" -dependencies = [ - "serde_derive", -] - -[[package]] -name = "serde_derive" -version = "1.0.203" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "500cbc0ebeb6f46627f50f3f5811ccf6bf00643be300b4c3eabc0ef55dc5b5ba" -dependencies = [ - "proc-macro2", - "quote", - "syn 2.0.48", -] - -[[package]] -name = "sha2" -version = "0.10.8" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - -[[package]] -name = "signature" -version = "2.2.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" -dependencies = [ - "digest", - "rand_core", -] - -[[package]] -name = "snowbridge-amcl" -version = "1.0.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "460a9ed63cdf03c1b9847e8a12a5f5ba19c4efd5869e4a737e05be25d7c427e5" -dependencies = [ - "parity-scale-codec", - "scale-info", -] - -[[package]] -name = "sp1-precompiles" -version = "0.1.0" -dependencies = [ - "anyhow", - "bincode", - "cfg-if", - "getrandom", - "hex", - "k256", - "num", - "rand", - "serde", - "snowbridge-amcl", -] - -[[package]] -name = "sp1-zkvm" -version = "0.1.0" -dependencies = [ - "bincode", - "cfg-if", - "getrandom", - "k256", - "libm", - "once_cell", - "rand", - "serde", - "sha2", - "sp1-precompiles", -] - -[[package]] -name = "spki" -version = "0.7.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d91ed6c858b01f942cd56b37a94b3e0a1798290327d1236e4d9cf4eaca44d29d" -dependencies = [ - "base64ct", - "der", -] - -[[package]] -name = "subtle" -version = "2.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" - -[[package]] -name = "syn" -version = "1.0.109" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "syn" -version = "2.0.48" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" -dependencies = [ - "proc-macro2", - "quote", - "unicode-ident", -] - -[[package]] -name = "tap" -version = "1.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55937e1799185b12863d447f42597ed69d9928686b8d88a1df17376a097d8369" - -[[package]] -name = "toml_datetime" -version = "0.6.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7cda73e2f1397b1262d6dfdcef8aafae14d1de7748d66822d3bfeeb6d03e5e4b" - -[[package]] -name = "toml_edit" -version = "0.19.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b5bb770da30e5cbfde35a2d7b9b8a2c4b8ef89548a7a6aeab5c9a576e3e7421" -dependencies = [ - "indexmap", - "toml_datetime", - "winnow", -] - -[[package]] -name = "toml_edit" -version = "0.20.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "396e4d48bbb2b7554c944bde63101b5ae446cff6ec4a24227428f15eb72ef338" -dependencies = [ - "indexmap", - "toml_datetime", - "winnow", -] - -[[package]] -name = "typenum" -version = "1.17.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" - -[[package]] -name = "unicode-ident" -version = "1.0.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" - -[[package]] -name = "version_check" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" - -[[package]] -name = "wasi" -version = "0.11.0+wasi-snapshot-preview1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" - -[[package]] -name = "winnow" -version = "0.5.40" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f593a95398737aeed53e489c785df13f3618e41dbcd6718c6addbf1395aa6876" -dependencies = [ - "memchr", -] - -[[package]] -name = "wyz" -version = "0.5.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "05f360fc0b24296329c78fda852a1e9ae82de9cf7b27dae4b7f62f118f77b9ed" -dependencies = [ - "tap", -] - -[[package]] -name = "zeroize" -version = "1.7.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" diff --git a/tests/blake3-compress/Cargo.toml b/tests/blake3-compress/Cargo.toml deleted file mode 100644 index e5987407c..000000000 --- a/tests/blake3-compress/Cargo.toml +++ /dev/null @@ -1,8 +0,0 @@ -[workspace] -[package] -version = "0.1.0" -name = "blake3-compress-test" -edition = "2021" - -[dependencies] -sp1-zkvm = { path = "../../zkvm/entrypoint" } diff --git a/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf b/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf deleted file mode 100755 index 4e0fee023..000000000 Binary files a/tests/blake3-compress/elf/riscv32im-succinct-zkvm-elf and /dev/null differ diff --git a/tests/blake3-compress/src/main.rs b/tests/blake3-compress/src/main.rs deleted file mode 100644 index 6bbee4916..000000000 --- a/tests/blake3-compress/src/main.rs +++ /dev/null @@ -1,42 +0,0 @@ -#![no_main] -sp1_zkvm::entrypoint!(main); - -extern "C" { - fn syscall_blake3_compress_inner(p: *mut u32, q: *const u32); -} - -pub fn main() { - // The input message and state are simply 0, 1, ..., 95 followed by some fixed constants. - for _i in 0..10 { - let input_message: [u8; 64] = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, - 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, - 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, - ]; - - let mut input_state: [u8; 64] = [ - 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, - 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 103, 230, 9, 106, 133, 174, 103, 187, 114, 243, - 110, 60, 58, 245, 79, 165, 96, 0, 0, 0, 0, 0, 0, 0, 64, 0, 0, 0, 97, 0, 0, 0, - ]; - - unsafe { - syscall_blake3_compress_inner( - input_state.as_mut_ptr() as *mut u32, - input_message.as_ptr() as *const u32, - ); - } - - // The expected output state is the result of compress_inner. - let output_state: [u8; 64] = [ - 239, 181, 94, 129, 58, 124, 80, 104, 126, 210, 5, 157, 255, 58, 238, 89, 252, 106, 170, - 12, 233, 56, 58, 31, 215, 16, 105, 97, 11, 229, 238, 73, 6, 79, 155, 180, 197, 73, 116, - 0, 127, 22, 16, 39, 116, 174, 85, 5, 61, 94, 87, 6, 236, 10, 36, 238, 119, 171, 207, - 171, 189, 216, 43, 250, - ]; - - assert_eq!(input_state, output_state); - } - - println!("done"); -} diff --git a/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf index fce1bf9ff..6e2c7e686 100755 Binary files a/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-add/src/main.rs b/tests/bls12381-add/src/main.rs index 874e9f066..681cf39af 100644 --- a/tests/bls12381-add/src/main.rs +++ b/tests/bls12381-add/src/main.rs @@ -6,44 +6,48 @@ extern "C" { } pub fn main() { - // generator. - // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 - // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 - let mut a: [u8; 96] = [ - 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, 23, - 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, 149, 38, - 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, 138, 136, - 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, 149, 224, 245, - 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, 179, 8, - ]; + for _ in 0..4 { + // generator. + // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 + // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 + let mut a: [u8; 96] = [ + 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, + 23, 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, + 149, 38, 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, + 138, 136, 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, + 149, 224, 245, 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, + 179, 8, + ]; - // 2 * generator. - // 838589206289216005799424730305866328161735431124665289961769162861615689790485775997575391185127590486775437397838 - // 3450209970729243429733164009999191867485184320918914219895632678707687208996709678363578245114137957452475385814312 - let b: [u8; 96] = [ - 78, 15, 191, 41, 85, 140, 154, 195, 66, 124, 28, 143, 187, 117, 143, 226, 42, 166, 88, 195, - 10, 45, 144, 67, 37, 1, 40, 145, 48, 219, 33, 151, 12, 69, 169, 80, 235, 200, 8, 136, 70, - 103, 77, 144, 234, 203, 114, 5, 40, 157, 116, 121, 25, 136, 134, 186, 27, 189, 22, 205, - 212, 217, 86, 76, 106, 215, 95, 29, 2, 185, 59, 247, 97, 228, 112, 134, 203, 62, 186, 34, - 56, 142, 157, 119, 115, 166, 253, 34, 163, 115, 198, 171, 140, 157, 106, 22, - ]; + // 2 * generator. + // 838589206289216005799424730305866328161735431124665289961769162861615689790485775997575391185127590486775437397838 + // 3450209970729243429733164009999191867485184320918914219895632678707687208996709678363578245114137957452475385814312 + let b: [u8; 96] = [ + 78, 15, 191, 41, 85, 140, 154, 195, 66, 124, 28, 143, 187, 117, 143, 226, 42, 166, 88, + 195, 10, 45, 144, 67, 37, 1, 40, 145, 48, 219, 33, 151, 12, 69, 169, 80, 235, 200, 8, + 136, 70, 103, 77, 144, 234, 203, 114, 5, 40, 157, 116, 121, 25, 136, 134, 186, 27, 189, + 22, 205, 212, 217, 86, 76, 106, 215, 95, 29, 2, 185, 59, 247, 97, 228, 112, 134, 203, + 62, 186, 34, 56, 142, 157, 119, 115, 166, 253, 34, 163, 115, 198, 171, 140, 157, 106, + 22, + ]; - unsafe { - syscall_bls12381_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_bls12381_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 - // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 - let c: [u8; 96] = [ - 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, 51, - 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, 101, - 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, 74, 99, - 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, 12, 154, 166, - 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, - ]; + // 3 * generator. + // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 + // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 + let c: [u8; 96] = [ + 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, + 51, 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, + 101, 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, + 74, 99, 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, + 12, 154, 166, 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, + ]; - assert_eq!(a, c); + assert_eq!(a, c); + } println!("done"); } diff --git a/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf index 818954dc4..3a8f2e187 100755 Binary files a/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-decompress/src/main.rs b/tests/bls12381-decompress/src/main.rs index 3e93a099f..a359fa9f7 100644 --- a/tests/bls12381-decompress/src/main.rs +++ b/tests/bls12381-decompress/src/main.rs @@ -7,19 +7,22 @@ extern "C" { pub fn main() { let compressed_key: [u8; 48] = sp1_zkvm::io::read_vec().try_into().unwrap(); - let mut decompressed_key: [u8; 96] = [0u8; 96]; - decompressed_key[..48].copy_from_slice(&compressed_key); + for _ in 0..4 { + let mut decompressed_key: [u8; 96] = [0u8; 96]; - println!("before: {:?}", decompressed_key); + decompressed_key[..48].copy_from_slice(&compressed_key); - let is_odd = (decompressed_key[0] & 0b_0010_0000) >> 5 == 0; - decompressed_key[0] &= 0b_0001_1111; + println!("before: {:?}", decompressed_key); - unsafe { - syscall_bls12381_decompress(&mut decompressed_key, is_odd); - } - println!("after: {:?}", decompressed_key); + let is_odd = (decompressed_key[0] & 0b_0010_0000) >> 5 == 0; + decompressed_key[0] &= 0b_0001_1111; + + unsafe { + syscall_bls12381_decompress(&mut decompressed_key, is_odd); + } - sp1_zkvm::io::commit_slice(&decompressed_key); + println!("after: {:?}", decompressed_key); + sp1_zkvm::io::commit_slice(&decompressed_key); + } } diff --git a/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf index 5c4706b8f..50470172a 100755 Binary files a/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf b/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf index 313a6226f..d1fe6cdf6 100755 Binary files a/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/bls12381-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bls12381-mul/src/main.rs b/tests/bls12381-mul/src/main.rs index 9f90906e9..89169660a 100644 --- a/tests/bls12381-mul/src/main.rs +++ b/tests/bls12381-mul/src/main.rs @@ -6,39 +6,42 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - // generator. - // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 - // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 - let a: [u8; 96] = [ - 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, 23, - 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, 149, 38, - 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, 138, 136, - 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, 149, 224, 245, - 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, 179, 8, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 12] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: bn254_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: bn254_mul"); - - // 3 * generator. - // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 - // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 - let c: [u8; 96] = [ - 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, 51, - 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, 101, - 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, 74, 99, - 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, 12, 154, 166, - 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, - ]; - - assert_eq!(a_point.to_le_bytes(), c); + for _ in 0..4 { + // generator. + // 3685416753713387016781088315183077757961620795782546409894578378688607592378376318836054947676345821548104185464507 + // 1339506544944476473020471379941921221584933875938349620426543736416511423956333506472724655353366534992391756441569 + let a: [u8; 96] = [ + 187, 198, 34, 219, 10, 240, 58, 251, 239, 26, 122, 249, 63, 232, 85, 108, 88, 172, 27, + 23, 63, 58, 78, 161, 5, 185, 116, 151, 79, 140, 104, 195, 15, 172, 169, 79, 140, 99, + 149, 38, 148, 215, 151, 49, 167, 211, 241, 23, 225, 231, 197, 70, 41, 35, 170, 12, 228, + 138, 136, 162, 68, 199, 60, 208, 237, 179, 4, 44, 203, 24, 219, 0, 246, 10, 208, 213, + 149, 224, 245, 252, 228, 138, 29, 116, 237, 48, 158, 160, 241, 160, 170, 227, 129, 244, + 179, 8, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 12] = [3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: bn254_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: bn254_mul"); + + // 3 * generator. + // 1527649530533633684281386512094328299672026648504329745640827351945739272160755686119065091946435084697047221031460 + // 487897572011753812113448064805964756454529228648704488481988876974355015977479905373670519228592356747638779818193 + let c: [u8; 96] = [ + 36, 82, 78, 2, 201, 192, 210, 150, 155, 23, 162, 44, 11, 122, 116, 129, 249, 63, 91, + 51, 81, 10, 120, 243, 241, 165, 233, 155, 31, 214, 18, 177, 151, 150, 169, 236, 45, 33, + 101, 23, 19, 240, 209, 249, 8, 227, 236, 9, 209, 48, 174, 144, 5, 59, 71, 163, 92, 244, + 74, 99, 108, 37, 69, 231, 230, 59, 212, 15, 49, 39, 156, 157, 127, 9, 195, 171, 221, + 12, 154, 166, 12, 248, 197, 137, 51, 98, 132, 138, 159, 176, 245, 166, 211, 128, 43, 3, + ]; + + assert_eq!(a_point.to_le_bytes(), c); + } println!("done"); } diff --git a/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf index a45b52cd9..a55b917d1 100755 Binary files a/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-add/src/main.rs b/tests/bn254-add/src/main.rs index 7e164663d..406681d65 100644 --- a/tests/bn254-add/src/main.rs +++ b/tests/bn254-add/src/main.rs @@ -6,40 +6,42 @@ extern "C" { } pub fn main() { - // generator. - // 1 - // 2 - let mut a: [u8; 64] = [ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, - ]; + for _ in 0..4 { + // generator. + // 1 + // 2 + let mut a: [u8; 64] = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]; - // 2 * generator. - // 1368015179489954701390400359078579693043519447331113978918064868415326638035 - // 9918110051302171585080402603319702774565515993150576347155970296011118125764 - let b: [u8; 64] = [ - 211, 207, 135, 109, 193, 8, 194, 211, 168, 28, 135, 22, 169, 22, 120, 217, 133, 21, 24, - 104, 91, 4, 133, 155, 2, 26, 19, 46, 231, 68, 6, 3, 196, 162, 24, 90, 122, 191, 62, 255, - 199, 143, 83, 227, 73, 164, 166, 104, 10, 156, 174, 178, 150, 95, 132, 231, 146, 124, 10, - 14, 140, 115, 237, 21, - ]; + // 2 * generator. + // 1368015179489954701390400359078579693043519447331113978918064868415326638035 + // 9918110051302171585080402603319702774565515993150576347155970296011118125764 + let b: [u8; 64] = [ + 211, 207, 135, 109, 193, 8, 194, 211, 168, 28, 135, 22, 169, 22, 120, 217, 133, 21, 24, + 104, 91, 4, 133, 155, 2, 26, 19, 46, 231, 68, 6, 3, 196, 162, 24, 90, 122, 191, 62, + 255, 199, 143, 83, 227, 73, 164, 166, 104, 10, 156, 174, 178, 150, 95, 132, 231, 146, + 124, 10, 14, 140, 115, 237, 21, + ]; - unsafe { - syscall_bn254_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_bn254_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 3353031288059533942658390886683067124040920775575537747144343083137631628272 - // 19321533766552368860946552437480515441416830039777911637913418824951667761761 - let c: [u8; 64] = [ - 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, - 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, 241, - 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, 148, 72, - 224, 190, 153, 183, 42, - ]; + // 3 * generator. + // 3353031288059533942658390886683067124040920775575537747144343083137631628272 + // 19321533766552368860946552437480515441416830039777911637913418824951667761761 + let c: [u8; 64] = [ + 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, + 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, + 241, 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, + 148, 72, 224, 190, 153, 183, 42, + ]; - assert_eq!(a, c); + assert_eq!(a, c); + } println!("done"); } diff --git a/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf index 2c7bcb623..b571be734 100755 Binary files a/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf b/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf index a414416de..dd1506ddc 100755 Binary files a/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/bn254-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/bn254-mul/src/main.rs b/tests/bn254-mul/src/main.rs index 841de5e4d..3086c3806 100644 --- a/tests/bn254-mul/src/main.rs +++ b/tests/bn254-mul/src/main.rs @@ -6,36 +6,38 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - // generator. - // 1 - // 2 - let a: [u8; 64] = [ - 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: bn254_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: bn254_mul"); - - // 3 * generator. - // 3353031288059533942658390886683067124040920775575537747144343083137631628272 - // 19321533766552368860946552437480515441416830039777911637913418824951667761761 - let c: [u8; 64] = [ - 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, - 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, 241, - 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, 148, 72, - 224, 190, 153, 183, 42, - ]; - - assert_eq!(a_point.to_le_bytes(), c); + for _ in 0..4 { + // generator. + // 1 + // 2 + let a: [u8; 64] = [ + 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: bn254_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: bn254_mul"); + + // 3 * generator. + // 3353031288059533942658390886683067124040920775575537747144343083137631628272 + // 19321533766552368860946552437480515441416830039777911637913418824951667761761 + let c: [u8; 64] = [ + 240, 171, 21, 25, 150, 85, 211, 242, 121, 230, 184, 21, 71, 216, 21, 147, 21, 189, 182, + 177, 188, 50, 2, 244, 63, 234, 107, 197, 154, 191, 105, 7, 97, 34, 254, 217, 61, 255, + 241, 205, 87, 91, 156, 11, 180, 99, 158, 49, 117, 100, 8, 141, 124, 219, 79, 85, 41, + 148, 72, 224, 190, 153, 183, 42, + ]; + + assert_eq!(a_point.to_le_bytes(), c); + } println!("done"); } diff --git a/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf b/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf index ed3121d5d..6e2531ad0 100755 Binary files a/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf and b/tests/cycle-tracker/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf b/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf index d75d1642e..58e50a259 100755 Binary files a/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf and b/tests/ecrecover/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-add/elf/riscv32im-succinct-zkvm-elf b/tests/ed-add/elf/riscv32im-succinct-zkvm-elf index 1f79b12f4..5916c8a80 100755 Binary files a/tests/ed-add/elf/riscv32im-succinct-zkvm-elf and b/tests/ed-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-add/src/main.rs b/tests/ed-add/src/main.rs index 057aea482..3deafa0c7 100644 --- a/tests/ed-add/src/main.rs +++ b/tests/ed-add/src/main.rs @@ -6,37 +6,40 @@ extern "C" { } pub fn main() { - // 90393249858788985237231628593243673548167146579814268721945474994541877372611 - // 33321104029277118100578831462130550309254424135206412570121538923759338004303 - let mut a: [u8; 64] = [ - 195, 166, 157, 207, 218, 220, 175, 197, 111, 177, 123, 23, 73, 72, 114, 103, 28, 246, 66, - 207, 66, 146, 187, 234, 136, 238, 133, 145, 47, 196, 216, 199, 79, 31, 224, 30, 179, 122, - 51, 84, 116, 12, 4, 189, 198, 198, 190, 22, 71, 201, 143, 249, 92, 56, 147, 133, 92, 187, - 130, 33, 152, 19, 171, 73, - ]; + for _ in 0..4 { + // 90393249858788985237231628593243673548167146579814268721945474994541877372611 + // 33321104029277118100578831462130550309254424135206412570121538923759338004303 + let mut a: [u8; 64] = [ + 195, 166, 157, 207, 218, 220, 175, 197, 111, 177, 123, 23, 73, 72, 114, 103, 28, 246, + 66, 207, 66, 146, 187, 234, 136, 238, 133, 145, 47, 196, 216, 199, 79, 31, 224, 30, + 179, 122, 51, 84, 116, 12, 4, 189, 198, 198, 190, 22, 71, 201, 143, 249, 92, 56, 147, + 133, 92, 187, 130, 33, 152, 19, 171, 73, + ]; - // 61717728572175158701898635111983295176935961585742968051419350619945173564869 - // 28137966556353620208933066709998005335145594788896528644015312259959272398451 - let b: [u8; 64] = [ - 197, 189, 200, 77, 201, 212, 57, 105, 191, 133, 123, 170, 167, 50, 114, 38, 37, 102, 188, - 29, 215, 227, 157, 142, 252, 31, 129, 67, 24, 255, 114, 136, 115, 94, 94, 55, 43, 200, 117, - 224, 139, 251, 238, 45, 80, 154, 70, 213, 219, 78, 201, 108, 73, 203, 72, 45, 167, 131, - 199, 47, 82, 134, 53, 62, - ]; + // 61717728572175158701898635111983295176935961585742968051419350619945173564869 + // 28137966556353620208933066709998005335145594788896528644015312259959272398451 + let b: [u8; 64] = [ + 197, 189, 200, 77, 201, 212, 57, 105, 191, 133, 123, 170, 167, 50, 114, 38, 37, 102, + 188, 29, 215, 227, 157, 142, 252, 31, 129, 67, 24, 255, 114, 136, 115, 94, 94, 55, 43, + 200, 117, 224, 139, 251, 238, 45, 80, 154, 70, 213, 219, 78, 201, 108, 73, 203, 72, 45, + 167, 131, 199, 47, 82, 134, 53, 62, + ]; - unsafe { - syscall_ed_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_ed_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } + + // 36213413123116753589144482590359479011148956763279542162278577842046663495729 + // 17093345531692682197799066694073110060588941459686871373458223451938707761683 + let c: [u8; 64] = [ + 49, 144, 129, 197, 86, 163, 62, 48, 222, 208, 213, 200, 219, 90, 163, 54, 211, 248, + 178, 224, 238, 167, 235, 219, 251, 247, 189, 239, 194, 16, 16, 80, 19, 106, 20, 198, + 72, 56, 103, 111, 68, 201, 29, 107, 75, 208, 193, 232, 181, 186, 175, 22, 213, 187, + 253, 125, 44, 80, 222, 209, 159, 125, 202, 37, + ]; - // 36213413123116753589144482590359479011148956763279542162278577842046663495729 - // 17093345531692682197799066694073110060588941459686871373458223451938707761683 - let c: [u8; 64] = [ - 49, 144, 129, 197, 86, 163, 62, 48, 222, 208, 213, 200, 219, 90, 163, 54, 211, 248, 178, - 224, 238, 167, 235, 219, 251, 247, 189, 239, 194, 16, 16, 80, 19, 106, 20, 198, 72, 56, - 103, 111, 68, 201, 29, 107, 75, 208, 193, 232, 181, 186, 175, 22, 213, 187, 253, 125, 44, - 80, 222, 209, 159, 125, 202, 37, - ]; + assert_eq!(a, c); + } - assert_eq!(a, c); println!("done"); } diff --git a/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf index 10bbf5e06..233f1ab1c 100755 Binary files a/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/ed-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/ed-decompress/src/main.rs b/tests/ed-decompress/src/main.rs index 32f4eef65..0b6929dde 100644 --- a/tests/ed-decompress/src/main.rs +++ b/tests/ed-decompress/src/main.rs @@ -8,26 +8,28 @@ extern "C" { } pub fn main() { - let pub_bytes = hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"); + for _ in 0..4 { + let pub_bytes = hex!("ec172b93ad5e563bf4932c70e1245034c35467ef2efd4d64ebf819683467e2bf"); - let mut decompressed = [0_u8; 64]; - decompressed[32..].copy_from_slice(&pub_bytes); + let mut decompressed = [0_u8; 64]; + decompressed[32..].copy_from_slice(&pub_bytes); - println!("before: {:?}", decompressed); + println!("before: {:?}", decompressed); - unsafe { - syscall_ed_decompress(decompressed.as_mut_ptr()); - } + unsafe { + syscall_ed_decompress(decompressed.as_mut_ptr()); + } - let expected: [u8; 64] = [ - 47, 252, 114, 91, 153, 234, 110, 201, 201, 153, 152, 14, 68, 231, 90, 221, 137, 110, 250, - 67, 10, 64, 37, 70, 163, 101, 111, 223, 185, 1, 180, 88, 236, 23, 43, 147, 173, 94, 86, 59, - 244, 147, 44, 112, 225, 36, 80, 52, 195, 84, 103, 239, 46, 253, 77, 100, 235, 248, 25, 104, - 52, 103, 226, 63, - ]; + let expected: [u8; 64] = [ + 47, 252, 114, 91, 153, 234, 110, 201, 201, 153, 152, 14, 68, 231, 90, 221, 137, 110, + 250, 67, 10, 64, 37, 70, 163, 101, 111, 223, 185, 1, 180, 88, 236, 23, 43, 147, 173, + 94, 86, 59, 244, 147, 44, 112, 225, 36, 80, 52, 195, 84, 103, 239, 46, 253, 77, 100, + 235, 248, 25, 104, 52, 103, 226, 63, + ]; - assert_eq!(decompressed, expected); + assert_eq!(decompressed, expected); + println!("after: {:?}", decompressed); + } - println!("after: {:?}", decompressed); println!("done"); } diff --git a/tests/ed25519/elf/riscv32im-succinct-zkvm-elf b/tests/ed25519/elf/riscv32im-succinct-zkvm-elf index 88c83e3c0..5f149617c 100755 Binary files a/tests/ed25519/elf/riscv32im-succinct-zkvm-elf and b/tests/ed25519/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf b/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf index 1c59449d8..7a61102c1 100755 Binary files a/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf and b/tests/fibonacci/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/hint-io/elf/riscv32im-succinct-zkvm-elf b/tests/hint-io/elf/riscv32im-succinct-zkvm-elf index 69fc40b11..ac7a2fc29 100755 Binary files a/tests/hint-io/elf/riscv32im-succinct-zkvm-elf and b/tests/hint-io/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf b/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf index a843a0779..15dee9915 100755 Binary files a/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf and b/tests/keccak-permute/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/keccak256/elf/riscv32im-succinct-zkvm-elf b/tests/keccak256/elf/riscv32im-succinct-zkvm-elf index 48e4965b3..311da32c1 100755 Binary files a/tests/keccak256/elf/riscv32im-succinct-zkvm-elf and b/tests/keccak256/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/panic/elf/riscv32im-succinct-zkvm-elf b/tests/panic/elf/riscv32im-succinct-zkvm-elf index e68a4a4dc..8debb2189 100755 Binary files a/tests/panic/elf/riscv32im-succinct-zkvm-elf and b/tests/panic/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf index 339003c77..bf7a3db10 100755 Binary files a/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-add/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-add/src/main.rs b/tests/secp256k1-add/src/main.rs index c45601bcc..9640e4e8c 100644 --- a/tests/secp256k1-add/src/main.rs +++ b/tests/secp256k1-add/src/main.rs @@ -6,41 +6,43 @@ extern "C" { } pub fn main() { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let mut a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, - 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, - 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, - 38, 119, 218, 58, 72, - ]; + for _ in 0..4 { + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let mut a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, + 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, + 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, + 101, 196, 163, 38, 119, 218, 58, 72, + ]; - // 2 * generator. - // 89565891926547004231252920425935692360644145829622209833684329913297188986597 - // 12158399299693830322967808612713398636155367887041628176798871954788371653930 - let b: [u8; 64] = [ - 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, 192, - 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, 49, 100, - 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, 163, 57, 195, - 61, 166, 254, 104, 225, 26, - ]; + // 2 * generator. + // 89565891926547004231252920425935692360644145829622209833684329913297188986597 + // 12158399299693830322967808612713398636155367887041628176798871954788371653930 + let b: [u8; 64] = [ + 229, 158, 112, 92, 185, 9, 172, 171, 167, 60, 239, 140, 75, 142, 119, 92, 216, 124, + 192, 149, 110, 64, 69, 48, 109, 125, 237, 65, 148, 127, 4, 198, 42, 229, 207, 80, 169, + 49, 100, 35, 225, 208, 102, 50, 101, 50, 246, 247, 238, 234, 108, 70, 25, 132, 197, + 163, 57, 195, 61, 166, 254, 104, 225, 26, + ]; - unsafe { - syscall_secp256k1_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); - } + unsafe { + syscall_secp256k1_add(a.as_mut_ptr() as *mut u32, b.as_ptr() as *const u32); + } - // 3 * generator. - // 112711660439710606056748659173929673102114977341539408544630613555209775888121 - // 25583027980570883691656905877401976406448868254816295069919888960541586679410 - let c: [u8; 64] = [ - 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, 248, - 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, 185, 108, - 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, 232, 45, 99, 15, - 123, 143, 56, - ]; + // 3 * generator. + // 112711660439710606056748659173929673102114977341539408544630613555209775888121 + // 25583027980570883691656905877401976406448868254816295069919888960541586679410 + let c: [u8; 64] = [ + 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, + 248, 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, + 185, 108, 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, + 232, 45, 99, 15, 123, 143, 56, + ]; - assert_eq!(a, c); + assert_eq!(a, c); + } println!("done"); } diff --git a/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf index e06da48d7..2fae11204 100755 Binary files a/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-decompress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-decompress/src/main.rs b/tests/secp256k1-decompress/src/main.rs index 6dc18a25c..a603986e0 100644 --- a/tests/secp256k1-decompress/src/main.rs +++ b/tests/secp256k1-decompress/src/main.rs @@ -8,20 +8,22 @@ extern "C" { pub fn main() { let compressed_key: [u8; 33] = sp1_zkvm::io::read_vec().try_into().unwrap(); - let mut decompressed_key: [u8; 64] = [0; 64]; - decompressed_key[..32].copy_from_slice(&compressed_key[1..]); - let is_odd = match compressed_key[0] { - 2 => false, - 3 => true, - _ => panic!("Invalid compressed key"), - }; - unsafe { - syscall_secp256k1_decompress(&mut decompressed_key, is_odd); - } + for _ in 0..4 { + let mut decompressed_key: [u8; 64] = [0; 64]; + decompressed_key[..32].copy_from_slice(&compressed_key[1..]); + let is_odd = match compressed_key[0] { + 2 => false, + 3 => true, + _ => panic!("Invalid compressed key"), + }; + unsafe { + syscall_secp256k1_decompress(&mut decompressed_key, is_odd); + } - let mut result: [u8; 65] = [0; 65]; - result[0] = 4; - result[1..].copy_from_slice(&decompressed_key); + let mut result: [u8; 65] = [0; 65]; + result[0] = 4; + result[1..].copy_from_slice(&decompressed_key); - sp1_zkvm::io::commit_slice(&result); + sp1_zkvm::io::commit_slice(&result); + } } diff --git a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf index 6ad007626..79a156fca 100755 Binary files a/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-double/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf b/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf index ec0db8bd0..d3e17ead6 100755 Binary files a/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/secp256k1-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/secp256k1-mul/src/main.rs b/tests/secp256k1-mul/src/main.rs index 731a81b38..a2fb6a3dd 100644 --- a/tests/secp256k1-mul/src/main.rs +++ b/tests/secp256k1-mul/src/main.rs @@ -6,37 +6,39 @@ use sp1_zkvm::precompiles::utils::AffinePoint; #[sp1_derive::cycle_tracker] pub fn main() { - // generator. - // 55066263022277343669578718895168534326250603453777594175500187360389116729240 - // 32670510020758816978083085130507043184471273380659243275938904335757337482424 - let a: [u8; 64] = [ - 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, 206, - 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, 208, 71, - 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, 101, 196, 163, - 38, 119, 218, 58, 72, - ]; - - let mut a_point = AffinePoint::::from_le_bytes(&a); - - // scalar. - // 3 - let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; - - println!("cycle-tracker-start: secp256k1_mul"); - a_point.mul_assign(&scalar); - println!("cycle-tracker-end: secp256k1_mul"); - - // 3 * generator. - // 112711660439710606056748659173929673102114977341539408544630613555209775888121 - // 25583027980570883691656905877401976406448868254816295069919888960541586679410 - let c: [u8; 64] = [ - 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, 248, - 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, 185, 108, - 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, 232, 45, 99, 15, - 123, 143, 56, - ]; - - assert_eq!(a_point.to_le_bytes(), c); + for _ in 0..4 { + // generator. + // 55066263022277343669578718895168534326250603453777594175500187360389116729240 + // 32670510020758816978083085130507043184471273380659243275938904335757337482424 + let a: [u8; 64] = [ + 152, 23, 248, 22, 91, 129, 242, 89, 217, 40, 206, 45, 219, 252, 155, 2, 7, 11, 135, + 206, 149, 98, 160, 85, 172, 187, 220, 249, 126, 102, 190, 121, 184, 212, 16, 251, 143, + 208, 71, 156, 25, 84, 133, 166, 72, 180, 23, 253, 168, 8, 17, 14, 252, 251, 164, 93, + 101, 196, 163, 38, 119, 218, 58, 72, + ]; + + let mut a_point = AffinePoint::::from_le_bytes(&a); + + // scalar. + // 3 + let scalar: [u32; 8] = [3, 0, 0, 0, 0, 0, 0, 0]; + + println!("cycle-tracker-start: secp256k1_mul"); + a_point.mul_assign(&scalar); + println!("cycle-tracker-end: secp256k1_mul"); + + // 3 * generator. + // 112711660439710606056748659173929673102114977341539408544630613555209775888121 + // 25583027980570883691656905877401976406448868254816295069919888960541586679410 + let c: [u8; 64] = [ + 249, 54, 224, 188, 19, 241, 1, 134, 176, 153, 111, 131, 69, 200, 49, 181, 41, 82, 157, + 248, 133, 79, 52, 73, 16, 195, 88, 146, 1, 138, 48, 249, 114, 230, 184, 132, 117, 253, + 185, 108, 27, 35, 194, 52, 153, 169, 0, 101, 86, 243, 55, 42, 230, 55, 227, 15, 20, + 232, 45, 99, 15, 123, 143, 56, + ]; + + assert_eq!(a_point.to_le_bytes(), c); + } println!("done"); } diff --git a/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf b/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf index 97126f881..f10443e12 100755 Binary files a/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf and b/tests/sha-compress/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/sha-compress/src/main.rs b/tests/sha-compress/src/main.rs index bdddab166..3c306966c 100644 --- a/tests/sha-compress/src/main.rs +++ b/tests/sha-compress/src/main.rs @@ -6,6 +6,10 @@ use sp1_zkvm::syscalls::syscall_sha256_compress; pub fn main() { let mut w = [1u32; 64]; let mut state = [1u32; 8]; - syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); + + for _ in 0..4 { + syscall_sha256_compress(w.as_mut_ptr(), state.as_mut_ptr()); + } + println!("{:?}", state); } diff --git a/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf b/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf index 7b8774766..d584e1c35 100755 Binary files a/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf and b/tests/sha-extend/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/sha2/elf/riscv32im-succinct-zkvm-elf b/tests/sha2/elf/riscv32im-succinct-zkvm-elf index ff4661def..2c63e6648 100755 Binary files a/tests/sha2/elf/riscv32im-succinct-zkvm-elf and b/tests/sha2/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf b/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf index d67e8be4d..526fc2f83 100755 Binary files a/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf and b/tests/tendermint-benchmark/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf index 83a521f1b..25b450b66 100755 Binary files a/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-arith/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-arith/src/main.rs b/tests/uint256-arith/src/main.rs index 872af3bf6..4df9947bd 100644 --- a/tests/uint256-arith/src/main.rs +++ b/tests/uint256-arith/src/main.rs @@ -36,27 +36,29 @@ pub fn main() { let a = U256::from(3u8); let b = U256::from(2u8); - println!("cycle-tracker-start: uint256_add"); - let add = uint256_add(black_box(a), black_box(b)); - assert_eq!(add, U256::from(5u8)); - println!("cycle-tracker-end: uint256_add"); - println!("{:?}", add); - - println!("cycle-tracker-start: uint256_sub"); - let sub = uint256_sub(black_box(a), black_box(b)); - assert_eq!(sub, U256::from(1u8)); - println!("cycle-tracker-end: uint256_sub"); - println!("{:?}", sub); - - println!("cycle-tracker-start: uint256_div"); - let div = uint256_div(black_box(a), black_box(b)); - assert_eq!(div, U256::from(1u8)); - println!("cycle-tracker-end: uint256_div"); - println!("{:?}", div); - - println!("cycle-tracker-start: uint256_mul"); - let mul = uint256_mul(black_box(a), black_box(b)); - assert_eq!(mul, U256::from(6u8)); - println!("cycle-tracker-end: uint256_mul"); - println!("{:?}", mul); + for _ in 0..4 { + println!("cycle-tracker-start: uint256_add"); + let add = uint256_add(black_box(a), black_box(b)); + assert_eq!(add, U256::from(5u8)); + println!("cycle-tracker-end: uint256_add"); + println!("{:?}", add); + + println!("cycle-tracker-start: uint256_sub"); + let sub = uint256_sub(black_box(a), black_box(b)); + assert_eq!(sub, U256::from(1u8)); + println!("cycle-tracker-end: uint256_sub"); + println!("{:?}", sub); + + println!("cycle-tracker-start: uint256_div"); + let div = uint256_div(black_box(a), black_box(b)); + assert_eq!(div, U256::from(1u8)); + println!("cycle-tracker-end: uint256_div"); + println!("{:?}", div); + + println!("cycle-tracker-start: uint256_mul"); + let mul = uint256_mul(black_box(a), black_box(b)); + assert_eq!(mul, U256::from(6u8)); + println!("cycle-tracker-end: uint256_mul"); + println!("{:?}", mul); + } } diff --git a/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf index ca4f0641f..067b69ca7 100755 Binary files a/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-div/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf b/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf index b49313ca6..2ab53b56d 100755 Binary files a/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf and b/tests/uint256-mul/elf/riscv32im-succinct-zkvm-elf differ diff --git a/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf b/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf index 63f478712..93c15811c 100755 Binary files a/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf and b/tests/verify-proof/elf/riscv32im-succinct-zkvm-elf differ diff --git a/zkvm/precompiles/src/uint256_div.rs b/zkvm/precompiles/src/uint256_div.rs index c12a07b90..b10e0116c 100644 --- a/zkvm/precompiles/src/uint256_div.rs +++ b/zkvm/precompiles/src/uint256_div.rs @@ -11,7 +11,6 @@ use num::{BigUint, Integer}; /// represented as arrays of bytes in little-endian order. It returns the quotient /// of the division as a 256-bit unsigned integer in the same byte array format. pub fn uint256_div(x: &mut [u8; 32], y: &[u8; 32]) -> [u8; 32] { - // TODO: this will panic now. // Assert that the divisor is not zero. assert!(y != &[0; 32], "division by zero"); cfg_if::cfg_if! {