diff --git a/Cargo.toml b/Cargo.toml index 527aacf..24e9620 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,8 +19,8 @@ bincode = "^1.3.3" rkyv = { version = "^0.7.44", features = ["validation"] } tokio = { version = "^1.39.2", features = ["full"] } rand = "^0.9.0-alpha.2" -seeded-random = "^0.6.0" thiserror = "^1.0.63" +rand_chacha = "0.3.1" [features] default = ["tokio/full", "rkyv/validation"] diff --git a/src/leader.rs b/src/leader.rs index ea6041c..a2a0778 100644 --- a/src/leader.rs +++ b/src/leader.rs @@ -5,8 +5,8 @@ use crate::party::Party; use crate::value::{Value, ValueSelector}; -use seeded_random::{Random, Seed}; -use std::cmp::Ordering; +use rand_chacha::rand_core::{RngCore, SeedableRng}; +use rand_chacha::ChaCha20Rng; use std::hash::{DefaultHasher, Hash, Hasher}; use thiserror::Error; @@ -70,7 +70,7 @@ impl DefaultLeaderElector { /// Hashes the seed to a value within a specified range. /// - /// This method uses the computed seed to generate a value within the range [0, range). + /// This method uses the computed seed to generate a value within the range [0, range]. /// The algorithm ensures uniform distribution of the resulting value, which is crucial /// for fair leader election. /// @@ -79,19 +79,19 @@ impl DefaultLeaderElector { /// - `range`: The upper limit for the random value generation, typically the sum of party weights. /// /// # Returns - /// A `u64` value within the specified range. - fn hash_to_range(seed: u64, range: u64) -> u64 { + /// A `u128` value within the specified range. + fn hash_to_range(seed: u64, range: u128) -> u128 { // Determine the number of bits required to represent the range - let mut k = 64; - while 1u64 << (k - 1) >= range { + let mut k = 128; + while 1u128 << (k - 1) > range { k -= 1; } // Use a seeded random generator to produce a value within the desired range - let rng = Random::from_seed(Seed::unsafe_new(seed)); + let mut rng = ChaCha20Rng::seed_from_u64(seed); loop { - let mut raw_res: u64 = rng.gen(); - raw_res >>= 64 - k; + let mut raw_res = ((rng.next_u64() as u128) << 64) | (rng.next_u64() as u128); + raw_res >>= 128 - k; if raw_res < range { return raw_res; @@ -124,31 +124,23 @@ impl> LeaderElector for DefaultLeaderElect fn elect_leader(&self, party: &Party) -> Result> { let seed = DefaultLeaderElector::compute_seed(party); - let total_weight: u64 = party.cfg.party_weights.iter().sum(); + let total_weight: u128 = party.cfg.party_weights.iter().map(|&x| x as u128).sum(); if total_weight == 0 { return Err(DefaultLeaderElectorError::ZeroWeightSum.into()); } - // Generate a random number in the range [0, total_weight) + // Generate a random number in the range [0, total_weight] let random_value = DefaultLeaderElector::hash_to_range(seed, total_weight); - // Use binary search to find the corresponding participant based on the cumulative weight - let mut cumulative_weights = vec![0; party.cfg.party_weights.len()]; - cumulative_weights[0] = party.cfg.party_weights[0]; - - for i in 1..party.cfg.party_weights.len() { - cumulative_weights[i] = cumulative_weights[i - 1] + party.cfg.party_weights[i]; - } - - match cumulative_weights.binary_search_by(|&weight| { - if random_value < weight { - Ordering::Greater - } else { - Ordering::Less + let mut cumulative_sum = 0u128; + for (index, &weight) in party.cfg.party_weights.iter().enumerate() { + cumulative_sum += weight as u128; + if random_value <= cumulative_sum { + return Ok(index as u64); } - }) { - Ok(index) | Err(index) => Ok(index as u64), } + + unreachable!("Index is guaranteed to be returned in a loop.") } } @@ -161,6 +153,17 @@ mod tests { use std::thread; use std::time::Duration; + #[test] + fn test_default_leader_elector_weight_one() { + let mut party = MockParty::default(); + party.cfg.party_weights = vec![0, 1, 0, 0]; + + let elector = DefaultLeaderElector::new(); + + let leader = elector.elect_leader(&party).unwrap(); + println!("leader: {}", leader); + } + #[test] fn test_default_leader_elector_determinism() { let party = MockParty::default(); @@ -209,11 +212,11 @@ mod tests { k -= 1; } - let rng = Random::from_seed(Seed::unsafe_new(seed)); + let mut rng = ChaCha20Rng::seed_from_u64(seed); let mut iteration = 1u64; loop { - let mut raw_res: u64 = rng.gen(); + let mut raw_res: u64 = rng.next_u64(); raw_res >>= 64 - k; if raw_res < range { @@ -260,15 +263,15 @@ mod tests { #[test] fn test_rng() { - let rng1 = Random::from_seed(Seed::unsafe_new(123456)); - let rng2 = Random::from_seed(Seed::unsafe_new(123456)); + let mut rng1 = ChaCha20Rng::seed_from_u64(123456); + let mut rng2 = ChaCha20Rng::seed_from_u64(123456); - println!("{}", rng1.gen::()); - println!("{}", rng2.gen::()); + println!("{}", rng1.next_u64()); + println!("{}", rng2.next_u64()); thread::sleep(Duration::from_secs(2)); - println!("{}", rng1.gen::()); - println!("{}", rng2.gen::()); + println!("{}", rng1.next_u64()); + println!("{}", rng2.next_u64()); } } diff --git a/src/party.rs b/src/party.rs index bf2390c..293a665 100644 --- a/src/party.rs +++ b/src/party.rs @@ -495,7 +495,7 @@ impl> Party { self.cfg.party_weights[routing.sender as usize] as u128; let self_weight = self.cfg.party_weights[self.id as usize] as u128; - if self.messages_1b_weight >= self.cfg.threshold - self_weight { + if self.messages_1b_weight >= self.cfg.threshold.saturating_sub(self_weight) { self.status = PartyStatus::Passed1b; } } @@ -576,7 +576,9 @@ impl> Party { ); let self_weight = self.cfg.party_weights[self.id as usize] as u128; - if self.messages_2av_state.get_weight() >= self.cfg.threshold - self_weight { + if self.messages_2av_state.get_weight() + >= self.cfg.threshold.saturating_sub(self_weight) + { self.status = PartyStatus::Passed2av; } } @@ -609,7 +611,9 @@ impl> Party { ); let self_weight = self.cfg.party_weights[self.id as usize] as u128; - if self.messages_2b_state.get_weight() >= self.cfg.threshold - self_weight { + if self.messages_2b_state.get_weight() + >= self.cfg.threshold.saturating_sub(self_weight) + { self.status = PartyStatus::Passed2b; } } diff --git a/tests/mod.rs b/tests/mod.rs index e81927f..3d32bfa 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -199,7 +199,7 @@ async fn test_ballot_malicious_party() { let elector = DefaultLeaderElector::new(); let leader = elector.elect_leader(&parties[0]).unwrap(); - const MALICIOUS_PARTY_ID: u64 = 1; + const MALICIOUS_PARTY_ID: u64 = 2; assert_ne!( MALICIOUS_PARTY_ID, leader, @@ -285,3 +285,33 @@ async fn test_ballot_many_parties() { analyze_ballot(results); } + +#[tokio::test] +async fn test_ballot_max_weight() { + let weights = vec![u64::MAX, 1]; + let threshold = BPConConfig::compute_bft_threshold(weights.clone()); + let cfg = BPConConfig::with_default_timeouts(weights, threshold); + + let (parties, receivers, senders) = create_parties(cfg); + let ballot_tasks = launch_parties(parties); + let p2p_task = propagate_p2p(receivers, senders); + let results = await_results(ballot_tasks).await; + p2p_task.abort(); + + analyze_ballot(results); +} + +#[tokio::test] +async fn test_ballot_weights_underflow() { + let weights = vec![100, 1, 2, 3, 4]; + let threshold = BPConConfig::compute_bft_threshold(weights.clone()); + let cfg = BPConConfig::with_default_timeouts(weights, threshold); + + let (parties, receivers, senders) = create_parties(cfg); + let ballot_tasks = launch_parties(parties); + let p2p_task = propagate_p2p(receivers, senders); + let results = await_results(ballot_tasks).await; + p2p_task.abort(); + + analyze_ballot(results); +}