diff --git a/Cargo.lock b/Cargo.lock index 93d6d3184..5eeacc611 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -314,7 +314,7 @@ dependencies = [ "criterion", "hex", "hkdf", - "lru", + "linked-hash-map", "pyo3", "rand", "rstest", @@ -1373,6 +1373,12 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "linked-hash-map" +version = "0.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0717cef1bc8b636c6e1c1bbdefc09e6322da8a9321966e8928ef80d20f7f770f" + [[package]] name = "log" version = "0.4.22" diff --git a/Cargo.toml b/Cargo.toml index 59370c183..9940eaf63 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -131,7 +131,6 @@ hex = "0.4.3" thiserror = "1.0.63" pyo3 = "0.22.5" arbitrary = "1.3.2" -lru = "0.12.4" rand = "0.8.5" criterion = "0.5.1" rstest = "0.22.0" diff --git a/crates/chia-bls/Cargo.toml b/crates/chia-bls/Cargo.toml index 97a4d9bd0..9cffe6356 100644 --- a/crates/chia-bls/Cargo.toml +++ b/crates/chia-bls/Cargo.toml @@ -26,7 +26,7 @@ hex = { workspace = true } thiserror = { workspace = true } pyo3 = { workspace = true, features = ["multiple-pymethods"], optional = true } arbitrary = { workspace = true, optional = true } -lru = { workspace = true } +linked-hash-map = "0.5.6" [dev-dependencies] rand = { workspace = true } diff --git a/crates/chia-bls/benches/cache.rs b/crates/chia-bls/benches/cache.rs index e8c92f75c..8ac330817 100644 --- a/crates/chia-bls/benches/cache.rs +++ b/crates/chia-bls/benches/cache.rs @@ -2,7 +2,7 @@ use chia_bls::aggregate_verify; use chia_bls::{sign, BlsCache, SecretKey, Signature}; use criterion::{criterion_group, criterion_main, Criterion}; use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::{seq::SliceRandom, Rng, SeedableRng}; fn cache_benchmark(c: &mut Criterion) { let mut rng = StdRng::seed_from_u64(1337); @@ -76,6 +76,43 @@ fn cache_benchmark(c: &mut Criterion) { )); }); }); + + // Add more pairs to the cache so we can evict a relatively larger number + for i in 1_000..20_000 { + let derived = sk.derive_hardened(i); + let pk = derived.public_key(); + let sig = sign(&derived, msg); + agg_sig.aggregate(&sig); + pks.push(pk); + } + bls_cache.aggregate_verify( + pks[1_000..20_000].iter().zip([&msg].iter().cycle()), + &agg_sig, + ); + + c.bench_function("bls_cache.evict 5% of the items", |b| { + let mut cache = bls_cache.clone(); + let mut pks_shuffled = pks.clone(); + pks_shuffled.shuffle(&mut rng); + b.iter(|| { + if cache.is_empty() { + return; + } + cache.evict(pks_shuffled.iter().take(1_000).zip([&msg].iter().cycle())); + }); + }); + + c.bench_function("bls_cache.evict 100% of the items", |b| { + let mut cache = bls_cache.clone(); + let mut pks_shuffled = pks.clone(); + pks_shuffled.shuffle(&mut rng); + b.iter(|| { + if cache.is_empty() { + return; + } + cache.evict(pks_shuffled.iter().zip([&msg].iter().cycle())); + }); + }); } criterion_group!(cache, cache_benchmark); diff --git a/crates/chia-bls/src/bls_cache.rs b/crates/chia-bls/src/bls_cache.rs index e3b8cdcf9..4aa2d030d 100644 --- a/crates/chia-bls/src/bls_cache.rs +++ b/crates/chia-bls/src/bls_cache.rs @@ -2,7 +2,7 @@ use std::borrow::Borrow; use std::num::NonZeroUsize; use chia_sha2::Sha256; -use lru::LruCache; +use linked_hash_map::LinkedHashMap; use std::sync::Mutex; use crate::{aggregate_verify_gt, hash_to_g2}; @@ -17,16 +17,35 @@ use crate::{GTElement, PublicKey, Signature}; /// However, validating a signature where we have no cached GT elements, the /// aggregate_verify() primitive is faster. When long-syncing, that's /// preferable. + +#[derive(Debug, Clone)] +struct BlsCacheData { + // sha256(pubkey + message) -> GTElement + items: LinkedHashMap<[u8; 32], GTElement>, + capacity: NonZeroUsize, +} + +impl BlsCacheData { + pub fn put(&mut self, hash: [u8; 32], pairing: GTElement) { + // If the cache is full, remove the oldest item. + if self.items.len() == self.capacity.get() { + if let Some((oldest_key, _)) = self.items.pop_front() { + self.items.remove(&oldest_key); + } + } + self.items.insert(hash, pairing); + } +} + #[cfg_attr(feature = "py-bindings", pyo3::pyclass(name = "BLSCache"))] #[derive(Debug)] pub struct BlsCache { - // sha256(pubkey + message) -> GTElement - cache: Mutex>, + cache: Mutex, } impl Default for BlsCache { fn default() -> Self { - Self::new(NonZeroUsize::new(50000).unwrap()) + Self::new(NonZeroUsize::new(50_000).unwrap()) } } @@ -39,18 +58,21 @@ impl Clone for BlsCache { } impl BlsCache { - pub fn new(cache_size: NonZeroUsize) -> Self { + pub fn new(capacity: NonZeroUsize) -> Self { Self { - cache: Mutex::new(LruCache::new(cache_size)), + cache: Mutex::new(BlsCacheData { + items: LinkedHashMap::new(), + capacity, + }), } } pub fn len(&self) -> usize { - self.cache.lock().expect("cache").len() + self.cache.lock().expect("cache").items.len() } pub fn is_empty(&self) -> bool { - self.cache.lock().expect("cache").is_empty() + self.cache.lock().expect("cache").items.is_empty() } pub fn aggregate_verify, Msg: AsRef<[u8]>>( @@ -67,7 +89,7 @@ impl BlsCache { let hash: [u8; 32] = hasher.finalize(); // If the pairing is in the cache, we don't need to recalculate it. - if let Some(pairing) = self.cache.lock().expect("cache").get(&hash).cloned() { + if let Some(pairing) = self.cache.lock().expect("cache").items.get(&hash).cloned() { return pairing; } @@ -88,6 +110,22 @@ impl BlsCache { let hash: [u8; 32] = hasher.finalize(); self.cache.lock().expect("cache").put(hash, gt); } + + pub fn evict(&mut self, pks_msgs: impl IntoIterator) + where + Pk: Borrow, + Msg: AsRef<[u8]>, + { + let mut c = self.cache.lock().expect("cache"); + for (pk, msg) in pks_msgs { + let mut hasher = Sha256::new(); + let mut aug_msg = pk.borrow().to_bytes().to_vec(); + aug_msg.extend_from_slice(msg.as_ref()); + hasher.update(&aug_msg); + let hash: [u8; 32] = hasher.finalize(); + c.items.remove(&hash); + } + } } #[cfg(feature = "py-bindings")] @@ -148,7 +186,7 @@ impl BlsCache { use pyo3::types::PyBytes; let ret = PyList::empty_bound(py); let c = self.cache.lock().expect("cache"); - for (key, value) in &*c { + for (key, value) in &c.items { ret.append((PyBytes::new_bound(py, key), value.clone().into_py(py)))?; } Ok(ret.into()) @@ -167,6 +205,20 @@ impl BlsCache { } Ok(()) } + + #[pyo3(name = "evict")] + pub fn py_evict(&mut self, pks: &Bound<'_, PyList>, msgs: &Bound<'_, PyList>) -> PyResult<()> { + let pks = pks + .iter()? + .map(|item| item?.extract()) + .collect::>>()?; + let msgs = msgs + .iter()? + .map(|item| item?.extract()) + .collect::>>()?; + self.evict(pks.into_iter().zip(msgs)); + Ok(()) + } } #[cfg(test)] @@ -261,21 +313,24 @@ pub mod tests { } // The cache should be full now. - assert_eq!(bls_cache.cache.lock().expect("cache").len(), 3); - - // Recreate first key. - let sk = SecretKey::from_seed(&[1; 32]); - let pk = sk.public_key(); - let msg = [106; 32]; - - let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat(); - - let mut hasher = Sha256::new(); - hasher.update(aug_msg); - let hash: [u8; 32] = hasher.finalize(); + assert_eq!(bls_cache.len(), 3); - // The first key should have been removed, since it's the oldest that's been accessed. - assert!(!bls_cache.cache.lock().expect("cache").contains(&hash)); + // Recreate first two keys and make sure they got removed. + for i in 1..=2 { + let sk = SecretKey::from_seed(&[i; 32]); + let pk = sk.public_key(); + let msg = [106; 32]; + let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat(); + let mut hasher = Sha256::new(); + hasher.update(aug_msg); + let hash: [u8; 32] = hasher.finalize(); + assert!(!bls_cache + .cache + .lock() + .expect("cache") + .items + .contains_key(&hash)); + } } #[test] @@ -286,4 +341,51 @@ pub mod tests { assert!(bls_cache.aggregate_verify(pks_msgs, &Signature::default())); } + + #[test] + fn test_evict() { + let mut bls_cache = BlsCache::new(NonZeroUsize::new(5).unwrap()); + // Create 5 pk msg pairs and add them to the cache. + let mut pks_msgs = Vec::new(); + for i in 1..=5 { + let sk = SecretKey::from_seed(&[i; 32]); + let pk = sk.public_key(); + let msg = [42; 32]; + let sig = sign(&sk, msg); + pks_msgs.push((pk, msg)); + assert!(bls_cache.aggregate_verify([(pk, msg)], &sig)); + } + assert_eq!(bls_cache.len(), 5); + // Evict the first and third entries. + let pks_msgs_to_evict = vec![pks_msgs[0], pks_msgs[2]]; + bls_cache.evict(pks_msgs_to_evict.iter().copied()); + // The cache should have 3 items now. + assert_eq!(bls_cache.len(), 3); + // Check that the evicted entries are no longer in the cache. + for (pk, msg) in &pks_msgs_to_evict { + let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat(); + let mut hasher = Sha256::new(); + hasher.update(aug_msg); + let hash: [u8; 32] = hasher.finalize(); + assert!(!bls_cache + .cache + .lock() + .expect("cache") + .items + .contains_key(&hash)); + } + // Check that the remaining entries are still in the cache. + for (pk, msg) in &[pks_msgs[1], pks_msgs[3], pks_msgs[4]] { + let aug_msg = [&pk.to_bytes(), msg.as_ref()].concat(); + let mut hasher = Sha256::new(); + hasher.update(aug_msg); + let hash: [u8; 32] = hasher.finalize(); + assert!(bls_cache + .cache + .lock() + .expect("cache") + .items + .contains_key(&hash)); + } + } } diff --git a/wheel/generate_type_stubs.py b/wheel/generate_type_stubs.py index 169ab53b6..fb85366fb 100644 --- a/wheel/generate_type_stubs.py +++ b/wheel/generate_type_stubs.py @@ -368,6 +368,7 @@ def len(self) -> int: ... def aggregate_verify(self, pks: list[G1Element], msgs: list[bytes], sig: G2Element) -> bool: ... def items(self) -> list[tuple[bytes, GTElement]]: ... def update(self, other: Sequence[tuple[bytes, GTElement]]) -> None: ... + def evict(self, pks: list[G1Element], msgs: list[bytes]) -> None: ... @final class AugSchemeMPL: diff --git a/wheel/python/chia_rs/chia_rs.pyi b/wheel/python/chia_rs/chia_rs.pyi index 7f8dc7c25..171b0566f 100644 --- a/wheel/python/chia_rs/chia_rs.pyi +++ b/wheel/python/chia_rs/chia_rs.pyi @@ -99,6 +99,7 @@ class BLSCache: def aggregate_verify(self, pks: list[G1Element], msgs: list[bytes], sig: G2Element) -> bool: ... def items(self) -> list[tuple[bytes, GTElement]]: ... def update(self, other: Sequence[tuple[bytes, GTElement]]) -> None: ... + def evict(self, pks: list[G1Element], msgs: list[bytes]) -> None: ... @final class AugSchemeMPL: