Skip to content

Commit

Permalink
feat: save +load vk/pk/breakpoints
Browse files Browse the repository at this point in the history
  • Loading branch information
rpalakkal committed Feb 6, 2024
1 parent e38a7d3 commit f89892f
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 28 deletions.
89 changes: 78 additions & 11 deletions axiom-client-sdk/src/cmd.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,21 @@
use std::{env, fmt::Debug, fs, path::PathBuf};
use std::{
env,
fmt::Debug,
fs::{self, File},
io::BufWriter,
path::PathBuf,
};

use axiom_client::axiom_eth::halo2_base::{gates::circuit::BaseCircuitParams, AssignedValue};
use axiom_client::{
axiom_eth::{
halo2_base::{gates::circuit::BaseCircuitParams, AssignedValue},
halo2_proofs::{plonk::ProvingKey, SerdeFormat},
halo2curves::bn256::G1Affine,
rlc::virtual_region::RlcThreadBreakPoints,
},
scaffold::AxiomCircuit,
types::AxiomCircuitParams,
};
pub use clap::Parser;
use clap::Subcommand;
use ethers::providers::{Http, Provider};
Expand Down Expand Up @@ -45,6 +60,8 @@ pub struct Cli {
pub provider: Option<String>,
#[arg(short, long = "input")]
pub input_path: Option<PathBuf>,
#[arg(short, long = "data-path")]
pub data_path: Option<PathBuf>,
}

pub fn run_cli<A: AxiomComputeFn>()
Expand All @@ -71,6 +88,7 @@ where
.provider
.unwrap_or_else(|| env::var("PROVIDER_URI").unwrap());
let provider = Provider::<Http>::try_from(provider_uri).unwrap();
let data_path = cli.data_path.unwrap_or_else(|| PathBuf::from("data"));

let params = BaseCircuitParams {
k: 12,
Expand All @@ -89,24 +107,73 @@ where
.mock();
}
SnarkCmd::Keygen => {
AxiomCompute::<A>::new()
let circuit = AxiomCompute::<A>::new()
.use_params(params)
.use_provider(provider)
.keygen();
.use_provider(provider);
let (vkey, pkey, breakpoints) = circuit.keygen();
let pk_path = data_path.join(PathBuf::from("pk.bin"));
if pk_path.exists() {
fs::remove_file(&pk_path).unwrap();
}
let vk_path = data_path.join(PathBuf::from("vk.bin"));
if vk_path.exists() {
fs::remove_file(&vk_path).unwrap();
}
let f = File::create(&vk_path)
.unwrap_or_else(|_| panic!("Could not create file at {vk_path:?}"));
let mut writer = BufWriter::new(f);
vkey.write(&mut writer, SerdeFormat::RawBytes)
.expect("writing vkey should not fail");

let f = File::create(&pk_path)
.unwrap_or_else(|_| panic!("Could not create file at {pk_path:?}"));
let mut writer = BufWriter::new(f);
pkey.write(&mut writer, SerdeFormat::RawBytes)
.expect("writing pkey should not fail");

let breakpoints_path = data_path.join(PathBuf::from("breakpoints.json"));
if breakpoints_path.exists() {
fs::remove_file(&breakpoints_path).unwrap();
}
let f = File::create(&breakpoints_path)
.unwrap_or_else(|_| panic!("Could not create file at {breakpoints_path:?}"));
let mut writer = BufWriter::new(f);
serde_json::to_writer_pretty(&mut writer, &breakpoints)
.expect("writing breakpoints should not fail");
}
SnarkCmd::Prove => {
let compute = AxiomCompute::<A>::new()
.use_params(params)
.use_params(params.clone())
.use_provider(provider);
let (_vk, pk) = compute.keygen();
compute.use_inputs(input).prove(pk);
let pk_path = data_path.join(PathBuf::from("pk.bin"));
let mut f = File::open(&pk_path).unwrap();
let pk = ProvingKey::<G1Affine>::read::<_, AxiomCircuit<Fr, Http, AxiomCompute<A>>>(
&mut f,
SerdeFormat::RawBytes,
AxiomCircuitParams::Base(params),
)
.unwrap();
let breakpoints_path = data_path.join(PathBuf::from("breakpoints.json"));
let f = File::open(&breakpoints_path).unwrap();
let breakpoints: RlcThreadBreakPoints = serde_json::from_reader(f).unwrap();
compute.use_inputs(input).prove(pk, breakpoints);
}
SnarkCmd::Run => {
let compute = AxiomCompute::<A>::new()
.use_params(params)
.use_params(params.clone())
.use_provider(provider);
let (_vk, pk) = compute.keygen();
compute.use_inputs(input).run(pk);
let pk_path = data_path.join(PathBuf::from("pk.bin"));
let mut f = File::open(&pk_path).unwrap();
let pk = ProvingKey::<G1Affine>::read::<_, AxiomCircuit<Fr, Http, AxiomCompute<A>>>(
&mut f,
SerdeFormat::RawBytes,
AxiomCircuitParams::Base(params),
)
.unwrap();
let breakpoints_path = data_path.join(PathBuf::from("breakpoints.json"));
let f = File::open(&breakpoints_path).unwrap();
let breakpoints: RlcThreadBreakPoints = serde_json::from_reader(f).unwrap();
compute.use_inputs(input).run(pk, breakpoints);
}
}
}
20 changes: 16 additions & 4 deletions axiom-client-sdk/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use axiom_client::{
},
halo2_proofs::plonk::{ProvingKey, VerifyingKey},
halo2curves::bn256::G1Affine,
rlc::circuit::builder::RlcCircuitBuilder,
rlc::{circuit::builder::RlcCircuitBuilder, virtual_region::RlcThreadBreakPoints},
snark_verifier_sdk::Snark,
utils::hilo::HiLo,
},
Expand Down Expand Up @@ -141,14 +141,20 @@ where
mock::<Http, Self>(provider, AxiomCircuitParams::Base(params), converted_input);
}

pub fn keygen(&self) -> (VerifyingKey<G1Affine>, ProvingKey<G1Affine>) {
pub fn keygen(
&self,
) -> (
VerifyingKey<G1Affine>,
ProvingKey<G1Affine>,
RlcThreadBreakPoints,
) {
self.check_provider_and_params_set();
let provider = self.provider.clone().unwrap();
let params = self.params.clone().unwrap();
keygen::<Http, Self>(provider, AxiomCircuitParams::Base(params), None)
}

pub fn prove(&self, pk: ProvingKey<G1Affine>) -> Snark {
pub fn prove(&self, pk: ProvingKey<G1Affine>, break_points: RlcThreadBreakPoints) -> Snark {
self.check_all_set();
let provider = self.provider.clone().unwrap();
let params = self.params.clone().unwrap();
Expand All @@ -158,10 +164,15 @@ where
AxiomCircuitParams::Base(params),
converted_input,
pk,
break_points,
)
}

pub fn run(&self, pk: ProvingKey<G1Affine>) -> AxiomV2CircuitOutput {
pub fn run(
&self,
pk: ProvingKey<G1Affine>,
break_points: RlcThreadBreakPoints,
) -> AxiomV2CircuitOutput {
self.check_all_set();
let provider = self.provider.clone().unwrap();
let params = self.params.clone().unwrap();
Expand All @@ -171,6 +182,7 @@ where
AxiomCircuitParams::Base(params),
converted_input,
pk,
break_points,
)
}

Expand Down
14 changes: 12 additions & 2 deletions axiom-client/src/run/inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use axiom_query::axiom_eth::{
SerdeFormat,
},
halo2curves::bn256::{Fr, G1Affine},
rlc::virtual_region::RlcThreadBreakPoints,
snark_verifier_sdk::{halo2::gen_snark_shplonk, Snark},
utils::keccak::decorator::RlcKeccakCircuitParams,
};
Expand Down Expand Up @@ -48,14 +49,19 @@ pub fn keygen<P: JsonRpcClient + Clone, S: AxiomCircuitScaffold<P, Fr>>(
provider: Provider<P>,
raw_circuit_params: AxiomCircuitParams,
inputs: Option<S::InputValue>,
) -> (VerifyingKey<G1Affine>, ProvingKey<G1Affine>) {
) -> (
VerifyingKey<G1Affine>,
ProvingKey<G1Affine>,
RlcThreadBreakPoints,
) {
let circuit_params = RlcKeccakCircuitParams::from(raw_circuit_params.clone());
let params = gen_srs(circuit_params.k() as u32);
let mut runner = AxiomCircuit::<_, _, S>::new(provider, raw_circuit_params).use_inputs(inputs);
if circuit_params.keccak_rows_per_round > 0 {
runner.calculate_params();
}
let vk = keygen_vk(&params, &runner).expect("Failed to generate vk");
let breakpoints = runner.break_points();
let path = Path::new("data/vk.bin");
if let Some(parent) = path.parent() {
create_dir_all(parent).expect("Failed to create data directory");
Expand All @@ -68,18 +74,20 @@ pub fn keygen<P: JsonRpcClient + Clone, S: AxiomCircuitScaffold<P, Fr>>(
let mut pk_file = File::create(path).expect("Failed to create pk file");
pk.write(&mut pk_file, SerdeFormat::Processed)
.expect("Failed to write pk");
(vk, pk)
(vk, pk, breakpoints)
}

pub fn prove<P: JsonRpcClient + Clone, S: AxiomCircuitScaffold<P, Fr>>(
provider: Provider<P>,
raw_circuit_params: AxiomCircuitParams,
inputs: Option<S::InputValue>,
pk: ProvingKey<G1Affine>,
break_points: RlcThreadBreakPoints,
) -> Snark {
let circuit_params = RlcKeccakCircuitParams::from(raw_circuit_params.clone());
let params = gen_srs(circuit_params.k() as u32);
let mut runner = AxiomCircuit::<_, _, S>::new(provider, raw_circuit_params).use_inputs(inputs);
runner.set_break_points(break_points);
if circuit_params.keccak_rows_per_round > 0 {
runner.calculate_params();
}
Expand All @@ -91,12 +99,14 @@ pub fn run<P: JsonRpcClient + Clone, S: AxiomCircuitScaffold<P, Fr>>(
raw_circuit_params: AxiomCircuitParams,
inputs: Option<S::InputValue>,
pk: ProvingKey<G1Affine>,
break_points: RlcThreadBreakPoints,
) -> AxiomV2CircuitOutput {
let circuit_params = RlcKeccakCircuitParams::from(raw_circuit_params.clone());
let k = circuit_params.k();
let params = gen_srs(k as u32);
let mut runner =
AxiomCircuit::<_, _, S>::new(provider, raw_circuit_params.clone()).use_inputs(inputs);
runner.set_break_points(break_points);
let output = runner.scaffold_output();
if circuit_params.keccak_rows_per_round > 0 {
runner.calculate_params();
Expand Down
11 changes: 10 additions & 1 deletion axiom-client/src/scaffold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,16 @@ impl<F: Field, P: JsonRpcClient + Clone, A: AxiomCircuitScaffold<P, F>> AxiomCir
}

pub fn break_points(&self) -> RlcThreadBreakPoints {
self.builder.borrow().break_points()
let rlc_params = self.builder.borrow().params();
if rlc_params.num_rlc_columns == 0 {
let break_points = self.builder.borrow().base.break_points();
RlcThreadBreakPoints {
base: break_points,
rlc: vec![],
}
} else {
self.builder.borrow().break_points()
}
}

pub fn k(&self) -> usize {
Expand Down
16 changes: 8 additions & 8 deletions axiom-client/src/tests/keccak.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ pub fn mock<S: AxiomCircuitScaffold<Http, Fr>>(_circuit: S) {
let params = get_keccak_test_params();
let agg_circuit_params = get_agg_test_params();
let client = get_provider();
let (_, pk) = keygen::<_, S>(client.clone(), params.clone(), None);
let snark = prove::<_, S>(client, params, None, pk);
let (_, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None);
let snark = prove::<_, S>(client, params, None, pk, break_points);
agg_circuit_mock(agg_circuit_params, snark);
}

Expand All @@ -138,8 +138,8 @@ pub fn test_single_subquery_instances<S: AxiomCircuitScaffold<Http, Fr>>(_circui
let num_user_output_fe = runner.output_num_instances();
let subquery_fe = runner.subquery_num_instances();
let results = runner.scaffold_output();
let (_, pk) = keygen::<_, S>(client.clone(), params.clone(), None);
let snark = prove::<_, S>(client, params, None, pk);
let (_, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None);
let snark = prove::<_, S>(client, params, None, pk, break_points);
let agg_circuit =
create_aggregation_circuit(agg_circuit_params, snark.clone(), CircuitBuilderStage::Mock);
let instances = agg_circuit.instances();
Expand All @@ -163,15 +163,15 @@ pub fn test_compute_query<S: AxiomCircuitScaffold<Http, Fr>>(_circuit: S) {
let params = get_keccak_test_params();
let agg_circuit_params = get_agg_test_params();
let client = get_provider();
let (_vk, pk) = keygen::<_, S>(client.clone(), params.clone(), None);
let output = run::<_, S>(client, params.clone(), None, pk);
let (agg_vk, agg_pk, break_points) =
let (_vk, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None);
let output = run::<_, S>(client, params.clone(), None, pk, break_points);
let (agg_vk, agg_pk, agg_break_points) =
agg_circuit_keygen(agg_circuit_params, output.snark.clone());
let final_output = agg_circuit_run(
agg_circuit_params,
output.snark.clone(),
agg_pk,
break_points,
agg_break_points,
output.data,
);
let circuit = create_aggregation_circuit(
Expand Down
4 changes: 2 additions & 2 deletions axiom-client/src/tests/shared_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,8 +157,8 @@ pub fn check_compute_proof_and_query_format<S: AxiomCircuitScaffold<Http, Fr>>(
is_aggregation: bool,
) {
let client = get_provider();
let (vk, pk) = keygen::<_, S>(client.clone(), params.clone(), None);
let output = run::<_, S>(client, params.clone(), None, pk);
let (vk, pk, break_points) = keygen::<_, S>(client.clone(), params.clone(), None);
let output = run::<_, S>(client, params.clone(), None, pk, break_points);
check_compute_proof_format(output.clone(), is_aggregation);
check_compute_query_format(output, params, vk);
}

0 comments on commit f89892f

Please sign in to comment.