diff --git a/crates/chia-bls/benches/cache.rs b/crates/chia-bls/benches/cache.rs index e8db3116f..e8c92f75c 100644 --- a/crates/chia-bls/benches/cache.rs +++ b/crates/chia-bls/benches/cache.rs @@ -16,7 +16,7 @@ fn cache_benchmark(c: &mut Criterion) { let mut agg_sig = Signature::default(); for i in 0..1000 { - let derived = sk.derive_hardened(i as u32); + let derived = sk.derive_hardened(i); let pk = derived.public_key(); let sig = sign(&derived, msg); agg_sig.aggregate(&sig); @@ -28,43 +28,43 @@ fn cache_benchmark(c: &mut Criterion) { c.bench_function("bls_cache.aggregate_verify, 0% cache hits", |b| { let mut cache = bls_cache.clone(); b.iter(|| { - assert!(cache.aggregate_verify(&pks, [&msg].iter().cycle(), &agg_sig)); + assert!(cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)); }); }); // populate 10% of keys - bls_cache.aggregate_verify(&pks[0..100], [&msg].iter().cycle(), &agg_sig); + bls_cache.aggregate_verify(pks[0..100].iter().zip([&msg].iter().cycle()), &agg_sig); c.bench_function("bls_cache.aggregate_verify, 10% cache hits", |b| { let mut cache = bls_cache.clone(); b.iter(|| { - assert!(cache.aggregate_verify(&pks, [&msg].iter().cycle(), &agg_sig)); + assert!(cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)); }); }); // populate another 10% of keys - bls_cache.aggregate_verify(&pks[100..200], [&msg].iter().cycle(), &agg_sig); + bls_cache.aggregate_verify(pks[100..200].iter().zip([&msg].iter().cycle()), &agg_sig); c.bench_function("bls_cache.aggregate_verify, 20% cache hits", |b| { let mut cache = bls_cache.clone(); b.iter(|| { - assert!(cache.aggregate_verify(&pks, [&msg].iter().cycle(), &agg_sig)); + assert!(cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)); }); }); // populate another 30% of keys - bls_cache.aggregate_verify(&pks[200..500], [&msg].iter().cycle(), &agg_sig); + bls_cache.aggregate_verify(pks[200..500].iter().zip([&msg].iter().cycle()), &agg_sig); c.bench_function("bls_cache.aggregate_verify, 50% cache hits", |b| { let mut cache = bls_cache.clone(); b.iter(|| { - assert!(cache.aggregate_verify(&pks, [&msg].iter().cycle(), &agg_sig)); + assert!(cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)); }); }); // populate all other keys - bls_cache.aggregate_verify(&pks[500..1000], [&msg].iter().cycle(), &agg_sig); + bls_cache.aggregate_verify(pks[500..1000].iter().zip([&msg].iter().cycle()), &agg_sig); c.bench_function("bls_cache.aggregate_verify, 100% cache hits", |b| { let mut cache = bls_cache.clone(); b.iter(|| { - assert!(cache.aggregate_verify(&pks, [&msg].iter().cycle(), &agg_sig)); + assert!(cache.aggregate_verify(pks.iter().zip([&msg].iter().cycle()), &agg_sig)); }); }); diff --git a/crates/chia-bls/src/bls_cache.rs b/crates/chia-bls/src/bls_cache.rs index 38e255c63..9d9339608 100644 --- a/crates/chia-bls/src/bls_cache.rs +++ b/crates/chia-bls/src/bls_cache.rs @@ -1,11 +1,9 @@ -use std::borrow::Borrow; -use std::num::NonZeroUsize; - -use lru::LruCache; -use sha2::{Digest, Sha256}; - use crate::{aggregate_verify_gt, hash_to_g2}; use crate::{GTElement, PublicKey, Signature}; +use lru::LruCache; +use sha2::{Digest, Sha256}; +use std::borrow::Borrow; +use std::num::NonZeroUsize; /// This is a cache of pairings of public keys and their corresponding message. /// It accelerates aggregate verification when some public keys have already @@ -44,13 +42,12 @@ impl BlsCache { self.cache.is_empty() } - pub fn aggregate_verify( + pub fn aggregate_verify, Msg: AsRef<[u8]>>( &mut self, - pks: impl IntoIterator>, - msgs: impl IntoIterator>, + pks_msgs: impl IntoIterator, sig: &Signature, ) -> bool { - let iter = pks.into_iter().zip(msgs).map(|(pk, msg)| -> GTElement { + let iter = pks_msgs.into_iter().map(|(pk, msg)| -> GTElement { // Hash pubkey + message let mut hasher = Sha256::new(); hasher.update(pk.borrow().to_bytes()); @@ -124,7 +121,7 @@ impl BlsCache { .map(|item| item?.extract()) .collect::>>()?; - Ok(self.aggregate_verify(pks, msgs, sig)) + Ok(self.aggregate_verify(pks.into_iter().zip(msgs), sig)) } #[pyo3(name = "len")] @@ -187,11 +184,11 @@ pub mod tests { assert!(bls_cache.is_empty()); // Verify the signature and add to the cache. - assert!(bls_cache.aggregate_verify(pk_list, msg_list, &sig)); + assert!(bls_cache.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig)); assert_eq!(bls_cache.len(), 1); // Now that it's cached, it shouldn't cache it again. - assert!(bls_cache.aggregate_verify(pk_list, msg_list, &sig)); + assert!(bls_cache.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig)); assert_eq!(bls_cache.len(), 1); } @@ -211,7 +208,7 @@ pub mod tests { assert!(bls_cache.is_empty()); // Add the first signature to cache. - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &agg_sig)); + assert!(bls_cache.aggregate_verify(pk_list.iter().zip(msg_list.iter()), &agg_sig)); assert_eq!(bls_cache.len(), 1); // Try with the first key message pair in the cache but not the second. @@ -223,7 +220,7 @@ pub mod tests { pk_list.push(pk2); msg_list.push(msg2); - assert!(bls_cache.aggregate_verify(&pk_list, &msg_list, &agg_sig)); + assert!(bls_cache.aggregate_verify(pk_list.iter().zip(msg_list.iter()), &agg_sig)); assert_eq!(bls_cache.len(), 2); // Try reusing a public key. @@ -234,7 +231,7 @@ pub mod tests { msg_list.push(msg3); // Verify this signature and add to the cache as well (since it's still a different aggregate). - assert!(bls_cache.aggregate_verify(pk_list, msg_list, &agg_sig)); + assert!(bls_cache.aggregate_verify(pk_list.iter().zip(msg_list), &agg_sig)); assert_eq!(bls_cache.len(), 3); } @@ -257,7 +254,7 @@ pub mod tests { let msg_list = [msg]; // Add to cache by validating them one at a time. - assert!(bls_cache.aggregate_verify(pk_list.iter(), msg_list.iter(), &sig)); + assert!(bls_cache.aggregate_verify(pk_list.into_iter().zip(msg_list), &sig)); } // The cache should be full now. @@ -285,6 +282,6 @@ pub mod tests { let pks: [&PublicKey; 0] = []; let msgs: [&[u8]; 0] = []; - assert!(bls_cache.aggregate_verify(pks, msgs, &Signature::default())); + assert!(bls_cache.aggregate_verify(pks.into_iter().zip(msgs), &Signature::default())); } }