Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Public inputs refactor #676

Merged
merged 35 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
b14e8f0
rename MemorySegment to SegmentName
juan518munoz Nov 9, 2023
efb2435
add SegmentName enum options
juan518munoz Nov 9, 2023
6edc4c3
Segment struct declaration
juan518munoz Nov 9, 2023
ed2be99
initial integratation of struct Segment
juan518munoz Nov 9, 2023
0579bd9
rmv comment
juan518munoz Nov 9, 2023
68497f3
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 13, 2023
2c75f5a
update cairo-vm
juan518munoz Nov 13, 2023
f5a601b
change cairo-vm branch
juan518munoz Nov 14, 2023
2a7d3eb
add ecdsa & pedersen
juan518munoz Nov 14, 2023
4284f8a
test pub inputs from vm
juan518munoz Nov 15, 2023
97e1e44
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 15, 2023
f73fc39
public inputs most fields derived from vm
juan518munoz Nov 15, 2023
8ba4bcb
clippy & fmt
juan518munoz Nov 15, 2023
845ebd1
update cairo-vm dep to latest rev
juan518munoz Nov 15, 2023
2ae0109
get memory segment from vm
juan518munoz Nov 16, 2023
b556554
rename data_len to codelen
juan518munoz Nov 16, 2023
e13cc69
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 16, 2023
14d3c14
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 17, 2023
32d75d5
Merge branch 'main' into PublicInputsRefactor
juan518munoz Nov 17, 2023
56024d0
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 4, 2023
be520f9
Extract public memory directly from public inputs of Cairo VM
entropidelic Dec 4, 2023
666f007
Refactor get_memory_holes function
entropidelic Dec 5, 2023
9d720cd
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 5, 2023
f495565
Remove unnecessary function and legacy test
entropidelic Dec 5, 2023
b317f1b
Remove some commented code and fix some tests
entropidelic Dec 5, 2023
f1b9293
Solve clippy issues
entropidelic Dec 5, 2023
95168ba
Remove legacy test
entropidelic Dec 5, 2023
f6817ee
Fix some comments on tests
entropidelic Dec 5, 2023
8415e05
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 5, 2023
7bdce89
Refactor pub addresses in add_pub_memory_in_public_input_section func…
entropidelic Dec 6, 2023
40e5211
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 6, 2023
17c5b1b
Merge remote-tracking branch 'origin/PublicInputsRefactor' into Publi…
entropidelic Dec 6, 2023
65b8e09
Refactor segment_size method
entropidelic Dec 6, 2023
a0b7612
iterate over addr value pairs in add_pub_memory_in_public_input_secti…
entropidelic Dec 6, 2023
58a5d4e
Merge branch 'main' into PublicInputsRefactor
entropidelic Dec 6, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion provers/cairo/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ thiserror = "1.0.38"
log = "0.4.17"
bincode = { version = "2.0.0-rc.2", tag = "v2.0.0-rc.2", git = "https://github.com/bincode-org/bincode.git", features= ['serde'] }
# NOTE: For cairo 1 compatibility, add the `cairo-1-hints` feature.
cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm", rev = "e763cef", default-features = false }
cairo-vm = { git = "https://github.com/lambdaclass/cairo-vm", rev = "e61ae177edb94e29470fed23bdda43329b16c057", default-features = false }
sha3 = "0.10.6"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
Expand Down
91 changes: 74 additions & 17 deletions provers/cairo/src/air.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::ops::Range;

use cairo_vm::without_std::collections::HashMap;
use cairo_vm::{air_public_input::MemorySegmentAddresses, without_std::collections::HashMap};
use lambdaworks_math::{
errors::DeserializationError,
field::{
Expand Down Expand Up @@ -147,12 +147,60 @@ pub const MEM_P_TRACE_OFFSET: usize = 17;
pub const MEM_A_TRACE_OFFSET: usize = 19;

#[derive(Debug, Clone, PartialEq, Eq, Hash, serde::Serialize, serde::Deserialize)]
pub enum MemorySegment {
pub enum SegmentName {
RangeCheck,
Output,
Program,
Execution,
Ecdsa,
Pedersen,
}

pub type MemorySegmentMap = HashMap<MemorySegment, Range<u64>>;
impl From<&str> for SegmentName {
fn from(value: &str) -> Self {
match value {
"range_check" => SegmentName::RangeCheck,
"output" => SegmentName::Output,
"program" => SegmentName::Program,
"execution" => SegmentName::Execution,
"ecdsa" => SegmentName::Ecdsa,
"pedersen" => SegmentName::Pedersen,
n => panic!("Invalid segment name {n}"),
}
}
}

#[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Segment {
pub begin_addr: u64,
pub stop_ptr: u64,
}

impl From<Range<u64>> for Segment {
fn from(range: Range<u64>) -> Self {
Segment {
begin_addr: range.start,
stop_ptr: range.end,
}
}
}

impl From<Segment> for Range<u64> {
fn from(val: Segment) -> Self {
val.begin_addr..val.stop_ptr
}
}

impl From<&MemorySegmentAddresses> for Segment {
fn from(value: &MemorySegmentAddresses) -> Self {
Self {
begin_addr: value.begin_addr as u64,
stop_ptr: value.stop_ptr as u64,
}
}
}

pub type MemorySegmentMap = HashMap<SegmentName, Segment>;

#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct PublicInputs {
Expand Down Expand Up @@ -236,12 +284,16 @@ impl Serializable for PublicInputs {
let mut memory_segment_bytes = vec![];
for (segment, range) in self.memory_segments.iter() {
let segment_type = match segment {
MemorySegment::RangeCheck => 0u8,
MemorySegment::Output => 1u8,
SegmentName::RangeCheck => 0u8,
SegmentName::Output => 1u8,
SegmentName::Program => 2u8,
SegmentName::Execution => 3u8,
SegmentName::Ecdsa => 4u8,
SegmentName::Pedersen => 5u8,
};
memory_segment_bytes.extend(segment_type.to_be_bytes());
memory_segment_bytes.extend(range.start.to_be_bytes());
memory_segment_bytes.extend(range.end.to_be_bytes());
memory_segment_bytes.extend(range.begin_addr.to_be_bytes());
memory_segment_bytes.extend(range.stop_ptr.to_be_bytes());
}
let memory_segment_length = self.memory_segments.len();
bytes.extend(memory_segment_length.to_be_bytes());
Expand Down Expand Up @@ -364,8 +416,12 @@ impl Deserializable for PublicInputs {
return Err(DeserializationError::InvalidAmountOfBytes);
}
let segment_type = match bytes[0] {
0 => MemorySegment::RangeCheck,
1 => MemorySegment::Output,
0u8 => SegmentName::RangeCheck,
1u8 => SegmentName::Output,
2u8 => SegmentName::Program,
3u8 => SegmentName::Execution,
4u8 => SegmentName::Ecdsa,
5u8 => SegmentName::Pedersen,
_ => return Err(DeserializationError::FieldFromBytesError),
};
bytes = &bytes[1..];
Expand All @@ -385,7 +441,7 @@ impl Deserializable for PublicInputs {
.map_err(|_| DeserializationError::InvalidAmountOfBytes)?,
);
bytes = &bytes[8..];
memory_segments.insert(segment_type, start..end);
memory_segments.insert(segment_type, Segment::from(start..end));
}

let mut public_memory = HashMap::new();
Expand Down Expand Up @@ -474,9 +530,9 @@ fn add_pub_memory_in_public_input_section(
let mut a_aux = addresses.to_owned();
let mut v_aux = values.to_owned();

let output_range = public_input.memory_segments.get(&MemorySegment::Output);
let output_segment = public_input.memory_segments.get(&SegmentName::Output);

let pub_addrs = get_pub_memory_addrs(output_range, public_input);
let pub_addrs = get_pub_memory_addrs(output_segment, public_input);
let mut pub_addrs_iter = pub_addrs.iter();

// Iterate over addresses
Expand Down Expand Up @@ -505,12 +561,13 @@ fn add_pub_memory_in_public_input_section(
/// If the output builtin is used, `output_range` is `Some(...)` and this function adds incrementally to the resulting
/// `Vec` addresses from the start to the end of the unwrapped `output_range`.
fn get_pub_memory_addrs(
output_range: Option<&Range<u64>>,
output_segment: Option<&Segment>,
public_input: &PublicInputs,
) -> Vec<FieldElement<Stark252PrimeField>> {
let public_memory_len = public_input.public_memory.len() as u64;

if let Some(output_range) = output_range {
if let Some(output_segment) = output_segment {
let output_range: Range<u64> = output_segment.clone().into();
let output_section = output_range.end - output_range.start;
let program_section = public_memory_len - output_section;

Expand Down Expand Up @@ -1315,7 +1372,7 @@ mod test {
range_check_max: None,
range_check_min: None,
num_steps: 1,
memory_segments: MemorySegmentMap::from([(MemorySegment::Output, 20..21)]),
memory_segments: MemorySegmentMap::from([(SegmentName::Output, Segment::from(20..21))]),
codelen: 3,
};

Expand Down Expand Up @@ -1496,7 +1553,7 @@ mod prop_test {
Felt252,
};

use super::{MemorySegment, MemorySegmentMap, PublicInputs};
use super::{MemorySegmentMap, PublicInputs, Segment, SegmentName};

prop_compose! {
fn some_felt()(base in any::<u64>(), exponent in any::<u128>()) -> Felt252 {
Expand All @@ -1518,7 +1575,7 @@ mod prop_test {
codelen in any::<usize>(),
) -> PublicInputs {
let public_memory = public_memory.iter().map(|(k, v)| (Felt252::from(*k), Felt252::from(*v))).collect();
let memory_segments = MemorySegmentMap::from([(MemorySegment::Output, 10u64..16u64), (MemorySegment::RangeCheck, 20u64..71u64)]);
let memory_segments = MemorySegmentMap::from([(SegmentName::Output, Segment::from(10u64..16u64)), (SegmentName::RangeCheck, Segment::from(20u64..71u64))]);
PublicInputs {
pc_init,
ap_init,
Expand Down
13 changes: 9 additions & 4 deletions provers/cairo/src/execution_trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@ use super::{
},
register_states::RegisterStates,
};
use crate::air::{EXTRA_ADDR, RC_HOLES};
use crate::{
air::{
PublicInputs, EXTRA_ADDR, FRAME_DST_ADDR, FRAME_OP0_ADDR, FRAME_OP1_ADDR, FRAME_PC,
OFF_DST, OFF_OP0, OFF_OP1, RC_HOLES,
PublicInputs, FRAME_DST_ADDR, FRAME_OP0_ADDR, FRAME_OP1_ADDR, FRAME_PC, OFF_DST, OFF_OP0,
OFF_OP1,
},
Felt252,
};
Expand Down Expand Up @@ -55,8 +56,12 @@ pub fn build_main_trace(
address_cols.sort_by_key(|x| x.representative());

let (rc_holes, rc_min, rc_max) = get_rc_holes(&main_trace, &[OFF_DST, OFF_OP0, OFF_OP1]);
public_input.range_check_min = Some(rc_min);
public_input.range_check_max = Some(rc_max);

// this will avaluate to true if the public inputs weren't obtained from the run_program() function
if public_input.range_check_min.is_none() && public_input.range_check_max.is_none() {
public_input.range_check_min = Some(rc_min);
public_input.range_check_max = Some(rc_max);
}
fill_rc_holes(&mut main_trace, &rc_holes);

let memory_holes = get_memory_holes(&address_cols, public_input.codelen);
Expand Down
45 changes: 36 additions & 9 deletions provers/cairo/src/runner/run.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use crate::air::PublicInputs;
use crate::air::{PublicInputs, Segment, SegmentName};
use crate::cairo_layout::CairoLayout;
use crate::cairo_mem::CairoMemory;
use crate::execution_trace::build_main_trace;
use crate::register_states::RegisterStates;
use crate::Felt252;

use super::vec_writer::VecWriter;
use cairo_vm::cairo_run::{self, EncodeTraceError};
Expand All @@ -13,6 +14,7 @@ use cairo_vm::vm::errors::{
cairo_run_errors::CairoRunError, trace_errors::TraceError, vm_errors::VirtualMachineError,
};

use cairo_vm::without_std::collections::HashMap;
use lambdaworks_math::field::fields::fft_friendly::stark_252_prime_field::Stark252PrimeField;
use stark_platinum_prover::trace::TraceTable;

Expand Down Expand Up @@ -79,7 +81,7 @@ pub fn run_program(
entrypoint_function: Option<&str>,
layout: CairoLayout,
program_content: &[u8],
) -> Result<(RegisterStates, CairoMemory, usize), Error> {
) -> Result<(RegisterStates, CairoMemory, PublicInputs), Error> {
// default value for entrypoint is "main"
let entrypoint = entrypoint_function.unwrap_or("main");

Expand All @@ -92,6 +94,7 @@ pub fn run_program(
layout: layout.as_str(),
proof_mode: true,
secure_run: None,
disable_trace_padding: false,
};

let (runner, vm) =
Expand Down Expand Up @@ -122,22 +125,46 @@ pub fn run_program(
let cairo_mem = CairoMemory::from_bytes_le(&memory_vec).unwrap();
let register_states = RegisterStates::from_bytes_le(&trace_vec).unwrap();

let data_len = runner.get_program().data_len();
let codelen = runner.get_program().data_len();

let public_memory = (1..=codelen as u64)
.map(|i| (Felt252::from(i), *cairo_mem.get(&i).unwrap()))
.collect::<HashMap<Felt252, Felt252>>();

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that we have the public inputs directly from the cairo vm, there is no need to ask for the codelen, and the public memory should be obtained from the public inputs, this was just a hacky solution until we had access to them.
The approach should be similar to the one you did with the memory segments

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a result of this, the PublicInputs field codelen could be removed, since it can be deduced with the length of the Program memory segment

Copy link
Contributor Author

@juan518munoz juan518munoz Nov 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only function that makes use of the codelen field is get_memory_holes which itself is only called inside build_main_trace. How should we handle this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The approach should be similar to the one you did with the memory segments

I'm having trouble doing an iteration over the vm public memory, doing something like this:

    let mut public_memory: HashMap<Felt252, Felt252> = HashMap::new();
    vm_public_inputs.public_memory.iter().for_each(|e| {
        let v = e.value.clone().unwrap().to_str_radix(16);
        public_memory.insert(
            Felt252::from((e.address + e.page) as u64),
            Felt252::from_hex_unchecked(&v),
        );
    });

Breaks some of our tests

image

Maybe I'm missing something?

let vm_public_inputs = runner.get_air_public_input(&vm).unwrap();

let mut memory_segments: HashMap<SegmentName, Segment> = HashMap::new();
vm_public_inputs.memory_segments.iter().for_each(|(k, v)| {
memory_segments.insert(SegmentName::from(*k), Segment::from(v));
});

let num_steps = register_states.steps();
let public_inputs = PublicInputs {
pc_init: Felt252::from(register_states.rows[0].pc),
ap_init: Felt252::from(register_states.rows[0].ap),
fp_init: Felt252::from(register_states.rows[0].fp),
pc_final: Felt252::from(register_states.rows[num_steps - 1].pc),
ap_final: Felt252::from(register_states.rows[num_steps - 1].ap),
range_check_min: Some(vm_public_inputs.rc_min as u16),
range_check_max: Some(vm_public_inputs.rc_max as u16),
memory_segments,
public_memory,
num_steps,
codelen,
};

Ok((register_states, cairo_mem, data_len))
Ok((register_states, cairo_mem, public_inputs))
}

pub fn generate_prover_args(
program_content: &[u8],
layout: CairoLayout,
) -> Result<(TraceTable<Stark252PrimeField>, PublicInputs), Error> {
let (register_states, memory, program_size) = run_program(None, layout, program_content)?;

let mut pub_inputs = PublicInputs::from_regs_and_mem(&register_states, &memory, program_size);
let (register_states, memory, mut public_inputs) = run_program(None, layout, program_content)?;

let main_trace = build_main_trace(&register_states, &memory, &mut pub_inputs);
let main_trace = build_main_trace(&register_states, &memory, &mut public_inputs);

Ok((main_trace, pub_inputs))
Ok((main_trace, public_inputs))
}

pub fn generate_prover_args_from_trace(
Expand Down
Loading