Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor AIR #937

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion provers/stark/src/constraints/evaluator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use crate::trace::LDETraceTable;
use crate::traits::AIR;
use crate::{frame::Frame, prover::evaluate_polynomial_on_lde_domain};
use itertools::Itertools;
#[cfg(all(debug_assertions, not(feature = "parallel")))]
#[cfg(not(feature = "parallel"))]
use lambdaworks_math::polynomial::Polynomial;
use lambdaworks_math::{fft::errors::FFTError, field::element::FieldElement, traits::AsBytes};
#[cfg(feature = "parallel")]
Expand Down
3 changes: 2 additions & 1 deletion provers/stark/src/constraints/transition.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::ops::Div;

use crate::domain::Domain;
use crate::frame::Frame;
use crate::prover::evaluate_polynomial_on_lde_domain;
Expand All @@ -6,7 +8,6 @@ use lambdaworks_math::field::element::FieldElement;
use lambdaworks_math::field::traits::{IsFFTField, IsField, IsSubFieldOf};
use lambdaworks_math::polynomial::Polynomial;
use num_integer::Integer;
use std::ops::Div;
/// TransitionConstraint represents the behaviour that a transition constraint
/// over the computation that wants to be proven must comply with.
pub trait TransitionConstraint<F, E>: Send + Sync
Expand Down
13 changes: 0 additions & 13 deletions provers/stark/src/context.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::collections::HashSet;

use super::proof::options::ProofOptions;

#[derive(Clone, Debug)]
Expand All @@ -13,22 +11,11 @@ pub struct AirContext {
/// offsets that are needed to compute EVERY transition constraint, even if some
/// constraints don't use all of the indexes in said offsets.
pub transition_offsets: Vec<usize>,
pub transition_exemptions: Vec<usize>,
pub num_transition_constraints: usize,
}

impl AirContext {
pub fn num_transition_constraints(&self) -> usize {
self.num_transition_constraints
}

/// Returns the number of non-trivial different
/// transition exemptions.
pub fn num_transition_exemptions(&self) -> usize {
self.transition_exemptions
.iter()
.filter(|&x| *x != 0)
.collect::<HashSet<_>>()
.len()
}
}
10 changes: 4 additions & 6 deletions provers/stark/src/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,10 @@ pub fn validate_trace<A: AIR>(

// --------- VALIDATE TRANSITION CONSTRAINTS -----------
let n_transition_constraints = air.context().num_transition_constraints();
let transition_exemptions = &air.context().transition_exemptions;

let exemption_steps: Vec<usize> = vec![lde_trace.num_rows(); n_transition_constraints]
.iter()
.zip(transition_exemptions)
.map(|(trace_steps, exemptions)| trace_steps - exemptions)
let exemption_steps: Vec<usize> = std::iter::repeat(lde_trace.num_steps())
.take(n_transition_constraints)
.zip(air.transition_constraints())
.map(|(trace_steps, constraint)| trace_steps - constraint.end_exemptions())
.collect();

// Iterate over trace and compute transitions
Expand Down
20 changes: 8 additions & 12 deletions provers/stark/src/examples/bit_flags.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,6 @@ impl TransitionConstraint<StarkField, StarkField> for ZeroFlagConstraint {
16
}

fn offset(&self) -> usize {
15
}

fn evaluate(
&self,
frame: &Frame<StarkField, StarkField>,
Expand Down Expand Up @@ -130,17 +126,12 @@ impl AIR for BitFlagsAIR {
let flag_constraint = Box::new(ZeroFlagConstraint::new());
let constraints: Vec<Box<dyn TransitionConstraint<Self::Field, Self::FieldExtension>>> =
vec![bit_constraint, flag_constraint];
// vec![flag_constraint];
// vec![bit_constraint];

let num_transition_constraints = constraints.len();
let transition_exemptions: Vec<_> =
constraints.iter().map(|c| c.end_exemptions()).collect();

let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 1,
transition_exemptions,
trace_columns: 2,
transition_offsets: vec![0],
num_transition_constraints,
};
Expand Down Expand Up @@ -195,7 +186,7 @@ impl AIR for BitFlagsAIR {
}
}

pub fn bit_prefix_flag_trace(num_steps: usize) -> TraceTable<StarkField> {
pub fn bit_prefix_flag_trace(num_steps: usize) -> TraceTable<StarkField, StarkField> {
debug_assert!(num_steps.is_power_of_two());
let step: Vec<Felt252> = [
1031u64, 515, 257, 128, 64, 32, 16, 8, 4, 2, 1, 0, 0, 0, 0, 0,
Expand All @@ -207,5 +198,10 @@ pub fn bit_prefix_flag_trace(num_steps: usize) -> TraceTable<StarkField> {
let mut data: Vec<Felt252> = iter::repeat(step).take(num_steps).flatten().collect();
data[0] = Felt252::from(1030);

TraceTable::new(data, 1, 0, 16)
let mut dummy_column = (0..16).map(Felt252::from).collect();
dummy_column = iter::repeat(dummy_column)
.take(num_steps)
.flatten()
.collect();
TraceTable::from_columns_main(vec![data, dummy_column], 16)
}
9 changes: 2 additions & 7 deletions provers/stark/src/examples/dummy_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ impl AIR for DummyAIR {
let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 2,
transition_exemptions: vec![0, 2],
transition_offsets: vec![0, 1, 2],
num_transition_constraints: 2,
};
Expand Down Expand Up @@ -198,7 +197,7 @@ impl AIR for DummyAIR {
}
}

pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F> {
pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F, F> {
let mut ret: Vec<FieldElement<F>> = vec![];

let a0 = FieldElement::one();
Expand All @@ -211,9 +210,5 @@ pub fn dummy_trace<F: IsFFTField>(trace_length: usize) -> TraceTable<F> {
ret.push(ret[i - 1].clone() + ret[i - 2].clone());
}

TraceTable::from_columns(
vec![vec![FieldElement::<F>::one(); trace_length], ret],
2,
1,
)
TraceTable::from_columns_main(vec![vec![FieldElement::<F>::one(); trace_length], ret], 1)
}
46 changes: 4 additions & 42 deletions provers/stark/src/examples/fibonacci_2_cols_shifted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,6 @@ where

let context = AirContext {
proof_options: proof_options.clone(),
transition_exemptions: vec![1, 1],
transition_offsets: vec![0, 1],
num_transition_constraints: 2,
trace_columns: 2,
Expand Down Expand Up @@ -238,7 +237,7 @@ where
pub fn compute_trace<F: IsFFTField>(
initial_value: FieldElement<F>,
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut x = FieldElement::one();
let mut y = initial_value;
let mut col0 = vec![x.clone()];
Expand All @@ -250,7 +249,7 @@ pub fn compute_trace<F: IsFFTField>(
col1.push(y.clone());
}

TraceTable::from_columns(vec![col0, col1], 2, 1)
TraceTable::from_columns_main(vec![col0, col1], 1)
}

#[cfg(test)]
Expand All @@ -264,46 +263,9 @@ mod tests {
#[test]
fn trace_has_expected_rows() {
let trace = compute_trace(FieldElement::<Stark252PrimeField>::one(), 8);
assert_eq!(trace.n_rows(), 8);
assert_eq!(trace.num_rows(), 8);

let trace = compute_trace(FieldElement::<Stark252PrimeField>::one(), 64);
assert_eq!(trace.n_rows(), 64);
}

#[test]
fn trace_of_8_rows_is_correctly_calculated() {
let trace = compute_trace(FieldElement::<Stark252PrimeField>::one(), 8);
assert_eq!(
trace.get_row(0),
vec![FieldElement::one(), FieldElement::one()]
);
assert_eq!(
trace.get_row(1),
vec![FieldElement::one(), FieldElement::from(2)]
);
assert_eq!(
trace.get_row(2),
vec![FieldElement::from(2), FieldElement::from(3)]
);
assert_eq!(
trace.get_row(3),
vec![FieldElement::from(3), FieldElement::from(5)]
);
assert_eq!(
trace.get_row(4),
vec![FieldElement::from(5), FieldElement::from(8)]
);
assert_eq!(
trace.get_row(5),
vec![FieldElement::from(8), FieldElement::from(13)]
);
assert_eq!(
trace.get_row(6),
vec![FieldElement::from(13), FieldElement::from(21)]
);
assert_eq!(
trace.get_row(7),
vec![FieldElement::from(21), FieldElement::from(34)]
);
assert_eq!(trace.num_rows(), 64);
}
}
5 changes: 2 additions & 3 deletions provers/stark/src/examples/fibonacci_2_columns.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ where

let context = AirContext {
proof_options: proof_options.clone(),
transition_exemptions: vec![1, 1],
transition_offsets: vec![0, 1],
num_transition_constraints: constraints.len(),
trace_columns: 2,
Expand Down Expand Up @@ -209,7 +208,7 @@ where
pub fn compute_trace<F: IsFFTField>(
initial_values: [FieldElement<F>; 2],
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut ret1: Vec<FieldElement<F>> = vec![];
let mut ret2: Vec<FieldElement<F>> = vec![];

Expand All @@ -222,5 +221,5 @@ pub fn compute_trace<F: IsFFTField>(
ret2.push(new_val + ret2[i - 1].clone());
}

TraceTable::from_columns(vec![ret1, ret2], 2, 1)
TraceTable::from_columns_main(vec![ret1, ret2], 1)
}
32 changes: 18 additions & 14 deletions provers/stark/src/examples/fibonacci_rap.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,13 +166,10 @@ where
Box::new(PermutationConstraint::new()),
];

let exemptions = 3 + trace_length - pub_inputs.steps - 1;

let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 3,
transition_offsets: vec![0, 1, 2],
transition_exemptions: vec![exemptions, 1],
num_transition_constraints: transition_constraints.len(),
};

Expand All @@ -186,15 +183,15 @@ where

fn build_auxiliary_trace(
&self,
main_trace: &TraceTable<Self::Field>,
trace: &mut TraceTable<Self::Field, Self::FieldExtension>,
challenges: &[FieldElement<F>],
) -> TraceTable<Self::Field> {
let main_segment_cols = main_trace.columns();
) {
let main_segment_cols = trace.columns_main();
let not_perm = &main_segment_cols[0];
let perm = &main_segment_cols[1];
let gamma = &challenges[0];

let trace_len = main_trace.n_rows();
let trace_len = trace.num_rows();

let mut aux_col = Vec::new();
for i in 0..trace_len {
Expand All @@ -208,7 +205,10 @@ where
aux_col.push(z_i * n_p_term.div(p_term));
}
}
TraceTable::from_columns(vec![aux_col], 0, 1)

for (i, aux_elem) in aux_col.iter().enumerate().take(trace.num_rows()) {
trace.set_aux(i, 0, aux_elem.clone())
}
}

fn build_rap_challenges(
Expand Down Expand Up @@ -236,7 +236,6 @@ where
let a0_aux = BoundaryConstraint::new_aux(0, 0, FieldElement::<Self::FieldExtension>::one());

BoundaryConstraints::from_constraints(vec![a0, a1, a0_aux])
// BoundaryConstraints::from_constraints(vec![a0, a1])
}

fn transition_constraints(
Expand Down Expand Up @@ -274,9 +273,8 @@ where
pub fn fibonacci_rap_trace<F: IsFFTField>(
initial_values: [FieldElement<F>; 2],
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut fib_seq: Vec<FieldElement<F>> = vec![];

fib_seq.push(initial_values[0].clone());
fib_seq.push(initial_values[1].clone());

Expand All @@ -294,7 +292,13 @@ pub fn fibonacci_rap_trace<F: IsFFTField>(
let mut trace_cols = vec![fib_seq, fib_permuted];
resize_to_next_power_of_two(&mut trace_cols);

TraceTable::from_columns(trace_cols, 2, 1)
let mut trace = TraceTable::allocate_with_zeros(trace_cols[0].len(), 2, 1, 1);
for i in 0..trace.num_rows() {
trace.set_main(i, 0, trace_cols[0][i].clone());
trace.set_main(i, 1, trace_cols[1][i].clone());
}

trace
}

#[cfg(test)]
Expand Down Expand Up @@ -337,13 +341,13 @@ mod test {
];
resize_to_next_power_of_two(&mut expected_trace);

assert_eq!(trace.columns(), expected_trace);
assert_eq!(trace.columns_main(), expected_trace);
}

#[test]
fn aux_col() {
let trace = fibonacci_rap_trace([FE17::from(1), FE17::from(1)], 64);
let trace_cols = trace.columns();
let trace_cols = trace.columns_main();

let not_perm = trace_cols[0].clone();
let perm = trace_cols[1].clone();
Expand Down
5 changes: 2 additions & 3 deletions provers/stark/src/examples/quadratic_air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,6 @@ where
let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 1,
transition_exemptions: vec![1],
transition_offsets: vec![0, 1],
num_transition_constraints: constraints.len(),
};
Expand Down Expand Up @@ -161,7 +160,7 @@ where
pub fn quadratic_trace<F: IsFFTField>(
initial_value: FieldElement<F>,
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut ret: Vec<FieldElement<F>> = vec![];

ret.push(initial_value);
Expand All @@ -170,5 +169,5 @@ pub fn quadratic_trace<F: IsFFTField>(
ret.push(ret[i - 1].clone() * ret[i - 1].clone());
}

TraceTable::from_columns(vec![ret], 1, 1)
TraceTable::from_columns_main(vec![ret], 1)
}
5 changes: 2 additions & 3 deletions provers/stark/src/examples/simple_fibonacci.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ where
let context = AirContext {
proof_options: proof_options.clone(),
trace_columns: 1,
transition_exemptions: vec![2],
transition_offsets: vec![0, 1, 2],
num_transition_constraints: constraints.len(),
};
Expand Down Expand Up @@ -162,7 +161,7 @@ where
pub fn fibonacci_trace<F: IsFFTField>(
initial_values: [FieldElement<F>; 2],
trace_length: usize,
) -> TraceTable<F> {
) -> TraceTable<F, F> {
let mut ret: Vec<FieldElement<F>> = vec![];

ret.push(initial_values[0].clone());
Expand All @@ -172,5 +171,5 @@ pub fn fibonacci_trace<F: IsFFTField>(
ret.push(ret[i - 1].clone() + ret[i - 2].clone());
}

TraceTable::from_columns(vec![ret], 1, 1)
TraceTable::from_columns_main(vec![ret], 1)
}
Loading
Loading