From 817d62bfdf79ae8de186842c5ecbf33926c98c30 Mon Sep 17 00:00:00 2001 From: Kevin Jue Date: Fri, 2 Aug 2024 11:35:48 -0700 Subject: [PATCH] feat: recursion circuit p2 wide constraints (#1205) --- .../core-v2/src/chips/poseidon2_wide/air.rs | 122 ++++++++++++++++- .../poseidon2_wide/columns/permutation.rs | 4 +- .../core-v2/src/chips/poseidon2_wide/trace.rs | 126 ++++++++++-------- 3 files changed, 191 insertions(+), 61 deletions(-) diff --git a/recursion/core-v2/src/chips/poseidon2_wide/air.rs b/recursion/core-v2/src/chips/poseidon2_wide/air.rs index b0320c9a1..6ab6767f4 100644 --- a/recursion/core-v2/src/chips/poseidon2_wide/air.rs +++ b/recursion/core-v2/src/chips/poseidon2_wide/air.rs @@ -1,16 +1,23 @@ //! The air module contains the AIR constraints for the poseidon2 chip. //! At the moment, we're only including memory constraints to test the new memory argument. +use std::array; use std::borrow::Borrow; use p3_air::{Air, BaseAir, PairBuilder}; +use p3_field::AbstractField; use p3_matrix::Matrix; +use sp1_primitives::RC_16_30_U32; +use sp1_recursion_core::poseidon2_wide::NUM_EXTERNAL_ROUNDS; use crate::builder::SP1RecursionAirBuilder; +use super::columns::permutation::Poseidon2; use super::columns::preprocessed::Poseidon2PreprocessedCols; use super::columns::{NUM_POSEIDON2_DEGREE3_COLS, NUM_POSEIDON2_DEGREE9_COLS}; -use super::{Poseidon2WideChip, WIDTH}; +use super::{ + external_linear_layer, internal_linear_layer, Poseidon2WideChip, NUM_INTERNAL_ROUNDS, WIDTH, +}; impl BaseAir for Poseidon2WideChip { fn width(&self) -> usize { @@ -61,5 +68,118 @@ where prep_local.memory_preprocessed[i + WIDTH].write_mult, ) }); + + // Apply the external rounds. + for r in 0..NUM_EXTERNAL_ROUNDS { + self.eval_external_round(builder, local_row.as_ref(), r); + } + + // Apply the internal rounds. + self.eval_internal_rounds(builder, local_row.as_ref()); + } +} + +impl Poseidon2WideChip { + /// Eval the constraints for the external rounds. + fn eval_external_round( + &self, + builder: &mut AB, + local_row: &dyn Poseidon2, + r: usize, + ) where + AB: SP1RecursionAirBuilder + PairBuilder, + { + let mut local_state: [AB::Expr; WIDTH] = + array::from_fn(|i| local_row.external_rounds_state()[r][i].into()); + + // For the first round, apply the linear layer. + if r == 0 { + external_linear_layer(&mut local_state); + } + + // Add the round constants. + let round = if r < NUM_EXTERNAL_ROUNDS / 2 { + r + } else { + r + NUM_INTERNAL_ROUNDS + }; + let add_rc: [AB::Expr; WIDTH] = array::from_fn(|i| { + local_state[i].clone() + AB::F::from_wrapped_u32(RC_16_30_U32[round][i]) + }); + + // Apply the sboxes. + // See `populate_external_round` for why we don't have columns for the sbox output here. + let mut sbox_deg_7: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); + let mut sbox_deg_3: [AB::Expr; WIDTH] = core::array::from_fn(|_| AB::Expr::zero()); + for i in 0..WIDTH { + let calculated_sbox_deg_3 = add_rc[i].clone() * add_rc[i].clone() * add_rc[i].clone(); + + if let Some(external_sbox) = local_row.external_rounds_sbox() { + // builder.assert_eq(external_sbox[r][i].into(), calculated_sbox_deg_3); + sbox_deg_3[i] = external_sbox[r][i].into(); + } else { + sbox_deg_3[i] = calculated_sbox_deg_3; + } + + sbox_deg_7[i] = sbox_deg_3[i].clone() * sbox_deg_3[i].clone() * add_rc[i].clone(); + } + + // Apply the linear layer. + let mut state = sbox_deg_7; + external_linear_layer(&mut state); + + let next_state = if r == (NUM_EXTERNAL_ROUNDS / 2) - 1 { + local_row.internal_rounds_state() + } else if r == NUM_EXTERNAL_ROUNDS - 1 { + local_row.perm_output() + } else { + &local_row.external_rounds_state()[r + 1] + }; + + for i in 0..WIDTH { + builder.assert_eq(next_state[i], state[i].clone()); + } + } + + /// Eval the constraints for the internal rounds. + fn eval_internal_rounds(&self, builder: &mut AB, local_row: &dyn Poseidon2) + where + AB: SP1RecursionAirBuilder + PairBuilder, + { + let state = &local_row.internal_rounds_state(); + let s0 = local_row.internal_rounds_s0(); + let mut state: [AB::Expr; WIDTH] = core::array::from_fn(|i| state[i].into()); + for r in 0..NUM_INTERNAL_ROUNDS { + // Add the round constant. + let round = r + NUM_EXTERNAL_ROUNDS / 2; + let add_rc = if r == 0 { + state[0].clone() + } else { + s0[r - 1].into() + } + AB::Expr::from_wrapped_u32(RC_16_30_U32[round][0]); + + let mut sbox_deg_3 = add_rc.clone() * add_rc.clone() * add_rc.clone(); + if let Some(internal_sbox) = local_row.internal_rounds_sbox() { + builder.assert_eq(internal_sbox[r], sbox_deg_3); + sbox_deg_3 = internal_sbox[r].into(); + } + + // See `populate_internal_rounds` for why we don't have columns for the sbox output here. + let sbox_deg_7 = sbox_deg_3.clone() * sbox_deg_3.clone() * add_rc.clone(); + + // Apply the linear layer. + // See `populate_internal_rounds` for why we don't have columns for the new state here. + state[0] = sbox_deg_7.clone(); + internal_linear_layer(&mut state); + + if r < NUM_INTERNAL_ROUNDS - 1 { + builder.assert_eq(s0[r], state[0].clone()); + } + } + + let external_state = local_row.external_rounds_state()[NUM_EXTERNAL_ROUNDS / 2]; + for i in 0..WIDTH { + builder.assert_eq(external_state[i], state[i].clone()) + } } } diff --git a/recursion/core-v2/src/chips/poseidon2_wide/columns/permutation.rs b/recursion/core-v2/src/chips/poseidon2_wide/columns/permutation.rs index 7bac88dfc..54f54d407 100644 --- a/recursion/core-v2/src/chips/poseidon2_wide/columns/permutation.rs +++ b/recursion/core-v2/src/chips/poseidon2_wide/columns/permutation.rs @@ -21,11 +21,9 @@ pub const fn max(a: usize, b: usize) -> usize { #[derive(AlignedBorrow, Clone, Copy)] #[repr(C)] pub struct PermutationState { - pub external_rounds_state: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS + 1], + pub external_rounds_state: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], pub internal_rounds_state: [T; WIDTH], pub internal_rounds_s0: [T; NUM_INTERNAL_ROUNDS - 1], - pub external_rounds_sbox: [[T; WIDTH]; NUM_EXTERNAL_ROUNDS], - pub internal_rounds_sbox: [T; NUM_INTERNAL_ROUNDS], pub output_state: [T; WIDTH], } diff --git a/recursion/core-v2/src/chips/poseidon2_wide/trace.rs b/recursion/core-v2/src/chips/poseidon2_wide/trace.rs index 0941cb579..14ec93b07 100644 --- a/recursion/core-v2/src/chips/poseidon2_wide/trace.rs +++ b/recursion/core-v2/src/chips/poseidon2_wide/trace.rs @@ -47,67 +47,18 @@ impl MachineAir for Poseidon2WideChip as BaseAir>::width(self); for event in &input.poseidon2_wide_events { - let mut input_row = vec![F::zero(); num_columns]; - - { - let permutation = permutation_mut::(&mut input_row); - - let ( - external_rounds_state, - internal_rounds_state, - internal_rounds_s0, - mut external_sbox, - mut internal_sbox, - output_state, - ) = permutation.get_cols_mut(); - - external_rounds_state[0] = event.input; - external_rounds_state[1] = - external_linear_layer_immut(&external_rounds_state[0].clone()); - - // Apply the first half of external rounds. - for r in 0..NUM_EXTERNAL_ROUNDS / 2 { - let next_state = - self.populate_external_round(external_rounds_state, &mut external_sbox, r); - if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { - *internal_rounds_state = next_state; - } else { - external_rounds_state[r + 2] = next_state; - } - } - - // Apply the internal rounds. - external_rounds_state[NUM_EXTERNAL_ROUNDS / 2 + 1] = self.populate_internal_rounds( - internal_rounds_state, - internal_rounds_s0, - &mut internal_sbox, - ); - - // Apply the second half of external rounds. - for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { - let next_state = - self.populate_external_round(external_rounds_state, &mut external_sbox, r); - if r == NUM_EXTERNAL_ROUNDS - 1 { - for i in 0..WIDTH { - output_state[i] = next_state[i]; - assert_eq!(event.output[i], next_state[i]); - } - } else { - external_rounds_state[r + 2] = next_state; - } - } - } - rows.push(input_row); + let mut row = vec![F::zero(); num_columns]; + self.populate_perm(event.input, Some(event.output), row.as_mut_slice()); + rows.push(row); } if self.pad { // Pad the trace to a power of two. // This will need to be adjusted when the AIR constraints are implemented. - pad_rows_fixed( - &mut rows, - || vec![F::zero(); num_columns], - self.fixed_log2_rows, - ); + let mut dummy_row = vec![F::zero(); num_columns]; + self.populate_perm([F::zero(); WIDTH], None, &mut dummy_row); + + pad_rows_fixed(&mut rows, || dummy_row.clone(), self.fixed_log2_rows); } // Convert the trace to a row major matrix. @@ -188,6 +139,62 @@ impl MachineAir for Poseidon2WideChip Poseidon2WideChip { + fn populate_perm( + &self, + input: [F; WIDTH], + expected_output: Option<[F; WIDTH]>, + input_row: &mut [F], + ) { + { + let permutation = permutation_mut::(input_row); + + let ( + external_rounds_state, + internal_rounds_state, + internal_rounds_s0, + mut external_sbox, + mut internal_sbox, + output_state, + ) = permutation.get_cols_mut(); + + external_rounds_state[0] = input; + + // Apply the first half of external rounds. + for r in 0..NUM_EXTERNAL_ROUNDS / 2 { + let next_state = + self.populate_external_round(external_rounds_state, &mut external_sbox, r); + if r == NUM_EXTERNAL_ROUNDS / 2 - 1 { + *internal_rounds_state = next_state; + } else { + external_rounds_state[r + 1] = next_state; + } + } + + // Apply the internal rounds. + external_rounds_state[NUM_EXTERNAL_ROUNDS / 2] = self.populate_internal_rounds( + internal_rounds_state, + internal_rounds_s0, + &mut internal_sbox, + ); + + // Apply the second half of external rounds. + for r in NUM_EXTERNAL_ROUNDS / 2..NUM_EXTERNAL_ROUNDS { + let next_state = + self.populate_external_round(external_rounds_state, &mut external_sbox, r); + if r == NUM_EXTERNAL_ROUNDS - 1 { + for i in 0..WIDTH { + output_state[i] = next_state[i]; + if let Some(expected_output) = expected_output { + assert_eq!(expected_output[i], next_state[i]); + } + } + } else { + external_rounds_state[r + 1] = next_state; + } + } + } + } + fn populate_external_round( &self, external_rounds_state: &[[F; WIDTH]], @@ -195,7 +202,12 @@ impl Poseidon2WideChip { r: usize, ) -> [F; WIDTH] { let mut state = { - let round_state: &[F; WIDTH] = &external_rounds_state[r + 1]; + // For the first round, apply the linear layer. + let round_state: &[F; WIDTH] = if r == 0 { + &external_linear_layer_immut(&external_rounds_state[r]) + } else { + &external_rounds_state[r] + }; // Add round constants. //