Skip to content

Commit

Permalink
Add multi-thread pk read/write
Browse files Browse the repository at this point in the history
  • Loading branch information
nyunyunyunyu committed Oct 27, 2023
1 parent 782df4a commit 4beeb03
Show file tree
Hide file tree
Showing 6 changed files with 402 additions and 96 deletions.
17 changes: 9 additions & 8 deletions halo2_proofs/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
[package]
name = "halo2_proofs"
version = "0.2.0"
authors = [
"Sean Bowe <[email protected]>",
"Ying Tong Lai <[email protected]>",
"Daira Hopwood <[email protected]>",
"Jack Grigg <[email protected]>",
]
authors = ["Sean Bowe <[email protected]>", "Ying Tong Lai <[email protected]>", "Daira Hopwood <[email protected]>", "Jack Grigg <[email protected]>"]
edition = "2021"
rust-version = "1.56.1"
description = """
Expand Down Expand Up @@ -55,14 +50,18 @@ ff = "0.12"
group = "0.12"
halo2curves = { path = "../arithmetic/curves" }
rand = "0.8"
rand_core = { version = "0.6", default-features = false}
rand_core = { version = "0.6", default-features = false }
tracing = "0.1"
blake2b_simd = "1"
rustc-hash = "1.1.0"
sha3 = "0.9.1"
ark-std = { version = "0.3.0", features = ["print-trace"], optional = true }
serde = { version = "1.0", default-features = false, features = ["derive"] }
bincode = "1.3.3"
maybe-rayon = { version = "0.1.0", default-features = false }
itertools = "0.10"
tokio = { version = "1.33", features = ["full"] }
serde_json = { version = "1.0", default-features = false }

# Developer tooling dependencies
plotters = { version = "0.3.0", optional = true }
Expand All @@ -80,8 +79,10 @@ rand_chacha = "0.3.1"
getrandom = { version = "0.2", features = ["js"] }

[features]
default = ["batch"]
default = ["batch", "multicore"]
multicore = ["maybe-rayon/threads"]
dev-graph = ["plotters", "tabbycat"]
test-dev-graph = ["dev-graph", "plotters/bitmap_backend", "plotters/bitmap_encoder", "plotters/ttf"]
gadget-traces = ["backtrace"]
sanity-checks = []
batch = ["rand/getrandom"]
Expand Down
200 changes: 115 additions & 85 deletions halo2_proofs/examples/serialization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,91 +126,121 @@ impl Circuit<Fr> for StandardPlonk {
}
}

fn main() -> std::io::Result<()> {
#[tokio::main(flavor = "multi_thread", worker_threads = 24)]
async fn main() -> std::io::Result<()> {
let k = 22;
let circuit = StandardPlonk(Fr::random(OsRng));
let params = ParamsKZG::<Bn256>::setup(k, OsRng);
let vk = keygen_vk(&params, &circuit).expect("vk should not fail");
let pk = keygen_pk(&params, vk, &circuit).expect("pk should not fail");

for buf_size in [1024, 8 * 1024, 1024 * 1024, 1024 * 1024 * 1024] {
println!("buf_size: {buf_size}");
// Using halo2_proofs serde implementation
let f = File::create("serialization-test.pk")?;
let mut writer = BufWriter::with_capacity(buf_size, f);
let start = std::time::Instant::now();
pk.write(&mut writer, SerdeFormat::RawBytes)?;
writer.flush().unwrap();
println!("SerdeFormat::RawBytes pk write time: {:?}", start.elapsed());

let f = File::open("serialization-test.pk")?;
let mut reader = BufReader::with_capacity(buf_size, f);
let start = std::time::Instant::now();
let pk =
ProvingKey::<G1Affine>::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes)
.unwrap();
println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed());

let metadata = fs::metadata("serialization-test.pk")?;
let file_size = metadata.len();
println!("The size of the file is {} bytes", file_size);
std::fs::remove_file("serialization-test.pk")?;

// Using bincode
let f = File::create("serialization-test.pk")?;
let mut writer = BufWriter::with_capacity(buf_size, f);
let start = std::time::Instant::now();
bincode::serialize_into(&mut writer, &pk).unwrap();
writer.flush().unwrap();
println!("bincode pk write time: {:?}", start.elapsed());

let f = File::open("serialization-test.pk").unwrap();
let mut reader = BufReader::with_capacity(buf_size, f);
let start = std::time::Instant::now();
let pk: ProvingKey<G1Affine> = bincode::deserialize_from(&mut reader).unwrap();
println!("bincode pk read time: {:?}", start.elapsed());

let metadata = fs::metadata("serialization-test.pk")?;
let file_size = metadata.len();
println!("The size of the file is {} bytes", file_size);
std::fs::remove_file("serialization-test.pk").unwrap();

let instances: &[&[Fr]] = &[&[circuit.clone().0]];
let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
create_proof::<
KZGCommitmentScheme<Bn256>,
ProverGWC<'_, Bn256>,
Challenge255<G1Affine>,
_,
Blake2bWrite<Vec<u8>, G1Affine, Challenge255<_>>,
_,
>(
&params,
&pk,
&[circuit.clone()],
&[instances],
OsRng,
&mut transcript,
)
.expect("prover should not fail");
let proof = transcript.finalize();

let strategy = SingleStrategy::new(&params);
let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
assert!(verify_proof::<
KZGCommitmentScheme<Bn256>,
VerifierGWC<'_, Bn256>,
Challenge255<G1Affine>,
Blake2bRead<&[u8], G1Affine, Challenge255<G1Affine>>,
SingleStrategy<'_, Bn256>,
>(
&params,
pk.get_vk(),
strategy,
&[instances],
&mut transcript
)
.is_ok());
}

let buf_size = 1024 * 1024;

// let pk_path = "/home/ubuntu/playground/serialization-test.pk";
// let f = File::open(pk_path)?;
// let mut reader = BufReader::with_capacity(buf_size, f);
// let start = std::time::Instant::now();
// let pk = ProvingKey::<G1Affine>::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes)
// .unwrap();
// println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed());

// let pk_folder = "/home/ubuntu/playground/serialization-test/";
let pk_folder = "/mnt/ramdisk/serialization-test";
// let start = std::time::Instant::now();
// pk.multi_thread_write(pk_folder, SerdeFormat::RawBytes)?;
// println!(
// "SerdeFormat::RawBytes pk multi thread write time: {:?}",
// start.elapsed()
// );

let start = std::time::Instant::now();
ProvingKey::<G1Affine>::multi_thread_read::<StandardPlonk>(pk_folder, SerdeFormat::RawBytes)
.await?;
println!(
"SerdeFormat::RawBytes pk multi thread read time: {:?}",
start.elapsed()
);

Ok(())
// let circuit = StandardPlonk(Fr::random(OsRng));
// let params = ParamsKZG::<Bn256>::setup(k, OsRng);
// let vk = keygen_vk(&params, &circuit).expect("vk should not fail");
// let pk = keygen_pk(&params, vk, &circuit).expect("pk should not fail");

// for buf_size in [1024, 8 * 1024, 1024 * 1024, 1024 * 1024 * 1024] {
// println!("buf_size: {buf_size}");
// // Using halo2_proofs serde implementation
// let f = File::create("serialization-test.pk")?;
// let mut writer = BufWriter::with_capacity(buf_size, f);
// let start = std::time::Instant::now();
// pk.write(&mut writer, SerdeFormat::RawBytes)?;
// writer.flush().unwrap();
// println!("SerdeFormat::RawBytes pk write time: {:?}", start.elapsed());

// let f = File::open("serialization-test.pk")?;
// let mut reader = BufReader::with_capacity(buf_size, f);
// let start = std::time::Instant::now();
// let pk =
// ProvingKey::<G1Affine>::read::<_, StandardPlonk>(&mut reader, SerdeFormat::RawBytes)
// .unwrap();
// println!("SerdeFormat::RawBytes pk read time: {:?}", start.elapsed());

// let metadata = fs::metadata("serialization-test.pk")?;
// let file_size = metadata.len();
// println!("The size of the file is {} bytes", file_size);
// std::fs::remove_file("serialization-test.pk")?;

// // Using bincode
// let f = File::create("serialization-test.pk")?;
// let mut writer = BufWriter::with_capacity(buf_size, f);
// let start = std::time::Instant::now();
// bincode::serialize_into(&mut writer, &pk).unwrap();
// writer.flush().unwrap();
// println!("bincode pk write time: {:?}", start.elapsed());

// let f = File::open("serialization-test.pk").unwrap();
// let mut reader = BufReader::with_capacity(buf_size, f);
// let start = std::time::Instant::now();
// let pk: ProvingKey<G1Affine> = bincode::deserialize_from(&mut reader).unwrap();
// println!("bincode pk read time: {:?}", start.elapsed());

// let metadata = fs::metadata("serialization-test.pk")?;
// let file_size = metadata.len();
// println!("The size of the file is {} bytes", file_size);
// std::fs::remove_file("serialization-test.pk").unwrap();

// let instances: &[&[Fr]] = &[&[circuit.clone().0]];
// let mut transcript = Blake2bWrite::<_, _, Challenge255<_>>::init(vec![]);
// create_proof::<
// KZGCommitmentScheme<Bn256>,
// ProverGWC<'_, Bn256>,
// Challenge255<G1Affine>,
// _,
// Blake2bWrite<Vec<u8>, G1Affine, Challenge255<_>>,
// _,
// >(
// &params,
// &pk,
// &[circuit.clone()],
// &[instances],
// OsRng,
// &mut transcript,
// )
// .expect("prover should not fail");
// let proof = transcript.finalize();

// let strategy = SingleStrategy::new(&params);
// let mut transcript = Blake2bRead::<_, _, Challenge255<_>>::init(&proof[..]);
// assert!(verify_proof::<
// KZGCommitmentScheme<Bn256>,
// VerifierGWC<'_, Bn256>,
// Challenge255<G1Affine>,
// Blake2bRead<&[u8], G1Affine, Challenge255<G1Affine>>,
// SingleStrategy<'_, Bn256>,
// >(
// &params,
// pk.get_vk(),
// strategy,
// &[instances],
// &mut transcript
// )
// .is_ok());
// }
// Ok(())
}
69 changes: 68 additions & 1 deletion halo2_proofs/src/helpers.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
use crate::poly::Polynomial;
use ff::PrimeField;
use halo2curves::{pairing::Engine, serde::SerdeObject, CurveAffine};
use std::io;
use itertools::Itertools;
use maybe_rayon::prelude::{
IndexedParallelIterator, IntoParallelIterator, IntoParallelRefIterator, ParallelIterator,
};
use std::{
fs::File,
io::{self, BufReader, BufWriter},
path::Path,
};

/// This enum specifies how various types are serialized and deserialized.
#[derive(Clone, Copy, Debug)]
Expand Down Expand Up @@ -125,6 +133,35 @@ pub(crate) fn read_polynomial_vec<R: io::Read, F: SerdePrimeField, B>(
.collect()
}

/// Reads a vector of polynomials from buffer
pub(crate) async fn multi_thread_read_polynomial_vec<
F: SerdePrimeField,
B: Send + Sync + 'static,
>(
pk_prefix_path: impl AsRef<Path>,
format: SerdeFormat,
n: usize,
) -> Vec<Polynomial<F, B>> {
const BUFFER_SIZE: usize = 1024 * 1024;
let join_handles = (0..n)
.map(|i| {
let mut poly_path = pk_prefix_path
.as_ref()
.clone()
.to_path_buf()
.into_os_string();
poly_path.push(format!("_{i}"));
let mut reader = BufReader::with_capacity(BUFFER_SIZE, File::open(poly_path).unwrap());
tokio::spawn(async move { Polynomial::<F, B>::read(&mut reader, format) })
})
.collect_vec();
let mut ret = Vec::with_capacity(join_handles.len());
for join_handle in join_handles {
ret.push(join_handle.await.unwrap());
}
ret
}

/// Writes a slice of polynomials to buffer
pub(crate) fn write_polynomial_slice<W: io::Write, F: SerdePrimeField, B>(
slice: &[Polynomial<F, B>],
Expand All @@ -139,6 +176,36 @@ pub(crate) fn write_polynomial_slice<W: io::Write, F: SerdePrimeField, B>(
}
}

/// Writes a slice of polynomials to buffer
pub(crate) fn multi_thread_write_polynomial_slice<F: SerdePrimeField, B: Send + Sync>(
slice: &[Polynomial<F, B>],
pk_prefix_path: impl AsRef<Path>,
format: SerdeFormat,
) {
const BUFFER_SIZE: usize = 1024 * 1024;
let poly_path = slice
.iter()
.enumerate()
.map(|(i, _)| {
let mut poly_path = pk_prefix_path
.as_ref()
.clone()
.to_path_buf()
.into_os_string();
poly_path.push(format!("_{i}"));
poly_path
})
.collect_vec();
slice
.par_iter()
.zip_eq(poly_path.par_iter())
.for_each(|(poly, poly_path)| {
let mut writer =
BufWriter::with_capacity(BUFFER_SIZE, File::create(poly_path).unwrap());
poly.write(&mut writer, format);
});
}

/// Gets the total number of bytes of a slice of polynomials, assuming all polynomials are the same length
pub(crate) fn polynomial_slice_byte_length<F: PrimeField, B>(slice: &[Polynomial<F, B>]) -> usize {
let field_len = F::default().to_repr().as_ref().len();
Expand Down
Loading

0 comments on commit 4beeb03

Please sign in to comment.