diff --git a/axiom-client-sdk/src/cmd.rs b/axiom-client-sdk/src/cmd.rs index f1fa99f..63e26fd 100644 --- a/axiom-client-sdk/src/cmd.rs +++ b/axiom-client-sdk/src/cmd.rs @@ -1,6 +1,21 @@ -use std::{env, fmt::Debug, fs, path::PathBuf}; +use std::{ + env, + fmt::Debug, + fs::{self, File}, + io::BufWriter, + path::PathBuf, +}; -use axiom_client::axiom_eth::halo2_base::{gates::circuit::BaseCircuitParams, AssignedValue}; +use axiom_client::{ + axiom_eth::{ + halo2_base::{gates::circuit::BaseCircuitParams, AssignedValue}, + halo2_proofs::{plonk::ProvingKey, SerdeFormat}, + halo2curves::bn256::G1Affine, + rlc::virtual_region::RlcThreadBreakPoints, + }, + scaffold::AxiomCircuit, + types::AxiomCircuitParams, +}; pub use clap::Parser; use clap::Subcommand; use ethers::providers::{Http, Provider}; @@ -45,6 +60,8 @@ pub struct Cli { pub provider: Option, #[arg(short, long = "input")] pub input_path: Option, + #[arg(short, long = "data-path")] + pub data_path: Option, } pub fn run_cli() @@ -71,6 +88,7 @@ where .provider .unwrap_or_else(|| env::var("PROVIDER_URI").unwrap()); let provider = Provider::::try_from(provider_uri).unwrap(); + let data_path = cli.data_path.unwrap_or_else(|| PathBuf::from("data")); let params = BaseCircuitParams { k: 12, @@ -89,24 +107,73 @@ where .mock(); } SnarkCmd::Keygen => { - AxiomCompute::::new() + let circuit = AxiomCompute::::new() .use_params(params) - .use_provider(provider) - .keygen(); + .use_provider(provider); + let (vkey, pkey, breakpoints) = circuit.keygen(); + let pk_path = data_path.join(PathBuf::from("pk.bin")); + if pk_path.exists() { + fs::remove_file(&pk_path).unwrap(); + } + let vk_path = data_path.join(PathBuf::from("vk.bin")); + if vk_path.exists() { + fs::remove_file(&vk_path).unwrap(); + } + let f = File::create(&vk_path) + .unwrap_or_else(|_| panic!("Could not create file at {vk_path:?}")); + let mut writer = BufWriter::new(f); + vkey.write(&mut writer, SerdeFormat::RawBytes) + .expect("writing vkey should not fail"); + + let f = File::create(&pk_path) + .unwrap_or_else(|_| panic!("Could not create file at {pk_path:?}")); + let mut writer = BufWriter::new(f); + pkey.write(&mut writer, SerdeFormat::RawBytes) + .expect("writing pkey should not fail"); + + let breakpoints_path = data_path.join(PathBuf::from("breakpoints.json")); + if breakpoints_path.exists() { + fs::remove_file(&breakpoints_path).unwrap(); + } + let f = File::create(&breakpoints_path) + .unwrap_or_else(|_| panic!("Could not create file at {breakpoints_path:?}")); + let mut writer = BufWriter::new(f); + serde_json::to_writer_pretty(&mut writer, &breakpoints) + .expect("writing breakpoints should not fail"); } SnarkCmd::Prove => { let compute = AxiomCompute::::new() - .use_params(params) + .use_params(params.clone()) .use_provider(provider); - let (_vk, pk) = compute.keygen(); - compute.use_inputs(input).prove(pk); + let pk_path = data_path.join(PathBuf::from("pk.bin")); + let mut f = File::open(&pk_path).unwrap(); + let pk = ProvingKey::::read::<_, AxiomCircuit>>( + &mut f, + SerdeFormat::RawBytes, + AxiomCircuitParams::Base(params), + ) + .unwrap(); + let breakpoints_path = data_path.join(PathBuf::from("breakpoints.json")); + let f = File::open(&breakpoints_path).unwrap(); + let breakpoints: RlcThreadBreakPoints = serde_json::from_reader(f).unwrap(); + compute.use_inputs(input).prove(pk, breakpoints); } SnarkCmd::Run => { let compute = AxiomCompute::::new() - .use_params(params) + .use_params(params.clone()) .use_provider(provider); - let (_vk, pk) = compute.keygen(); - compute.use_inputs(input).run(pk); + let pk_path = data_path.join(PathBuf::from("pk.bin")); + let mut f = File::open(&pk_path).unwrap(); + let pk = ProvingKey::::read::<_, AxiomCircuit>>( + &mut f, + SerdeFormat::RawBytes, + AxiomCircuitParams::Base(params), + ) + .unwrap(); + let breakpoints_path = data_path.join(PathBuf::from("breakpoints.json")); + let f = File::open(&breakpoints_path).unwrap(); + let breakpoints: RlcThreadBreakPoints = serde_json::from_reader(f).unwrap(); + compute.use_inputs(input).run(pk, breakpoints); } } } diff --git a/axiom-client-sdk/src/compute.rs b/axiom-client-sdk/src/compute.rs index 01f55ee..466bf04 100644 --- a/axiom-client-sdk/src/compute.rs +++ b/axiom-client-sdk/src/compute.rs @@ -11,7 +11,7 @@ use axiom_client::{ }, halo2_proofs::plonk::{ProvingKey, VerifyingKey}, halo2curves::bn256::G1Affine, - rlc::circuit::builder::RlcCircuitBuilder, + rlc::{circuit::builder::RlcCircuitBuilder, virtual_region::RlcThreadBreakPoints}, snark_verifier_sdk::Snark, utils::hilo::HiLo, }, @@ -141,14 +141,20 @@ where mock::(provider, AxiomCircuitParams::Base(params), converted_input); } - pub fn keygen(&self) -> (VerifyingKey, ProvingKey) { + pub fn keygen( + &self, + ) -> ( + VerifyingKey, + ProvingKey, + RlcThreadBreakPoints, + ) { self.check_provider_and_params_set(); let provider = self.provider.clone().unwrap(); let params = self.params.clone().unwrap(); keygen::(provider, AxiomCircuitParams::Base(params), None) } - pub fn prove(&self, pk: ProvingKey) -> Snark { + pub fn prove(&self, pk: ProvingKey, break_points: RlcThreadBreakPoints) -> Snark { self.check_all_set(); let provider = self.provider.clone().unwrap(); let params = self.params.clone().unwrap(); @@ -158,10 +164,15 @@ where AxiomCircuitParams::Base(params), converted_input, pk, + break_points, ) } - pub fn run(&self, pk: ProvingKey) -> AxiomV2CircuitOutput { + pub fn run( + &self, + pk: ProvingKey, + break_points: RlcThreadBreakPoints, + ) -> AxiomV2CircuitOutput { self.check_all_set(); let provider = self.provider.clone().unwrap(); let params = self.params.clone().unwrap(); @@ -171,6 +182,7 @@ where AxiomCircuitParams::Base(params), converted_input, pk, + break_points, ) } diff --git a/axiom-client/src/run/inner.rs b/axiom-client/src/run/inner.rs index 342f251..3fcc407 100644 --- a/axiom-client/src/run/inner.rs +++ b/axiom-client/src/run/inner.rs @@ -13,6 +13,7 @@ use axiom_query::axiom_eth::{ SerdeFormat, }, halo2curves::bn256::{Fr, G1Affine}, + rlc::virtual_region::RlcThreadBreakPoints, snark_verifier_sdk::{halo2::gen_snark_shplonk, Snark}, utils::keccak::decorator::RlcKeccakCircuitParams, }; @@ -48,7 +49,11 @@ pub fn keygen>( provider: Provider

, raw_circuit_params: AxiomCircuitParams, inputs: Option, -) -> (VerifyingKey, ProvingKey) { +) -> ( + VerifyingKey, + ProvingKey, + RlcThreadBreakPoints, +) { let circuit_params = RlcKeccakCircuitParams::from(raw_circuit_params.clone()); let params = gen_srs(circuit_params.k() as u32); let mut runner = AxiomCircuit::<_, _, S>::new(provider, raw_circuit_params).use_inputs(inputs); @@ -56,6 +61,7 @@ pub fn keygen>( runner.calculate_params(); } let vk = keygen_vk(¶ms, &runner).expect("Failed to generate vk"); + let breakpoints = runner.break_points(); let path = Path::new("data/vk.bin"); if let Some(parent) = path.parent() { create_dir_all(parent).expect("Failed to create data directory"); @@ -68,7 +74,7 @@ pub fn keygen>( let mut pk_file = File::create(path).expect("Failed to create pk file"); pk.write(&mut pk_file, SerdeFormat::Processed) .expect("Failed to write pk"); - (vk, pk) + (vk, pk, breakpoints) } pub fn prove>( @@ -76,10 +82,12 @@ pub fn prove>( raw_circuit_params: AxiomCircuitParams, inputs: Option, pk: ProvingKey, + break_points: RlcThreadBreakPoints, ) -> Snark { let circuit_params = RlcKeccakCircuitParams::from(raw_circuit_params.clone()); let params = gen_srs(circuit_params.k() as u32); let mut runner = AxiomCircuit::<_, _, S>::new(provider, raw_circuit_params).use_inputs(inputs); + runner.set_break_points(break_points); if circuit_params.keccak_rows_per_round > 0 { runner.calculate_params(); } @@ -91,12 +99,14 @@ pub fn run>( raw_circuit_params: AxiomCircuitParams, inputs: Option, pk: ProvingKey, + break_points: RlcThreadBreakPoints, ) -> AxiomV2CircuitOutput { let circuit_params = RlcKeccakCircuitParams::from(raw_circuit_params.clone()); let k = circuit_params.k(); let params = gen_srs(k as u32); let mut runner = AxiomCircuit::<_, _, S>::new(provider, raw_circuit_params.clone()).use_inputs(inputs); + runner.set_break_points(break_points); let output = runner.scaffold_output(); if circuit_params.keccak_rows_per_round > 0 { runner.calculate_params(); diff --git a/axiom-client/src/scaffold.rs b/axiom-client/src/scaffold.rs index 7aacd0a..ee3d048 100644 --- a/axiom-client/src/scaffold.rs +++ b/axiom-client/src/scaffold.rs @@ -186,7 +186,16 @@ impl> AxiomCir } pub fn break_points(&self) -> RlcThreadBreakPoints { - self.builder.borrow().break_points() + let rlc_params = self.builder.borrow().params(); + if rlc_params.num_rlc_columns == 0 { + let break_points = self.builder.borrow().base.break_points(); + RlcThreadBreakPoints { + base: break_points, + rlc: vec![], + } + } else { + self.builder.borrow().break_points() + } } pub fn k(&self) -> usize { diff --git a/axiom-client/src/tests/keccak.rs b/axiom-client/src/tests/keccak.rs index 0858641..98d2c1f 100644 --- a/axiom-client/src/tests/keccak.rs +++ b/axiom-client/src/tests/keccak.rs @@ -119,8 +119,8 @@ pub fn mock>(_circuit: S) { let params = get_keccak_test_params(); let agg_circuit_params = get_agg_test_params(); let client = get_provider(); - let (_, pk) = keygen::<_, S>(client.clone(), params.clone(), None); - let snark = prove::<_, S>(client, params, None, pk); + let (_, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None); + let snark = prove::<_, S>(client, params, None, pk, break_points); agg_circuit_mock(agg_circuit_params, snark); } @@ -138,8 +138,8 @@ pub fn test_single_subquery_instances>(_circui let num_user_output_fe = runner.output_num_instances(); let subquery_fe = runner.subquery_num_instances(); let results = runner.scaffold_output(); - let (_, pk) = keygen::<_, S>(client.clone(), params.clone(), None); - let snark = prove::<_, S>(client, params, None, pk); + let (_, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None); + let snark = prove::<_, S>(client, params, None, pk, break_points); let agg_circuit = create_aggregation_circuit(agg_circuit_params, snark.clone(), CircuitBuilderStage::Mock); let instances = agg_circuit.instances(); @@ -163,15 +163,15 @@ pub fn test_compute_query>(_circuit: S) { let params = get_keccak_test_params(); let agg_circuit_params = get_agg_test_params(); let client = get_provider(); - let (_vk, pk) = keygen::<_, S>(client.clone(), params.clone(), None); - let output = run::<_, S>(client, params.clone(), None, pk); - let (agg_vk, agg_pk, break_points) = + let (_vk, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None); + let output = run::<_, S>(client, params.clone(), None, pk, break_points); + let (agg_vk, agg_pk, agg_break_points) = agg_circuit_keygen(agg_circuit_params, output.snark.clone()); let final_output = agg_circuit_run( agg_circuit_params, output.snark.clone(), agg_pk, - break_points, + agg_break_points, output.data, ); let circuit = create_aggregation_circuit( diff --git a/axiom-client/src/tests/shared_tests.rs b/axiom-client/src/tests/shared_tests.rs index dcc5296..eaa79f2 100644 --- a/axiom-client/src/tests/shared_tests.rs +++ b/axiom-client/src/tests/shared_tests.rs @@ -157,8 +157,8 @@ pub fn check_compute_proof_and_query_format>( is_aggregation: bool, ) { let client = get_provider(); - let (vk, pk) = keygen::<_, S>(client.clone(), params.clone(), None); - let output = run::<_, S>(client, params.clone(), None, pk); + let (vk, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None); + let output = run::<_, S>(client, params.clone(), None, pk, break_points); check_compute_proof_format(output.clone(), is_aggregation); check_compute_query_format(output, params, vk); }