Skip to content

Commit

Permalink
Stark: Prover and Verifier over field extensions (#716)
Browse files Browse the repository at this point in the history
* wip

* cast to fieldextension

* fmt

* clippy

* fmt

* fix parallel

* reverse in place

Co-authored-by: Mario Rugiero <[email protected]>

---------

Co-authored-by: Mario Rugiero <[email protected]>
  • Loading branch information
schouhy and Oppen authored Dec 18, 2023
1 parent b284c34 commit 703db8a
Show file tree
Hide file tree
Showing 26 changed files with 689 additions and 399 deletions.
437 changes: 311 additions & 126 deletions math/src/polynomial.rs

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions provers/cairo/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ fn generate_range_check_permutation_argument_column(

impl AIR for CairoAIR {
type Field = Stark252PrimeField;
type FieldExtension = Stark252PrimeField;
type RAPChallenges = CairoRAPChallenges;
type PublicInputs = PublicInputs;

Expand Down
4 changes: 2 additions & 2 deletions provers/cairo/src/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,15 +141,15 @@ fn check_simple_cairo_trace_evaluates_to_zero() {
let program_content = std::fs::read(cairo0_program_path("simple_program.json")).unwrap();
let (main_trace, public_input) =
generate_prover_args(&program_content, CairoLayout::Plain).unwrap();
let mut trace_polys = main_trace.compute_trace_polys();
let mut trace_polys = main_trace.compute_trace_polys::<Stark252PrimeField>();
let mut transcript = StoneProverTranscript::new(&[]);

let proof_options = ProofOptions::default_test_options();
let cairo_air = CairoAIR::new(main_trace.n_rows(), &public_input, &proof_options);
let rap_challenges = cairo_air.build_rap_challenges(&mut transcript);

let aux_trace = cairo_air.build_auxiliary_trace(&main_trace, &rap_challenges);
let aux_polys = aux_trace.compute_trace_polys();
let aux_polys = aux_trace.compute_trace_polys::<Stark252PrimeField>();

trace_polys.extend_from_slice(&aux_polys);

Expand Down
15 changes: 9 additions & 6 deletions provers/plonk/src/prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use crate::setup::{
new_strong_fiat_shamir_transcript, CommonPreprocessedInput, VerificationKey, Witness,
};
use lambdaworks_crypto::commitments::traits::IsCommitmentScheme;
use lambdaworks_math::{field::element::FieldElement, polynomial::Polynomial};
use lambdaworks_math::{
field::element::FieldElement,
polynomial::{self, Polynomial},
};
use lambdaworks_math::{field::traits::IsField, traits::ByteConversion};

/// Plonk proof.
Expand Down Expand Up @@ -318,7 +321,7 @@ where
.expect("xs and ys have equal length and xs are unique");

let z_h = Polynomial::new_monomial(FieldElement::one(), common_preprocessed_input.n)
- FieldElement::one();
- FieldElement::<F>::one();
let p_a = self.blind_polynomial(&p_a, &z_h, 2);
let p_b = self.blind_polynomial(&p_b, &z_h, 2);
let p_c = self.blind_polynomial(&p_c, &z_h, 2);
Expand Down Expand Up @@ -366,7 +369,7 @@ where
let p_z = Polynomial::interpolate_fft::<F>(&coefficients)
.expect("xs and ys have equal length and xs are unique");
let z_h = Polynomial::new_monomial(FieldElement::one(), common_preprocessed_input.n)
- FieldElement::one();
- FieldElement::<F>::one();
let p_z = self.blind_polynomial(&p_z, &z_h, 3);
let z_1 = self.commitment_scheme.commit(&p_z);
Round2Result {
Expand All @@ -392,7 +395,7 @@ where

let one = Polynomial::new_monomial(FieldElement::one(), 0);
let p_x = &Polynomial::new_monomial(FieldElement::<F>::one(), 1);
let zh = Polynomial::new_monomial(FieldElement::one(), cpi.n) - &one;
let zh = Polynomial::new_monomial(FieldElement::<F>::one(), cpi.n) - &one;

let z_x_omega_coefficients: Vec<FieldElement<F>> = p_z
.coefficients()
Expand Down Expand Up @@ -503,7 +506,7 @@ where
.collect();
let mut t = Polynomial::interpolate_offset_fft(&c, offset).unwrap();

Polynomial::pad_with_zero_coefficients_to_length(&mut t, 3 * (&cpi.n + 2));
polynomial::pad_with_zero_coefficients_to_length(&mut t, 3 * (&cpi.n + 2));
let p_t_lo = Polynomial::new(&t.coefficients[..&cpi.n + 2]);
let p_t_mid = Polynomial::new(&t.coefficients[&cpi.n + 2..2 * (&cpi.n + 2)]);
let p_t_hi = Polynomial::new(&t.coefficients[2 * (&cpi.n + 2)..3 * (&cpi.n + 2)]);
Expand Down Expand Up @@ -573,7 +576,7 @@ where

let l1_zeta = (&r4.zeta.pow(cpi.n as u64) - FieldElement::<F>::one())
/ (&r4.zeta - FieldElement::<F>::one())
/ FieldElement::from(cpi.n as u64);
/ FieldElement::<F>::from(cpi.n as u64);

let mut p_non_constant = &cpi.qm * &r4.a_zeta * &r4.b_zeta
+ &r4.a_zeta * &cpi.ql
Expand Down
77 changes: 41 additions & 36 deletions provers/stark/src/constraints/evaluator.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use itertools::Itertools;
use lambdaworks_math::{
fft::{cpu::roots_of_unity::get_powers_of_primitive_root_coset, errors::FFTError},
field::{element::FieldElement, traits::IsFFTField},
field::{
element::FieldElement,
traits::{IsFFTField, IsField, IsSubFieldOf},
},
polynomial::Polynomial,
traits::Serializable,
};
Expand All @@ -19,35 +22,36 @@ use crate::trace::TraceTable;
use crate::traits::AIR;
use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain};

pub struct ConstraintEvaluator<F: IsFFTField> {
boundary_constraints: BoundaryConstraints<F>,
pub struct ConstraintEvaluator<A: AIR> {
boundary_constraints: BoundaryConstraints<A::FieldExtension>,
}
impl<F: IsFFTField> ConstraintEvaluator<F> {
pub fn new<A: AIR<Field = F>>(air: &A, rap_challenges: &A::RAPChallenges) -> Self {
impl<A: AIR> ConstraintEvaluator<A> {
pub fn new(air: &A, rap_challenges: &A::RAPChallenges) -> Self {
let boundary_constraints = air.boundary_constraints(rap_challenges);

Self {
boundary_constraints,
}
}

pub fn evaluate<A: AIR<Field = F>>(
pub fn evaluate(
&self,
air: &A,
lde_trace: &TraceTable<F>,
domain: &Domain<F>,
transition_coefficients: &[FieldElement<F>],
boundary_coefficients: &[FieldElement<F>],
lde_trace: &TraceTable<A::FieldExtension>,
domain: &Domain<A::Field>,
transition_coefficients: &[FieldElement<A::FieldExtension>],
boundary_coefficients: &[FieldElement<A::FieldExtension>],
rap_challenges: &A::RAPChallenges,
) -> Vec<FieldElement<F>>
) -> Vec<FieldElement<A::FieldExtension>>
where
FieldElement<F>: Serializable + Send + Sync,
FieldElement<A::Field>: Serializable + Send + Sync,
FieldElement<A::FieldExtension>: Serializable + Send + Sync,
A: Send + Sync,
A::RAPChallenges: Send + Sync,
{
let boundary_constraints = &self.boundary_constraints;
let number_of_b_constraints = boundary_constraints.constraints.len();
let boundary_zerofiers_inverse_evaluations: Vec<Vec<FieldElement<F>>> =
let boundary_zerofiers_inverse_evaluations: Vec<Vec<FieldElement<A::Field>>> =
boundary_constraints
.constraints
.iter()
Expand All @@ -57,16 +61,16 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {
.lde_roots_of_unity_coset
.iter()
.map(|v| v.clone() - point)
.collect::<Vec<FieldElement<F>>>();
.collect::<Vec<FieldElement<A::Field>>>();
FieldElement::inplace_batch_inverse(&mut evals).unwrap();
evals
})
.collect::<Vec<Vec<FieldElement<F>>>>();
.collect::<Vec<Vec<FieldElement<A::Field>>>>();

let trace_length = air.trace_length();

#[cfg(all(debug_assertions, not(feature = "parallel")))]
let boundary_polys: Vec<Polynomial<FieldElement<F>>> = Vec::new();
let boundary_polys: Vec<Polynomial<FieldElement<A::Field>>> = Vec::new();

let lde_periodic_columns = air
.get_periodic_column_polynomials()
Expand All @@ -79,7 +83,7 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {
&domain.coset_offset,
)
})
.collect::<Result<Vec<Vec<FieldElement<A::Field>>>, FFTError>>()
.collect::<Result<Vec<Vec<FieldElement<A::FieldExtension>>>, FFTError>>()
.unwrap();

let n_col = lde_trace.n_cols();
Expand All @@ -96,10 +100,10 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {
.skip(col)
.step_by(n_col)
.take(n_elem)
.map(|v| v - &constraint.value)
.collect::<Vec<FieldElement<F>>>()
.map(|v| -&constraint.value + v)
.collect::<Vec<FieldElement<A::FieldExtension>>>()
})
.collect::<Vec<Vec<FieldElement<F>>>>();
.collect::<Vec<Vec<FieldElement<A::FieldExtension>>>>();

#[cfg(feature = "parallel")]
let boundary_eval_iter = (0..domain.lde_roots_of_unity_coset.len()).into_par_iter();
Expand All @@ -117,7 +121,7 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {
* &boundary_polys_evaluations[constraint_index][domain_index]
})
})
.collect::<Vec<FieldElement<F>>>();
.collect::<Vec<FieldElement<A::FieldExtension>>>();

#[cfg(all(debug_assertions, not(feature = "parallel")))]
let boundary_zerofiers = Vec::new();
Expand All @@ -138,10 +142,10 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {

let blowup_factor_order = u64::from(blowup_factor.trailing_zeros());

let offset = FieldElement::<F>::from(air.context().proof_options.coset_offset);
let offset = FieldElement::<A::Field>::from(air.context().proof_options.coset_offset);
let offset_pow = offset.pow(trace_length);
let one = FieldElement::<F>::one();
let mut zerofier_evaluations = get_powers_of_primitive_root_coset(
let one = FieldElement::one();
let mut zerofier_evaluations = get_powers_of_primitive_root_coset::<A::Field>(
blowup_factor_order,
blowup_factor as usize,
&offset_pow,
Expand Down Expand Up @@ -187,7 +191,7 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {

let periodic_values: Vec<_> = lde_periodic_columns
.iter()
.map(|col| col[i].clone())
.map(|col| col[i].clone().to_extension())
.collect();

// Compute all the transition constraints at this
Expand All @@ -212,14 +216,14 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {
// If there's no exemption, then
// the zerofier remains as it was.
if *exemption == 0 {
acc + zerofier * beta * eval
acc + eval * zerofier * beta
} else {
//TODO: change how exemptions are indexed!
if num_exemptions == 1 {
acc + zerofier
* beta
* eval
acc + eval
* &transition_exemptions_evaluations[0][i]
* beta
* zerofier
} else {
// This case is not used for Cairo Programs, it can be improved in the future
let vector = air
Expand All @@ -235,29 +239,30 @@ impl<F: IsFFTField> ConstraintEvaluator<F> {
.position(|elem_2| elem_2 == exemption)
.expect("is there");

acc + zerofier
* beta
* eval
acc + eval
* &transition_exemptions_evaluations[index][i]
* zerofier
* beta
}
}
});
// TODO: Remove clones

acc_transition + boundary
})
.collect::<Vec<FieldElement<F>>>();
.collect::<Vec<FieldElement<A::FieldExtension>>>();

evaluations_t
}
}

fn evaluate_transition_exemptions<F: IsFFTField>(
transition_exemptions: Vec<Polynomial<FieldElement<F>>>,
fn evaluate_transition_exemptions<F: IsFFTField + IsSubFieldOf<E>, E: IsField>(
transition_exemptions: Vec<Polynomial<FieldElement<E>>>,
domain: &Domain<F>,
) -> Vec<Vec<FieldElement<F>>>
) -> Vec<Vec<FieldElement<E>>>
where
FieldElement<F>: Send + Sync + Serializable,
FieldElement<E>: Send + Sync + Serializable,
Polynomial<FieldElement<F>>: Send + Sync,
{
#[cfg(feature = "parallel")]
Expand Down
21 changes: 13 additions & 8 deletions provers/stark/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,18 @@ use crate::trace::TraceTable;
use super::domain::Domain;
use super::traits::AIR;
use lambdaworks_math::{
field::{element::FieldElement, traits::IsFFTField},
field::{
element::FieldElement,
traits::{IsFFTField, IsField},
},
polynomial::Polynomial,
};
use log::{error, info};

/// Validates that the trace is valid with respect to the supplied AIR constraints
pub fn validate_trace<F: IsFFTField, A: AIR<Field = F>>(
pub fn validate_trace<A: AIR>(
air: &A,
trace_polys: &[Polynomial<FieldElement<A::Field>>],
trace_polys: &[Polynomial<FieldElement<A::FieldExtension>>],
domain: &Domain<A::Field>,
rap_challenges: &A::RAPChallenges,
) -> bool {
Expand All @@ -22,7 +25,8 @@ pub fn validate_trace<F: IsFFTField, A: AIR<Field = F>>(
let trace_columns: Vec<_> = trace_polys
.iter()
.map(|poly| {
Polynomial::evaluate_fft::<F>(poly, 1, Some(domain.interpolation_domain_size)).unwrap()
Polynomial::evaluate_fft::<A::Field>(poly, 1, Some(domain.interpolation_domain_size))
.unwrap()
})
.collect();

Expand All @@ -32,7 +36,8 @@ pub fn validate_trace<F: IsFFTField, A: AIR<Field = F>>(
.get_periodic_column_polynomials()
.iter()
.map(|poly| {
Polynomial::evaluate_fft::<F>(poly, 1, Some(domain.interpolation_domain_size)).unwrap()
Polynomial::evaluate_fft::<A::Field>(poly, 1, Some(domain.interpolation_domain_size))
.unwrap()
})
.collect();

Expand All @@ -46,7 +51,7 @@ pub fn validate_trace<F: IsFFTField, A: AIR<Field = F>>(
let boundary_value = constraint.value.clone();
let trace_value = trace.get(step, col);

if &boundary_value != trace_value {
if &boundary_value.clone().to_extension() != trace_value {
ret = false;
error!("Boundary constraint inconsistency - Expected value {:?} in step {} and column {}, found: {:?}", boundary_value, step, col, trace_value);
}
Expand Down Expand Up @@ -78,7 +83,7 @@ pub fn validate_trace<F: IsFFTField, A: AIR<Field = F>>(
evaluations.iter().enumerate().for_each(|(i, eval)| {
// Check that all the transition constraint evaluations of the trace are zero.
// We don't take into account the transition exemptions.
if step < exemption_steps[i] && eval != &FieldElement::<F>::zero() {
if step < exemption_steps[i] && eval != &FieldElement::zero() {
ret = false;
error!(
"Inconsistent evaluation of transition {} in step {} - expected 0, got {:?}",
Expand Down Expand Up @@ -111,7 +116,7 @@ pub fn check_boundary_polys_divisibility<F: IsFFTField>(
/// array, returning a true when valid and false when not.
pub fn validate_2d_structure<F>(data: &[FieldElement<F>], width: usize) -> bool
where
F: IsFFTField,
F: IsField,
{
let rows: Vec<Vec<FieldElement<F>>> = data.chunks(width).map(|c| c.to_vec()).collect();
rows.iter().all(|r| r.len() == rows[0].len())
Expand Down
3 changes: 2 additions & 1 deletion provers/stark/src/examples/dummy_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pub struct DummyAIR {

impl AIR for DummyAIR {
type Field = Stark252PrimeField;
type FieldExtension = Stark252PrimeField;
type RAPChallenges = ();
type PublicInputs = ();

Expand Down Expand Up @@ -55,7 +56,7 @@ impl AIR for DummyAIR {

fn build_rap_challenges(
&self,
_transcript: &mut impl IsStarkTranscript<Self::Field>,
_transcript: &mut impl IsStarkTranscript<Self::FieldExtension>,
) -> Self::RAPChallenges {
}
fn compute_transition(
Expand Down
3 changes: 2 additions & 1 deletion provers/stark/src/examples/fibonacci_2_cols_shifted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ where
F: IsFFTField,
{
type Field = F;
type FieldExtension = F;
type RAPChallenges = ();
type PublicInputs = PublicInputs<Self::Field>;

Expand Down Expand Up @@ -88,7 +89,7 @@ where

fn build_rap_challenges(
&self,
_transcript: &mut impl IsStarkTranscript<Self::Field>,
_transcript: &mut impl IsStarkTranscript<Self::FieldExtension>,
) -> Self::RAPChallenges {
}

Expand Down
3 changes: 2 additions & 1 deletion provers/stark/src/examples/fibonacci_2_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ where
F: IsFFTField,
{
type Field = F;
type FieldExtension = F;
type RAPChallenges = ();
type PublicInputs = FibonacciPublicInputs<Self::Field>;

Expand Down Expand Up @@ -64,7 +65,7 @@ where

fn build_rap_challenges(
&self,
_transcript: &mut impl IsStarkTranscript<Self::Field>,
_transcript: &mut impl IsStarkTranscript<Self::FieldExtension>,
) -> Self::RAPChallenges {
}

Expand Down
Loading

0 comments on commit 703db8a

Please sign in to comment.