Skip to content

Commit

Permalink
Merge pull request #3367 from ProvableHQ/fix/transmission-checksum
Browse files Browse the repository at this point in the history
[Fix] Add checksum to `TransmissionID`
  • Loading branch information
zosorock authored Jul 30, 2024
2 parents 878624d + 9f76888 commit b0a80c3
Show file tree
Hide file tree
Showing 19 changed files with 647 additions and 1,167 deletions.
1,319 changes: 301 additions & 1,018 deletions Cargo.lock

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ version = "=0.1.24"
default-features = false

[workspace.dependencies.snarkvm]
#path = "../snarkVM"
git = "https://github.com/AleoNet/snarkVM.git"
rev = "d170a9f" # If this is updated, the rev in `node/rest/Cargo.toml` must be updated as well.
rev = "68f4f31" # If this is updated, the rev in `node/rest/Cargo.toml` must be updated as well.
#version = "=0.16.18"
features = [ "circuit", "console", "rocks" ]

Expand Down
9 changes: 7 additions & 2 deletions node/bft/events/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,15 @@ pub mod prop_tests {
.boxed()
}

pub fn any_transmission_checksum() -> BoxedStrategy<<CurrentNetwork as Network>::TransmissionChecksum> {
Just(0).prop_perturb(|_, mut rng| rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>()).boxed()
}

pub fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
prop_oneof![
any_transaction_id().prop_map(TransmissionID::Transaction),
any_solution_id().prop_map(TransmissionID::Solution),
(any_transaction_id(), any_transmission_checksum())
.prop_map(|(id, cs)| TransmissionID::Transaction(id, cs)),
(any_solution_id(), any_transmission_checksum()).prop_map(|(id, cs)| TransmissionID::Solution(id, cs)),
]
.boxed()
}
Expand Down
23 changes: 3 additions & 20 deletions node/bft/events/src/transmission_request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,32 +58,15 @@ impl<N: Network> FromBytes for TransmissionRequest<N> {

#[cfg(test)]
pub mod prop_tests {
use crate::{
prop_tests::{any_solution_id, any_transaction_id},
TransmissionRequest,
};
use snarkvm::{
console::prelude::{FromBytes, ToBytes},
ledger::narwhal::TransmissionID,
};
use crate::{prop_tests::any_transmission_id, TransmissionRequest};
use snarkvm::console::prelude::{FromBytes, ToBytes};

use bytes::{Buf, BufMut, BytesMut};
use proptest::{
prelude::{BoxedStrategy, Strategy},
prop_oneof,
};
use proptest::prelude::{BoxedStrategy, Strategy};
use test_strategy::proptest;

type CurrentNetwork = snarkvm::prelude::MainnetV0;

fn any_transmission_id() -> BoxedStrategy<TransmissionID<CurrentNetwork>> {
prop_oneof![
any_solution_id().prop_map(TransmissionID::Solution),
any_transaction_id().prop_map(TransmissionID::Transaction),
]
.boxed()
}

pub fn any_transmission_request() -> BoxedStrategy<TransmissionRequest<CurrentNetwork>> {
any_transmission_id().prop_map(TransmissionRequest::new).boxed()
}
Expand Down
22 changes: 13 additions & 9 deletions node/bft/events/src/transmission_response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl<N: Network> FromBytes for TransmissionResponse<N> {
#[cfg(test)]
pub mod prop_tests {
use crate::{
prop_tests::{any_solution_id, any_transaction_id},
prop_tests::{any_solution_id, any_transaction_id, any_transmission_checksum},
TransmissionResponse,
};
use snarkvm::{
Expand All @@ -82,14 +82,18 @@ pub mod prop_tests {

pub fn any_transmission() -> BoxedStrategy<(TransmissionID<CurrentNetwork>, Transmission<CurrentNetwork>)> {
prop_oneof![
(any_solution_id(), collection::vec(any::<u8>(), 256..=256)).prop_map(|(pc, bytes)| (
TransmissionID::Solution(pc),
Transmission::Solution(Data::Buffer(Bytes::from(bytes)))
)),
(any_transaction_id(), collection::vec(any::<u8>(), 512..=512)).prop_map(|(tid, bytes)| (
TransmissionID::Transaction(tid),
Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))
)),
(any_solution_id(), any_transmission_checksum(), collection::vec(any::<u8>(), 256..=256)).prop_map(
|(pc, cs, bytes)| (
TransmissionID::Solution(pc, cs),
Transmission::Solution(Data::Buffer(Bytes::from(bytes)))
)
),
(any_transaction_id(), any_transmission_checksum(), collection::vec(any::<u8>(), 512..=512)).prop_map(
|(tid, cs, bytes)| (
TransmissionID::Transaction(tid, cs),
Transmission::Transaction(Data::Buffer(Bytes::from(bytes)))
)
),
]
.boxed()
}
Expand Down
35 changes: 30 additions & 5 deletions node/bft/ledger-service/src/ledger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,8 +205,8 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool> {
match transmission_id {
TransmissionID::Ratification => Ok(false),
TransmissionID::Solution(solution_id) => self.ledger.contains_solution_id(solution_id),
TransmissionID::Transaction(transaction_id) => self.ledger.contains_transaction_id(transaction_id),
TransmissionID::Solution(solution_id, _) => self.ledger.contains_solution_id(solution_id),
TransmissionID::Transaction(transaction_id, _) => self.ledger.contains_transaction_id(transaction_id),
}
}

Expand All @@ -218,7 +218,10 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
) -> Result<()> {
match (transmission_id, transmission) {
(TransmissionID::Ratification, Transmission::Ratification) => {}
(TransmissionID::Transaction(expected_transaction_id), Transmission::Transaction(transaction_data)) => {
(
TransmissionID::Transaction(expected_transaction_id, expected_checksum),
Transmission::Transaction(transaction_data),
) => {
// Deserialize the transaction. If the transaction exceeds the maximum size, then return an error.
let transaction = match transaction_data.clone() {
Data::Object(transaction) => transaction,
Expand All @@ -227,11 +230,21 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
// Ensure the transaction ID matches the expected transaction ID.
if transaction.id() != expected_transaction_id {
bail!(
"Received mismatching transaction ID - expected {}, found {}",
"Received mismatching transaction ID - expected {}, found {}",
fmt_id(expected_transaction_id),
fmt_id(transaction.id()),
);
}

// Ensure the transmission checksum matches the expected checksum.
let checksum = transaction_data.to_checksum::<N>()?;
if checksum != expected_checksum {
bail!(
"Received mismatching checksum for transaction {} - expected {expected_checksum} but found {checksum}",
fmt_id(expected_transaction_id)
);
}

// Ensure the transaction is not a fee transaction.
if transaction.is_fee() {
bail!("Received a fee transaction in a transmission");
Expand All @@ -240,7 +253,10 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
// Update the transmission with the deserialized transaction.
*transaction_data = Data::Object(transaction);
}
(TransmissionID::Solution(expected_solution_id), Transmission::Solution(solution_data)) => {
(
TransmissionID::Solution(expected_solution_id, expected_checksum),
Transmission::Solution(solution_data),
) => {
match solution_data.clone().deserialize_blocking() {
Ok(solution) => {
if solution.id() != expected_solution_id {
Expand All @@ -251,6 +267,15 @@ impl<N: Network, C: ConsensusStorage<N>> LedgerService<N> for CoreLedgerService<
);
}

// Ensure the transmission checksum matches the expected checksum.
let checksum = solution_data.to_checksum::<N>()?;
if checksum != expected_checksum {
bail!(
"Received mismatching checksum for solution {} - expected {expected_checksum} but found {checksum}",
fmt_id(expected_solution_id)
);
}

// Update the transmission with the deserialized solution.
*solution_data = Data::Object(solution);
}
Expand Down
12 changes: 10 additions & 2 deletions node/bft/ledger-service/src/mock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,11 @@ impl<N: Network> LedgerService<N> for MockLedgerService<N> {

/// Returns `false` for all queries.
fn contains_transmission(&self, transmission_id: &TransmissionID<N>) -> Result<bool> {
trace!("[MockLedgerService] Contains transmission ID {} - false", fmt_id(transmission_id));
trace!(
"[MockLedgerService] Contains transmission ID {}.{} - false",
fmt_id(transmission_id),
fmt_id(transmission_id.checksum().unwrap_or_default())
);
Ok(false)
}

Expand All @@ -179,7 +183,11 @@ impl<N: Network> LedgerService<N> for MockLedgerService<N> {
transmission_id: TransmissionID<N>,
_transmission: &mut Transmission<N>,
) -> Result<()> {
trace!("[MockLedgerService] Ensure transmission ID matches {:?} - Ok", fmt_id(transmission_id));
trace!(
"[MockLedgerService] Ensure transmission ID matches {}.{} - Ok",
fmt_id(transmission_id),
fmt_id(transmission_id.checksum().unwrap_or_default())
);
Ok(())
}

Expand Down
37 changes: 36 additions & 1 deletion node/bft/src/bft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -597,10 +597,34 @@ impl<N: Network> BFT<N> {
if !IS_SYNCING {
// Initialize a map for the deduped transmissions.
let mut transmissions = IndexMap::new();
// Initialize a map for the deduped transaction ids.
let mut seen_transaction_ids = IndexSet::new();
// Initialize a map for the deduped solution ids.
let mut seen_solution_ids = IndexSet::new();
// Start from the oldest leader certificate.
for certificate in commit_subdag.values().flatten() {
// Retrieve the transmissions.
for transmission_id in certificate.transmission_ids() {
// If the transaction ID or solution ID already exists in the map, skip it.
// Note: This additional check is done to ensure that we do not include duplicate
// transaction IDs or solution IDs that may have a different transmission ID.
match transmission_id {
TransmissionID::Solution(solution_id, _) => {
// If the solution already exists, skip it.
if seen_solution_ids.contains(&solution_id) {
continue;
}
}
TransmissionID::Transaction(transaction_id, _) => {
// If the transaction already exists, skip it.
if seen_transaction_ids.contains(transaction_id) {
continue;
}
}
TransmissionID::Ratification => {
bail!("Ratifications are currently not supported in the BFT.")
}
}
// If the transmission already exists in the map, skip it.
if transmissions.contains_key(transmission_id) {
continue;
Expand All @@ -613,11 +637,22 @@ impl<N: Network> BFT<N> {
// Retrieve the transmission.
let Some(transmission) = self.storage().get_transmission(*transmission_id) else {
bail!(
"BFT failed to retrieve transmission '{}' from round {}",
"BFT failed to retrieve transmission '{}.{}' from round {}",
fmt_id(transmission_id),
fmt_id(transmission_id.checksum().unwrap_or_default()).dimmed(),
certificate.round()
);
};
// Insert the transaction ID or solution ID into the map.
match transmission_id {
TransmissionID::Solution(id, _) => {
seen_solution_ids.insert(id);
}
TransmissionID::Transaction(id, _) => {
seen_transaction_ids.insert(id);
}
TransmissionID::Ratification => {}
}
// Add the transmission to the set.
transmissions.insert(*transmission_id, transmission);
}
Expand Down
2 changes: 1 addition & 1 deletion node/bft/src/helpers/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ mod tests {

impl Input for TransmissionID<CurrentNetwork> {
fn input() -> Self {
TransmissionID::Transaction(Default::default())
TransmissionID::Transaction(Default::default(), Default::default())
}
}

Expand Down
5 changes: 3 additions & 2 deletions node/bft/src/helpers/partition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ mod tests {
]);
let hash = sha256d_to_u128(data);
assert_eq!(hash, 274520597840828436951879875061540363633u128);
let transmission_id: TransmissionID<CurrentNetwork> = TransmissionID::Solution(SolutionID::from(123456789));
let transmission_id: TransmissionID<CurrentNetwork> =
TransmissionID::Solution(SolutionID::from(123456789), 12345);
let worker_id = assign_to_worker(transmission_id, 5).unwrap();
assert_eq!(worker_id, 2);
assert_eq!(worker_id, 4);
}
}
45 changes: 36 additions & 9 deletions node/bft/src/helpers/pending.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,22 @@ mod tests {
assert_eq!(pending.len(), 0);

// Initialize the solution IDs.
let solution_id_1 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_2 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_3 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_4 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_1 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_2 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_3 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_4 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);

// Initialize the SocketAddrs.
let addr_1 = SocketAddr::from(([127, 0, 0, 1], 1234));
Expand Down Expand Up @@ -303,7 +315,10 @@ mod tests {
assert!(pending.contains_peer(solution_id_4, addr_4));
assert!(!pending.contains_peer_with_sent_request(solution_id_4, addr_4));

let unknown_id = TransmissionID::Solution(rng.gen::<u64>().into());
let unknown_id = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
assert!(!pending.contains(unknown_id));

// Check get.
Expand Down Expand Up @@ -336,7 +351,10 @@ mod tests {
assert_eq!(pending.len(), 0);

// Initialize the solution ID.
let solution_id_1 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_1 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);

// Initialize the SocketAddrs.
let addr_1 = SocketAddr::from(([127, 0, 0, 1], 1234));
Expand Down Expand Up @@ -382,7 +400,10 @@ mod tests {

for _ in 0..ITERATIONS {
// Generate a solution ID.
let solution_id = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
// Check if the number of sent requests is correct.
let mut expected_num_sent_requests = 0;
for i in 0..ITERATIONS {
Expand Down Expand Up @@ -416,8 +437,14 @@ mod tests {
assert_eq!(pending.len(), 0);

// Initialize the solution IDs.
let solution_id_1 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_2 = TransmissionID::Solution(rng.gen::<u64>().into());
let solution_id_1 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);
let solution_id_2 = TransmissionID::Solution(
rng.gen::<u64>().into(),
rng.gen::<<CurrentNetwork as Network>::TransmissionChecksum>(),
);

// Initialize the SocketAddrs.
let addr_1 = SocketAddr::from(([127, 0, 0, 1], 1234));
Expand Down
Loading

0 comments on commit b0a80c3

Please sign in to comment.