Skip to content

Commit

Permalink
Stark: Prover and Verifier over field extensions v3 (#724)
Browse files Browse the repository at this point in the history
* wip

* cast to fieldextension

* fmt

* clippy

* fmt

* fix parallel

* prover with generic air argument

* verifier with generic argument

* clippy, fmt

* wip, prover refactor

* wip

* fix stone serialization

* add safe scope for unwraps

* minor refactor

* clippy, fmt

* remove commented code. Fix typo in docs

* fix compilation error

* fmt

* fix parallel

* remove unnecessary trait bound

* fix wasm

* fix wasm

* change wasm proof bytes

* fix number of bytes

* fix wasm proof

* fix number of bytes

* wip

* wip, compiles

* wip

* wip. valid proofs

* fix trace ood element send order

* fix adapter

* clippy

* clippy

* clippy

* fmt

* remove unnecessary methods

* rename

* minor refactor. Avoid clone

* minor refactor

* add docstrings

* fix parallel

* add docstrings

* clippy

* fix test

* fix debug assert
  • Loading branch information
schouhy authored Dec 27, 2023
1 parent e1a8716 commit c0e190f
Show file tree
Hide file tree
Showing 21 changed files with 965 additions and 617 deletions.
309 changes: 169 additions & 140 deletions provers/cairo/src/air.rs

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions provers/cairo/src/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ fn check_simple_cairo_trace_evaluates_to_zero() {
assert!(validate_trace(
&cairo_air,
&trace_polys,
&aux_polys,
&domain,
&rap_challenges
));
Expand Down
402 changes: 201 additions & 201 deletions provers/cairo/tests/wasm.rs

Large diffs are not rendered by default.

38 changes: 32 additions & 6 deletions provers/stark/src/constraints/boundary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,45 @@ pub struct BoundaryConstraint<F: IsField> {
pub col: usize,
pub step: usize,
pub value: FieldElement<F>,
pub is_aux: bool,
}

impl<F: IsField> BoundaryConstraint<F> {
pub fn new(col: usize, step: usize, value: FieldElement<F>) -> Self {
Self { col, step, value }
pub fn new_main(col: usize, step: usize, value: FieldElement<F>) -> Self {
Self {
col,
step,
value,
is_aux: false,
}
}

pub fn new_aux(col: usize, step: usize, value: FieldElement<F>) -> Self {
Self {
col,
step,
value,
is_aux: true,
}
}

/// Used for creating boundary constraints for a trace with only one column
pub fn new_simple_main(step: usize, value: FieldElement<F>) -> Self {
Self {
col: 0,
step,
value,
is_aux: false,
}
}

/// Used for creating boundary constraints for a trace with only one column
pub fn new_simple(step: usize, value: FieldElement<F>) -> Self {
pub fn new_simple_aux(step: usize, value: FieldElement<F>) -> Self {
Self {
col: 0,
step,
value,
is_aux: true,
}
}
}
Expand Down Expand Up @@ -150,9 +176,9 @@ mod test {
// * a0 = 1
// * a1 = 1
// * a7 = 32
let a0 = BoundaryConstraint::new_simple(0, one);
let a1 = BoundaryConstraint::new_simple(1, one);
let result = BoundaryConstraint::new_simple(7, FieldElement::<PrimeField>::from(32));
let a0 = BoundaryConstraint::new_simple_main(0, one);
let a1 = BoundaryConstraint::new_simple_main(1, one);
let result = BoundaryConstraint::new_simple_main(7, FieldElement::<PrimeField>::from(32));

let constraints = BoundaryConstraints::from_constraints(vec![a0, a1, result]);

Expand Down
46 changes: 24 additions & 22 deletions provers/stark/src/constraints/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@ use rayon::prelude::{
use super::boundary::BoundaryConstraints;
#[cfg(all(debug_assertions, not(feature = "parallel")))]
use crate::debug::check_boundary_polys_divisibility;
use crate::domain::Domain;
use crate::trace::TraceTable;
use crate::traits::AIR;
use crate::{domain::Domain, table::EvaluationTable};
use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain};

pub struct ConstraintEvaluator<A: AIR> {
Expand All @@ -34,10 +33,10 @@ impl<A: AIR> ConstraintEvaluator<A> {
}
}

pub fn evaluate(
pub(crate) fn evaluate(
&self,
air: &A,
lde_trace: &TraceTable<A::FieldExtension>,
lde_table: &EvaluationTable<A::Field, A::FieldExtension>,
domain: &Domain<A::Field>,
transition_coefficients: &[FieldElement<A::FieldExtension>],
boundary_coefficients: &[FieldElement<A::FieldExtension>],
Expand Down Expand Up @@ -83,27 +82,30 @@ impl<A: AIR> ConstraintEvaluator<A> {
&domain.coset_offset,
)
})
.collect::<Result<Vec<Vec<FieldElement<A::FieldExtension>>>, FFTError>>()
.collect::<Result<Vec<Vec<FieldElement<A::Field>>>, FFTError>>()
.unwrap();

let n_col = lde_trace.n_cols();
let n_elem = domain.lde_roots_of_unity_coset.len();
let boundary_polys_evaluations = boundary_constraints
.constraints
.iter()
.map(|constraint| {
let col = constraint.col;
lde_trace
.table
.data
.iter()
.skip(col)
.step_by(n_col)
.take(n_elem)
.map(|v| -&constraint.value + v)
.collect::<Vec<FieldElement<A::FieldExtension>>>()
if constraint.is_aux {
(0..lde_table.n_rows())
.map(|row| {
let v = lde_table.get_aux(row, constraint.col);
v - &constraint.value
})
.collect()
} else {
(0..lde_table.n_rows())
.map(|row| {
let v = lde_table.get_main(row, constraint.col);
v - &constraint.value
})
.collect()
}
})
.collect::<Vec<Vec<FieldElement<A::FieldExtension>>>>();
.collect::<Vec<Vec<FieldElement<_>>>>();

#[cfg(feature = "parallel")]
let boundary_eval_iter = (0..domain.lde_roots_of_unity_coset.len()).into_par_iter();
Expand Down Expand Up @@ -182,22 +184,22 @@ impl<A: AIR> ConstraintEvaluator<A> {
.zip(&boundary_evaluation)
.zip(zerofier_iter)
.map(|((i, boundary), zerofier)| {
let frame = Frame::read_from_trace(
lde_trace,
let frame = Frame::<A::Field, A::FieldExtension>::read_from_lde_table(
lde_table,
i,
blowup_factor,
&air.context().transition_offsets,
);

let periodic_values: Vec<_> = lde_periodic_columns
.iter()
.map(|col| col[i].clone().to_extension())
.map(|col| col[i].clone())
.collect();

// Compute all the transition constraints at this
// point of the LDE domain.
let evaluations_transition =
air.compute_transition(&frame, &periodic_values, rap_challenges);
air.compute_transition_prover(&frame, &periodic_values, rap_challenges);

#[cfg(all(debug_assertions, not(feature = "parallel")))]
transition_evaluations.push(evaluations_transition.clone());
Expand Down
48 changes: 35 additions & 13 deletions provers/stark/src/debug.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::frame::Frame;
use crate::trace::TraceTable;
use crate::{frame::Frame, table::EvaluationTable};

use super::domain::Domain;
use super::traits::AIR;
Expand All @@ -15,29 +14,46 @@ use log::{error, info};
/// Validates that the trace is valid with respect to the supplied AIR constraints
pub fn validate_trace<A: AIR>(
air: &A,
trace_polys: &[Polynomial<FieldElement<A::FieldExtension>>],
main_trace_polys: &[Polynomial<FieldElement<A::Field>>],
aux_trace_polys: &[Polynomial<FieldElement<A::FieldExtension>>],
domain: &Domain<A::Field>,
rap_challenges: &A::RAPChallenges,
) -> bool {
info!("Starting constraints validation over trace...");
let mut ret = true;

let trace_columns: Vec<_> = trace_polys
let main_trace_columns: Vec<_> = main_trace_polys
.iter()
.map(|poly| {
Polynomial::<FieldElement<A::Field>>::evaluate_fft::<A::Field>(
poly,
1,
Some(domain.interpolation_domain_size),
)
.unwrap()
})
.collect();
let aux_trace_columns: Vec<_> = aux_trace_polys
.iter()
.map(|poly| {
Polynomial::evaluate_fft::<A::Field>(poly, 1, Some(domain.interpolation_domain_size))
.unwrap()
})
.collect();

let trace = TraceTable::from_columns(trace_columns, A::STEP_SIZE);
let lde_table =
EvaluationTable::from_columns(main_trace_columns, aux_trace_columns, A::STEP_SIZE);

let periodic_columns: Vec<_> = air
.get_periodic_column_polynomials()
.iter()
.map(|poly| {
Polynomial::evaluate_fft::<A::Field>(poly, 1, Some(domain.interpolation_domain_size))
.unwrap()
Polynomial::<FieldElement<A::Field>>::evaluate_fft::<A::Field>(
poly,
1,
Some(domain.interpolation_domain_size),
)
.unwrap()
})
.collect();

Expand All @@ -49,9 +65,14 @@ pub fn validate_trace<A: AIR>(
let col = constraint.col;
let step = constraint.step;
let boundary_value = constraint.value.clone();
let trace_value = trace.get(step, col);

if &boundary_value.clone().to_extension() != trace_value {
let trace_value = if !constraint.is_aux {
lde_table.get_main(step, col).clone().to_extension()
} else {
lde_table.get_aux(step, col).clone()
};

if boundary_value.clone().to_extension() != trace_value {
ret = false;
error!("Boundary constraint inconsistency - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value);
}
Expand All @@ -61,21 +82,22 @@ pub fn validate_trace<A: AIR>(
let n_transition_constraints = air.context().num_transition_constraints();
let transition_exemptions = &air.context().transition_exemptions;

let exemption_steps: Vec<usize> = vec![trace.n_rows(); n_transition_constraints]
let exemption_steps: Vec<usize> = vec![lde_table.n_rows(); n_transition_constraints]
.iter()
.zip(transition_exemptions)
.map(|(trace_steps, exemptions)| trace_steps - exemptions)
.collect();

// Iterate over trace and compute transitions
for step in 0..trace.num_steps() {
let frame = Frame::read_from_trace(&trace, step, 1, &air.context().transition_offsets);
for step in 0..lde_table.num_steps() {
let frame =
Frame::read_from_lde_table(&lde_table, step, 1, &air.context().transition_offsets);

let periodic_values: Vec<_> = periodic_columns
.iter()
.map(|col| col[step].clone())
.collect();
let evaluations = air.compute_transition(&frame, &periodic_values, rap_challenges);
let evaluations = air.compute_transition_prover(&frame, &periodic_values, rap_challenges);

// Iterate over each transition evaluation. When the evaluated step is not from
// the exemption steps corresponding to the transition, it should have zero as a
Expand Down
27 changes: 18 additions & 9 deletions provers/stark/src/examples/dummy_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,20 @@ impl AIR for DummyAIR {
_transcript: &mut impl IsStarkTranscript<Self::FieldExtension>,
) -> Self::RAPChallenges {
}
fn compute_transition(
fn compute_transition_prover(
&self,
frame: &Frame<Self::Field>,
frame: &Frame<Self::Field, Self::FieldExtension>,
_periodic_values: &[FieldElement<Self::Field>],
_rap_challenges: &Self::RAPChallenges,
) -> Vec<FieldElement<Self::Field>> {
) -> Vec<FieldElement<Self::FieldExtension>> {
let first_step = frame.get_evaluation_step(0);
let second_step = frame.get_evaluation_step(1);
let third_step = frame.get_evaluation_step(2);

let flag = first_step.get_evaluation_element(0, 0);
let a0 = first_step.get_evaluation_element(0, 1);
let a1 = second_step.get_evaluation_element(0, 1);
let a2 = third_step.get_evaluation_element(0, 1);
let flag = first_step.get_main_evaluation_element(0, 0);
let a0 = first_step.get_main_evaluation_element(0, 1);
let a1 = second_step.get_main_evaluation_element(0, 1);
let a2 = third_step.get_main_evaluation_element(0, 1);

let f_constraint = flag * (flag - FieldElement::one());

Expand All @@ -85,8 +85,8 @@ impl AIR for DummyAIR {
&self,
_rap_challenges: &Self::RAPChallenges,
) -> BoundaryConstraints<Self::Field> {
let a0 = BoundaryConstraint::new(1, 0, FieldElement::<Self::Field>::one());
let a1 = BoundaryConstraint::new(1, 1, FieldElement::<Self::Field>::one());
let a0 = BoundaryConstraint::new_main(1, 0, FieldElement::<Self::Field>::one());
let a1 = BoundaryConstraint::new_main(1, 1, FieldElement::<Self::Field>::one());

BoundaryConstraints::from_constraints(vec![a0, a1])
}
Expand All @@ -110,6 +110,15 @@ impl AIR for DummyAIR {
fn pub_inputs(&self) -> &Self::PublicInputs {
&()
}

fn compute_transition_verifier(
&self,
frame: &Frame<Self::FieldExtension, Self::FieldExtension>,
periodic_values: &[FieldElement<Self::FieldExtension>],
rap_challenges: &Self::RAPChallenges,
) -> Vec<FieldElement<Self::Field>> {
self.compute_transition_prover(frame, periodic_values, rap_challenges)
}
}

pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F> {
Expand Down
25 changes: 17 additions & 8 deletions provers/stark/src/examples/fibonacci_2_cols_shifted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,20 +93,20 @@ where
) -> Self::RAPChallenges {
}

fn compute_transition(
fn compute_transition_prover(
&self,
frame: &Frame<Self::Field>,
frame: &Frame<Self::Field, Self::FieldExtension>,
_periodic_values: &[FieldElement<Self::Field>],
_rap_challenges: &Self::RAPChallenges,
) -> Vec<FieldElement<Self::Field>> {
let first_row = frame.get_evaluation_step(0);
let second_row = frame.get_evaluation_step(1);

let a0_0 = first_row.get_evaluation_element(0, 0);
let a0_1 = first_row.get_evaluation_element(0, 1);
let a0_0 = first_row.get_main_evaluation_element(0, 0);
let a0_1 = first_row.get_main_evaluation_element(0, 1);

let a1_0 = second_row.get_evaluation_element(0, 0);
let a1_1 = second_row.get_evaluation_element(0, 1);
let a1_0 = second_row.get_main_evaluation_element(0, 0);
let a1_1 = second_row.get_main_evaluation_element(0, 1);

let first_transition = a1_0 - a0_1;
let second_transition = a1_1 - a0_0 - a0_1;
Expand All @@ -122,8 +122,8 @@ where
&self,
_rap_challenges: &Self::RAPChallenges,
) -> BoundaryConstraints<Self::Field> {
let initial_condition = BoundaryConstraint::new(0, 0, FieldElement::one());
let claimed_value_constraint = BoundaryConstraint::new(
let initial_condition = BoundaryConstraint::new_main(0, 0, FieldElement::one());
let claimed_value_constraint = BoundaryConstraint::new_main(
0,
self.pub_inputs.claimed_index,
self.pub_inputs.claimed_value.clone(),
Expand All @@ -147,6 +147,15 @@ where
fn pub_inputs(&self) -> &Self::PublicInputs {
&self.pub_inputs
}

fn compute_transition_verifier(
&self,
frame: &Frame<Self::FieldExtension, Self::FieldExtension>,
periodic_values: &[FieldElement<Self::FieldExtension>],
rap_challenges: &Self::RAPChallenges,
) -> Vec<FieldElement<Self::Field>> {
self.compute_transition_prover(frame, periodic_values, rap_challenges)
}
}

pub fn compute_trace<F: IsFFTField>(
Expand Down
Loading

0 comments on commit c0e190f

Please sign in to comment.