Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public inputs refactor #676

Merged
merged 35 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b14e8f0
rename MemorySegment to SegmentName
juan518munoz Nov 9, 2023
efb2435
add SegmentName enum options
juan518munoz Nov 9, 2023
6edc4c3
Segment struct declaration
juan518munoz Nov 9, 2023
ed2be99
initial integratation of struct Segment
juan518munoz Nov 9, 2023
0579bd9
rmv comment
juan518munoz Nov 9, 2023
68497f3
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 13, 2023
2c75f5a
update cairo-vm
juan518munoz Nov 13, 2023
f5a601b
change cairo-vm branch
juan518munoz Nov 14, 2023
2a7d3eb
add ecdsa & pedersen
juan518munoz Nov 14, 2023
4284f8a
test pub inputs from vm
juan518munoz Nov 15, 2023
97e1e44
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 15, 2023
f73fc39
public inputs most fields derived from vm
juan518munoz Nov 15, 2023
8ba4bcb
clippy & fmt
juan518munoz Nov 15, 2023
845ebd1
update cairo-vm dep to latest rev
juan518munoz Nov 15, 2023
2ae0109
get memory segment from vm
juan518munoz Nov 16, 2023
b556554
rename data_len to codelen
juan518munoz Nov 16, 2023
e13cc69
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 16, 2023
14d3c14
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 17, 2023
32d75d5
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 17, 2023
56024d0
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 4, 2023
be520f9
Extract public memory directly from public inputs of Cairo VM
entropidelic Dec 4, 2023
666f007
Refactor get_memory_holes function
entropidelic Dec 5, 2023
9d720cd
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 5, 2023
f495565
Remove unnecessary function and legacy test
entropidelic Dec 5, 2023
b317f1b
Remove some commented code and fix some tests
entropidelic Dec 5, 2023
f1b9293
Solve clippy issues
entropidelic Dec 5, 2023
95168ba
Remove legacy test
entropidelic Dec 5, 2023
f6817ee
Fix some comments on tests
entropidelic Dec 5, 2023
8415e05
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 5, 2023
7bdce89
Refactor pub addresses in add_pub_memory_in_public_input_section func…
entropidelic Dec 6, 2023
40e5211
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 6, 2023
17c5b1b
Merge remote-tracking branch 'origin/PublicInputsRefactor' into Publi…
entropidelic Dec 6, 2023
65b8e09
Refactor segment_size method
entropidelic Dec 6, 2023
a0b7612
iterate over addr value pairs in add_pub_memory_in_public_input_secti…
entropidelic Dec 6, 2023
58a5d4e
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion provers/cairo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ thiserror = "1.0.38"
log = "0.4.17"
bincode = { version = "2.0.0-rc.2", tag = "v2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", features= ['serde'] }
# NOTE: For cairo 1 compatibility, add the `cairo-1-hints` feature.
cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm", rev = "e763cef", default-features = false }
cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm", branch = "pub-field-apinputs", default-features = false }
sha3 = "0.10.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand Down
69 changes: 49 additions & 20 deletions provers/cairo/src/air.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,35 @@ pub const MEM_A_TRACE_OFFSET: usize = 19;
const BUILTIN_OFFSET: usize = 9;

#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum MemorySegment {
pub enum SegmentName {
RangeCheck,
Output,
Program,
Execution,
}

pub type MemorySegmentMap = HashMap<MemorySegment, Range<u64>>;
#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Segment {
pub begin_addr: u64,
pub stop_ptr: u64,
}

impl From<Range<u64>> for Segment {
fn from(range: Range<u64>) -> Self {
Segment {
begin_addr: range.start,
stop_ptr: range.end,
}
}
}

impl From<Segment> for Range<u64> {
fn from(val: Segment) -> Self {
val.begin_addr..val.stop_ptr
}
}

pub type MemorySegmentMap = HashMap<SegmentName, Segment>;

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PublicInputs {
Expand Down Expand Up @@ -204,14 +227,15 @@ impl PublicInputs {
codelen: usize,
memory_segments: &MemorySegmentMap,
) -> Self {
let output_range = memory_segments.get(&MemorySegment::Output);
let output_segment = memory_segments.get(&SegmentName::Output);

let mut public_memory = (1..=codelen as u64)
.map(|i| (Felt252::from(i), *memory.get(&i).unwrap()))
.collect::<HashMap<Felt252, Felt252>>();

if let Some(output_range) = output_range {
for addr in output_range.clone() {
if let Some(output_segment) = output_segment {
let range: Range<u64> = output_segment.clone().into();
for addr in range.clone() {
public_memory.insert(Felt252::from(addr), *memory.get(&addr).unwrap());
}
};
Expand Down Expand Up @@ -262,12 +286,14 @@ impl Serializable for PublicInputs {
let mut memory_segment_bytes = vec![];
for (segment, range) in self.memory_segments.iter() {
let segment_type = match segment {
MemorySegment::RangeCheck => 0u8,
MemorySegment::Output => 1u8,
SegmentName::RangeCheck => 0u8,
SegmentName::Output => 1u8,
SegmentName::Program => 2u8,
SegmentName::Execution => 3u8,
};
memory_segment_bytes.extend(segment_type.to_be_bytes());
memory_segment_bytes.extend(range.start.to_be_bytes());
memory_segment_bytes.extend(range.end.to_be_bytes());
memory_segment_bytes.extend(range.begin_addr.to_be_bytes());
memory_segment_bytes.extend(range.stop_ptr.to_be_bytes());
}
let memory_segment_length = self.memory_segments.len();
bytes.extend(memory_segment_length.to_be_bytes());
Expand Down Expand Up @@ -390,8 +416,10 @@ impl Deserializable for PublicInputs {
return Err(DeserializationError::InvalidAmountOfBytes);
}
let segment_type = match bytes[0] {
0 => MemorySegment::RangeCheck,
1 => MemorySegment::Output,
0u8 => SegmentName::RangeCheck,
1u8 => SegmentName::Output,
2u8 => SegmentName::Program,
3u8 => SegmentName::Execution,
_ => return Err(DeserializationError::FieldFromBytesError),
};
bytes = &bytes[1..];
Expand All @@ -411,7 +439,7 @@ impl Deserializable for PublicInputs {
.map_err(|_| DeserializationError::InvalidAmountOfBytes)?,
);
bytes = &bytes[8..];
memory_segments.insert(segment_type, start..end);
memory_segments.insert(segment_type, Segment::from(start..end));
}

let mut public_memory = HashMap::new();
Expand Down Expand Up @@ -511,9 +539,9 @@ fn add_pub_memory_in_public_input_section(
let mut a_aux = addresses.to_owned();
let mut v_aux = values.to_owned();

let output_range = public_input.memory_segments.get(&MemorySegment::Output);
let output_segment = public_input.memory_segments.get(&SegmentName::Output);

let pub_addrs = get_pub_memory_addrs(output_range, public_input);
let pub_addrs = get_pub_memory_addrs(output_segment, public_input);
let mut pub_addrs_iter = pub_addrs.iter();

// Iterate over addresses
Expand Down Expand Up @@ -542,12 +570,13 @@ fn add_pub_memory_in_public_input_section(
/// If the output builtin is used, `output_range` is `Some(...)` and this function adds incrementally to the resulting
/// `Vec` addresses from the start to the end of the unwrapped `output_range`.
fn get_pub_memory_addrs(
output_range: Option<&Range<u64>>,
output_segment: Option<&Segment>,
public_input: &PublicInputs,
) -> Vec<FieldElement<Stark252PrimeField>> {
let public_memory_len = public_input.public_memory.len() as u64;

if let Some(output_range) = output_range {
if let Some(output_segment) = output_segment {
let output_range: Range<u64> = output_segment.clone().into();
let output_section = output_range.end - output_range.start;
let program_section = public_memory_len - output_section;

Expand Down Expand Up @@ -669,7 +698,7 @@ impl AIR for CairoAIR {
// layouts functionality. The `has_rc_builtin` boolean should not exist, we will know the
// layout from the Cairo public inputs directly, and the number of constraints and columns
// will be enforced through that.
let has_rc_builtin = pub_inputs.memory_segments.get(&MemorySegment::RangeCheck).is_some();
let has_rc_builtin = pub_inputs.memory_segments.get(&SegmentName::RangeCheck).is_some();
if has_rc_builtin {
trace_columns += 8 + 1; // 8 columns for each rc of the range-check builtin values decomposition, 1 for the values
transition_degrees.push(1); // Range check builtin constraint
Expand Down Expand Up @@ -1421,7 +1450,7 @@ mod test {
range_check_max: None,
range_check_min: None,
num_steps: 1,
memory_segments: MemorySegmentMap::from([(MemorySegment::Output, 20..21)]),
memory_segments: MemorySegmentMap::from([(SegmentName::Output, Segment::from(20..21))]),
codelen: 3,
};

Expand Down Expand Up @@ -1602,7 +1631,7 @@ mod prop_test {
Felt252,
};

use super::{MemorySegment, MemorySegmentMap, PublicInputs};
use super::{MemorySegmentMap, PublicInputs, Segment, SegmentName};

prop_compose! {
fn some_felt()(base in any::<u64>(), exponent in any::<u128>()) -> Felt252 {
Expand All @@ -1624,7 +1653,7 @@ mod prop_test {
codelen in any::<usize>(),
) -> PublicInputs {
let public_memory = public_memory.iter().map(|(k, v)| (Felt252::from(*k), Felt252::from(*v))).collect();
let memory_segments = MemorySegmentMap::from([(MemorySegment::Output, 10u64..16u64), (MemorySegment::RangeCheck, 20u64..71u64)]);
let memory_segments = MemorySegmentMap::from([(SegmentName::Output, Segment::from(10u64..16u64)), (SegmentName::RangeCheck, Segment::from(20u64..71u64))]);
PublicInputs {
pc_init,
ap_init,
Expand Down
20 changes: 10 additions & 10 deletions provers/cairo/src/execution_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@ use super::{
},
register_states::RegisterStates,
};
use crate::air::{EXTRA_ADDR, RC_HOLES};
use crate::air::{Segment, EXTRA_ADDR, RC_HOLES};
use crate::{
air::{
MemorySegment, PublicInputs, FRAME_DST_ADDR, FRAME_OP0_ADDR, FRAME_OP1_ADDR, FRAME_PC,
PublicInputs, SegmentName, FRAME_DST_ADDR, FRAME_OP0_ADDR, FRAME_OP1_ADDR, FRAME_PC,
OFF_DST, OFF_OP0, OFF_OP1,
},
Felt252,
Expand Down Expand Up @@ -285,11 +285,10 @@ pub fn build_cairo_execution_trace(
trace_cols.push(extra_vals);
trace_cols.push(rc_holes);

if let Some(range_check_builtin_range) = public_inputs
.memory_segments
.get(&MemorySegment::RangeCheck)
if let Some(range_check_builtin_segment) =
public_inputs.memory_segments.get(&SegmentName::RangeCheck)
{
add_rc_builtin_columns(&mut trace_cols, range_check_builtin_range.clone(), memory);
add_rc_builtin_columns(&mut trace_cols, range_check_builtin_segment.clone(), memory);
}

TraceTable::from_columns(trace_cols, 1)
Expand All @@ -298,12 +297,13 @@ pub fn build_cairo_execution_trace(
// Build range-check builtin columns: rc_0, rc_1, ... , rc_7, rc_value
fn add_rc_builtin_columns(
trace_cols: &mut Vec<Vec<Felt252>>,
range_check_builtin_range: Range<u64>,
range_check_builtin_segment: Segment,
memory: &CairoMemory,
) {
let range_checked_values: Vec<&Felt252> = range_check_builtin_range
.map(|addr| memory.get(&addr).unwrap())
.collect();
let range: Range<u64> = range_check_builtin_segment.into();
let range_checked_values: Vec<&Felt252> =
range.map(|addr| memory.get(&addr).unwrap()).collect();

let mut rc_trace_columns = decompose_rc_values_into_trace_columns(&range_checked_values);

// rc decomposition columns are appended with zeros and then pushed to the trace table
Expand Down
10 changes: 7 additions & 3 deletions provers/cairo/src/runner/run.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use crate::air::{MemorySegment, MemorySegmentMap, PublicInputs};
use crate::air::{MemorySegmentMap, PublicInputs, Segment, SegmentName};
use crate::cairo_layout::CairoLayout;
use crate::cairo_mem::CairoMemory;
use crate::execution_trace::build_main_trace;
Expand Down Expand Up @@ -93,6 +93,7 @@ pub fn run_program(
layout: layout.as_str(),
proof_mode: true,
secure_run: None,
disable_trace_padding: false,
};

let (runner, vm) =
Expand Down Expand Up @@ -208,10 +209,13 @@ fn create_memory_segment_map(
let mut memory_segments = MemorySegmentMap::new();

if let Some(range_check_builtin_range) = range_check_builtin_range {
memory_segments.insert(MemorySegment::RangeCheck, range_check_builtin_range);
memory_segments.insert(
SegmentName::RangeCheck,
Segment::from(range_check_builtin_range),
);
}
if let Some(output_range) = output_range {
memory_segments.insert(MemorySegment::Output, output_range.clone());
memory_segments.insert(SegmentName::Output, Segment::from(output_range.clone()));
}

memory_segments
Expand Down
7 changes: 4 additions & 3 deletions provers/cairo/src/tests/integration_tests.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
air::{
generate_cairo_proof, verify_cairo_proof, CairoAIR, MemorySegment, MemorySegmentMap,
PublicInputs, FRAME_DST_ADDR, FRAME_OP0_ADDR, FRAME_OP1_ADDR, FRAME_PC,
generate_cairo_proof, verify_cairo_proof, CairoAIR, MemorySegmentMap, PublicInputs,
Segment, SegmentName, FRAME_DST_ADDR, FRAME_OP0_ADDR, FRAME_OP1_ADDR, FRAME_PC,
},
cairo_layout::CairoLayout,
execution_trace::build_main_trace,
Expand Down Expand Up @@ -196,7 +196,8 @@ fn test_verifier_rejects_proof_with_overflowing_range_check_value() {
// These is the regular setup for generating the trace and the Cairo AIR, but now
// we do it with the malicious memory
let proof_options = ProofOptions::default_test_options();
let memory_segments = MemorySegmentMap::from([(MemorySegment::RangeCheck, 27..29)]);
let memory_segments =
MemorySegmentMap::from([(SegmentName::RangeCheck, Segment::from(27..29))]);

let mut pub_inputs = PublicInputs::from_regs_and_mem(
&register_states,
Expand Down