From 217a2d307b29a809fca5888c5d15449c81b4b2a2 Mon Sep 17 00:00:00 2001 From: Nikita Masych Date: Thu, 17 Oct 2024 19:12:38 +0300 Subject: [PATCH] fix: saturating subtraction for weights --- src/party.rs | 10 +++++++--- tests/mod.rs | 16 +++++++++++++++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/party.rs b/src/party.rs index bf2390c..293a665 100644 --- a/src/party.rs +++ b/src/party.rs @@ -495,7 +495,7 @@ impl> Party { self.cfg.party_weights[routing.sender as usize] as u128; let self_weight = self.cfg.party_weights[self.id as usize] as u128; - if self.messages_1b_weight >= self.cfg.threshold - self_weight { + if self.messages_1b_weight >= self.cfg.threshold.saturating_sub(self_weight) { self.status = PartyStatus::Passed1b; } } @@ -576,7 +576,9 @@ impl> Party { ); let self_weight = self.cfg.party_weights[self.id as usize] as u128; - if self.messages_2av_state.get_weight() >= self.cfg.threshold - self_weight { + if self.messages_2av_state.get_weight() + >= self.cfg.threshold.saturating_sub(self_weight) + { self.status = PartyStatus::Passed2av; } } @@ -609,7 +611,9 @@ impl> Party { ); let self_weight = self.cfg.party_weights[self.id as usize] as u128; - if self.messages_2b_state.get_weight() >= self.cfg.threshold - self_weight { + if self.messages_2b_state.get_weight() + >= self.cfg.threshold.saturating_sub(self_weight) + { self.status = PartyStatus::Passed2b; } } diff --git a/tests/mod.rs b/tests/mod.rs index d249ea6..3d32bfa 100644 --- a/tests/mod.rs +++ b/tests/mod.rs @@ -286,7 +286,6 @@ async fn test_ballot_many_parties() { analyze_ballot(results); } -#[ignore = "failing for now"] #[tokio::test] async fn test_ballot_max_weight() { let weights = vec![u64::MAX, 1]; @@ -301,3 +300,18 @@ async fn test_ballot_max_weight() { analyze_ballot(results); } + +#[tokio::test] +async fn test_ballot_weights_underflow() { + let weights = vec![100, 1, 2, 3, 4]; + let threshold = BPConConfig::compute_bft_threshold(weights.clone()); + let cfg = BPConConfig::with_default_timeouts(weights, threshold); + + let (parties, receivers, senders) = create_parties(cfg); + let ballot_tasks = launch_parties(parties); + let p2p_task = propagate_p2p(receivers, senders); + let results = await_results(ballot_tasks).await; + p2p_task.abort(); + + analyze_ballot(results); +}