From a19fc19ebc1dfcad06d3197a5f8a8e5e6411e30d Mon Sep 17 00:00:00 2001 From: Thibault Martinez Date: Tue, 6 Feb 2024 21:09:56 +0100 Subject: [PATCH] Add ContextInputs type (#1956) * Add ContextInputs type * nits * Nits * Add block_issuance_credits and rewards iterators * Use the new methods * Clippy * no_std --- sdk/src/client/secret/ledger_nano.rs | 4 +- sdk/src/client/secret/mod.rs | 4 +- sdk/src/types/block/context_input/mod.rs | 91 ++++++++++++++++++- sdk/src/types/block/error.rs | 4 +- sdk/src/types/block/payload/mod.rs | 2 +- .../block/payload/signed_transaction/mod.rs | 2 +- .../payload/signed_transaction/transaction.rs | 56 +----------- sdk/src/types/block/semantic/error.rs | 2 +- sdk/src/types/block/semantic/mod.rs | 14 +-- sdk/src/types/block/semantic/unlock.rs | 12 +-- 10 files changed, 107 insertions(+), 84 deletions(-) diff --git a/sdk/src/client/secret/ledger_nano.rs b/sdk/src/client/secret/ledger_nano.rs index 2bfe919bcf..8daff56bb2 100644 --- a/sdk/src/client/secret/ledger_nano.rs +++ b/sdk/src/client/secret/ledger_nano.rs @@ -526,8 +526,8 @@ fn merge_unlocks( let slot_index = prepared_transaction_data .transaction .context_inputs() - .iter() - .find_map(|c| c.as_commitment_opt().map(|c| c.slot_index())); + .commitment() + .map(|c| c.slot_index()); let transaction_signing_hash = prepared_transaction_data.transaction.signing_hash(); let mut merged_unlocks = Vec::new(); diff --git a/sdk/src/client/secret/mod.rs b/sdk/src/client/secret/mod.rs index d078c78373..be43098461 100644 --- a/sdk/src/client/secret/mod.rs +++ b/sdk/src/client/secret/mod.rs @@ -556,8 +556,8 @@ where let slot_index = prepared_transaction_data .transaction .context_inputs() - .iter() - .find_map(|c| c.as_commitment_opt().map(|c| c.slot_index())); + .commitment() + .map(|c| c.slot_index()); // Assuming inputs_data is ordered by address type for (current_block_index, input) in prepared_transaction_data.inputs_data.iter().enumerate() { diff --git a/sdk/src/types/block/context_input/mod.rs b/sdk/src/types/block/context_input/mod.rs index 09ee7f9c9a..f172c8b93f 100644 --- a/sdk/src/types/block/context_input/mod.rs +++ b/sdk/src/types/block/context_input/mod.rs @@ -5,9 +5,12 @@ mod block_issuance_credit; mod commitment; mod reward; -use core::ops::RangeInclusive; +use alloc::{boxed::Box, vec::Vec}; +use core::{cmp::Ordering, ops::RangeInclusive}; -use derive_more::{Display, From}; +use derive_more::{Deref, Display, From}; +use iterator_sorted::is_unique_sorted_by; +use packable::{bounded::BoundedU16, prefix::BoxedSlicePrefix, Packable}; pub(crate) use self::reward::RewardContextInputIndex; pub use self::{ @@ -75,6 +78,90 @@ impl core::fmt::Debug for ContextInput { } } +pub(crate) type ContextInputCount = + BoundedU16<{ *CONTEXT_INPUT_COUNT_RANGE.start() }, { *CONTEXT_INPUT_COUNT_RANGE.end() }>; + +#[derive(Clone, Debug, Eq, PartialEq, Ord, PartialOrd, Hash, Deref, Packable)] +#[packable(unpack_error = Error, with = |e| e.unwrap_item_err_or_else(|p| Error::InvalidContextInputCount(p.into())))] +pub struct ContextInputs( + #[packable(verify_with = verify_context_inputs_packable)] BoxedSlicePrefix, +); + +impl TryFrom> for ContextInputs { + type Error = Error; + + #[inline(always)] + fn try_from(features: Vec) -> Result { + Self::from_vec(features) + } +} + +impl IntoIterator for ContextInputs { + type Item = ContextInput; + type IntoIter = alloc::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + Vec::from(Into::>::into(self.0)).into_iter() + } +} + +impl ContextInputs { + /// Creates a new [`ContextInputs`] from a vec. + pub fn from_vec(features: Vec) -> Result { + let mut context_inputs = + BoxedSlicePrefix::::try_from(features.into_boxed_slice()) + .map_err(Error::InvalidContextInputCount)?; + + context_inputs.sort_by(context_inputs_cmp); + // Sort is obviously fine now but uniqueness still needs to be checked. + verify_context_inputs(&context_inputs)?; + + Ok(Self(context_inputs)) + } + + /// Gets a reference to a [`CommitmentContextInput`], if any. + pub fn commitment(&self) -> Option<&CommitmentContextInput> { + self.0.iter().find_map(|c| c.as_commitment_opt()) + } + + /// Returns an iterator over [`BlockIssuanceCreditContextInput`], if any. + pub fn block_issuance_credits(&self) -> impl Iterator { + self.iter().filter_map(|c| c.as_block_issuance_credit_opt()) + } + + /// Returns an iterator over [`RewardContextInput`], if any. + pub fn rewards(&self) -> impl Iterator { + self.iter().filter_map(|c| c.as_reward_opt()) + } +} + +fn verify_context_inputs_packable(context_inputs: &[ContextInput]) -> Result<(), Error> { + if VERIFY { + verify_context_inputs(context_inputs)?; + } + Ok(()) +} + +fn context_inputs_cmp(a: &ContextInput, b: &ContextInput) -> Ordering { + a.kind().cmp(&b.kind()).then_with(|| match (a, b) { + (ContextInput::Commitment(_), ContextInput::Commitment(_)) => Ordering::Equal, + (ContextInput::BlockIssuanceCredit(a), ContextInput::BlockIssuanceCredit(b)) => { + a.account_id().cmp(b.account_id()) + } + (ContextInput::Reward(a), ContextInput::Reward(b)) => a.index().cmp(&b.index()), + // No need to evaluate all combinations as `then_with` is only called on Equal. + _ => unreachable!(), + }) +} + +fn verify_context_inputs(context_inputs: &[ContextInput]) -> Result<(), Error> { + if !is_unique_sorted_by(context_inputs.iter(), |a, b| context_inputs_cmp(a, b)) { + return Err(Error::ContextInputsNotUniqueSorted); + } + + Ok(()) +} + #[cfg(test)] mod tests { use pretty_assertions::assert_eq; diff --git a/sdk/src/types/block/error.rs b/sdk/src/types/block/error.rs index ad9174341c..5224424c40 100644 --- a/sdk/src/types/block/error.rs +++ b/sdk/src/types/block/error.rs @@ -15,7 +15,7 @@ use primitive_types::U256; use super::slot::EpochIndex; use crate::types::block::{ address::{AddressCapabilityFlag, WeightedAddressCount}, - context_input::RewardContextInputIndex, + context_input::{ContextInputCount, RewardContextInputIndex}, input::UtxoInput, mana::ManaAllotmentCount, output::{ @@ -26,7 +26,7 @@ use crate::types::block::{ }, payload::{ tagged_data::{TagLength, TaggedDataLength}, - ContextInputCount, InputCount, OutputCount, + InputCount, OutputCount, }, protocol::ProtocolParametersHash, unlock::{UnlockCount, UnlockIndex, UnlocksCount}, diff --git a/sdk/src/types/block/payload/mod.rs b/sdk/src/types/block/payload/mod.rs index f72f633450..4f1ba4c336 100644 --- a/sdk/src/types/block/payload/mod.rs +++ b/sdk/src/types/block/payload/mod.rs @@ -18,7 +18,7 @@ use packable::{ Packable, PackableExt, }; -pub(crate) use self::signed_transaction::{ContextInputCount, InputCount, OutputCount}; +pub(crate) use self::signed_transaction::{InputCount, OutputCount}; pub use self::{ candidacy_announcement::CandidacyAnnouncementPayload, signed_transaction::SignedTransactionPayload, tagged_data::TaggedDataPayload, diff --git a/sdk/src/types/block/payload/signed_transaction/mod.rs b/sdk/src/types/block/payload/signed_transaction/mod.rs index 39af5185a6..262315b4af 100644 --- a/sdk/src/types/block/payload/signed_transaction/mod.rs +++ b/sdk/src/types/block/payload/signed_transaction/mod.rs @@ -8,7 +8,7 @@ mod transaction_id; use packable::{Packable, PackableExt}; -pub(crate) use self::transaction::{ContextInputCount, InputCount, OutputCount}; +pub(crate) use self::transaction::{InputCount, OutputCount}; pub use self::{ transaction::{Transaction, TransactionBuilder, TransactionCapabilities, TransactionCapabilityFlag}, transaction_id::{TransactionHash, TransactionId, TransactionSigningHash}, diff --git a/sdk/src/types/block/payload/signed_transaction/transaction.rs b/sdk/src/types/block/payload/signed_transaction/transaction.rs index c66e5d2d81..85f274ff6e 100644 --- a/sdk/src/types/block/payload/signed_transaction/transaction.rs +++ b/sdk/src/types/block/payload/signed_transaction/transaction.rs @@ -2,17 +2,15 @@ // SPDX-License-Identifier: Apache-2.0 use alloc::{collections::BTreeSet, vec::Vec}; -use core::cmp::Ordering; use crypto::hashes::{blake2b::Blake2b256, Digest}; use hashbrown::HashSet; -use iterator_sorted::is_unique_sorted_by; use packable::{bounded::BoundedU16, prefix::BoxedSlicePrefix, Packable, PackableExt}; use crate::{ types::block::{ capabilities::{Capabilities, CapabilityFlag}, - context_input::{ContextInput, CONTEXT_INPUT_COUNT_RANGE}, + context_input::{ContextInput, ContextInputs}, input::{Input, INPUT_COUNT_RANGE}, mana::{verify_mana_allotments_sum, ManaAllotment, ManaAllotments}, output::{Output, OUTPUT_COUNT_RANGE}, @@ -128,7 +126,7 @@ impl TransactionBuilder { /// Finishes a [`TransactionBuilder`] into a [`Transaction`]. pub fn finish_with_params<'a>( - mut self, + self, params: impl Into>, ) -> Result { let params = params.into(); @@ -160,16 +158,6 @@ impl TransactionBuilder { }) .ok_or(Error::InvalidField("creation slot"))?; - self.context_inputs.sort_by(context_inputs_cmp); - - let context_inputs: BoxedSlicePrefix = self - .context_inputs - .into_boxed_slice() - .try_into() - .map_err(Error::InvalidContextInputCount)?; - - verify_context_inputs(&context_inputs)?; - let inputs: BoxedSlicePrefix = self .inputs .into_boxed_slice() @@ -199,7 +187,7 @@ impl TransactionBuilder { Ok(Transaction { network_id: self.network_id, creation_slot, - context_inputs, + context_inputs: ContextInputs::from_vec(self.context_inputs)?, inputs, allotments, capabilities: self.capabilities, @@ -215,8 +203,6 @@ impl TransactionBuilder { } } -pub(crate) type ContextInputCount = - BoundedU16<{ *CONTEXT_INPUT_COUNT_RANGE.start() }, { *CONTEXT_INPUT_COUNT_RANGE.end() }>; pub(crate) type InputCount = BoundedU16<{ *INPUT_COUNT_RANGE.start() }, { *INPUT_COUNT_RANGE.end() }>; pub(crate) type OutputCount = BoundedU16<{ *OUTPUT_COUNT_RANGE.start() }, { *OUTPUT_COUNT_RANGE.end() }>; @@ -230,9 +216,7 @@ pub struct Transaction { network_id: u64, /// The slot index in which the transaction was created. creation_slot: SlotIndex, - #[packable(verify_with = verify_context_inputs_packable)] - #[packable(unpack_error_with = |e| e.unwrap_item_err_or_else(|p| Error::InvalidContextInputCount(p.into())))] - context_inputs: BoxedSlicePrefix, + context_inputs: ContextInputs, #[packable(verify_with = verify_inputs_packable)] #[packable(unpack_error_with = |e| e.unwrap_item_err_or_else(|p| Error::InvalidInputCount(p.into())))] inputs: BoxedSlicePrefix, @@ -262,7 +246,7 @@ impl Transaction { } /// Returns the context inputs of a [`Transaction`]. - pub fn context_inputs(&self) -> &[ContextInput] { + pub fn context_inputs(&self) -> &ContextInputs { &self.context_inputs } @@ -359,36 +343,6 @@ fn verify_network_id(network_id: &u64, visitor: &ProtocolPar Ok(()) } -fn verify_context_inputs_packable( - context_inputs: &[ContextInput], - _visitor: &ProtocolParameters, -) -> Result<(), Error> { - if VERIFY { - verify_context_inputs(context_inputs)?; - } - Ok(()) -} - -fn context_inputs_cmp(a: &ContextInput, b: &ContextInput) -> Ordering { - a.kind().cmp(&b.kind()).then_with(|| match (a, b) { - (ContextInput::Commitment(_), ContextInput::Commitment(_)) => Ordering::Equal, - (ContextInput::BlockIssuanceCredit(a), ContextInput::BlockIssuanceCredit(b)) => { - a.account_id().cmp(b.account_id()) - } - (ContextInput::Reward(a), ContextInput::Reward(b)) => a.index().cmp(&b.index()), - // No need to evaluate all combinations as `then_with` is only called on Equal. - _ => unreachable!(), - }) -} - -fn verify_context_inputs(context_inputs: &[ContextInput]) -> Result<(), Error> { - if !is_unique_sorted_by(context_inputs.iter(), |a, b| context_inputs_cmp(a, b)) { - return Err(Error::ContextInputsNotUniqueSorted); - } - - Ok(()) -} - fn verify_inputs(inputs: &[Input]) -> Result<(), Error> { let mut seen_utxos = HashSet::new(); diff --git a/sdk/src/types/block/semantic/error.rs b/sdk/src/types/block/semantic/error.rs index 9ce2cd6332..f627f197dc 100644 --- a/sdk/src/types/block/semantic/error.rs +++ b/sdk/src/types/block/semantic/error.rs @@ -236,6 +236,6 @@ impl TryFrom for TransactionFailureReason { type Error = Error; fn try_from(c: u8) -> Result { - TransactionFailureReason::from_repr(c).ok_or(Self::Error::InvalidTransactionFailureReason(c)) + Self::from_repr(c).ok_or(Self::Error::InvalidTransactionFailureReason(c)) } } diff --git a/sdk/src/types/block/semantic/mod.rs b/sdk/src/types/block/semantic/mod.rs index 8f0d0a5302..ac30255a8e 100644 --- a/sdk/src/types/block/semantic/mod.rs +++ b/sdk/src/types/block/semantic/mod.rs @@ -108,12 +108,7 @@ impl<'a> SemanticValidationContext<'a> { // Validation of inputs. let mut has_implicit_account_creation_address = false; - self.commitment_context_input = self - .transaction - .context_inputs() - .iter() - .find_map(|c| c.as_commitment_opt()) - .copied(); + self.commitment_context_input = self.transaction.context_inputs().commitment().copied(); self.bic_context_input = self .transaction @@ -122,12 +117,7 @@ impl<'a> SemanticValidationContext<'a> { .find_map(|c| c.as_block_issuance_credit_opt()) .copied(); - for reward_context_input in self - .transaction - .context_inputs() - .iter() - .filter_map(|c| c.as_reward_opt()) - { + for reward_context_input in self.transaction.context_inputs().rewards() { if let Some(output_id) = self.inputs.get(reward_context_input.index() as usize).map(|v| v.0) { self.reward_context_inputs.insert(*output_id, *reward_context_input); } else { diff --git a/sdk/src/types/block/semantic/unlock.rs b/sdk/src/types/block/semantic/unlock.rs index 4adcaadafa..f6ff3ed2cc 100644 --- a/sdk/src/types/block/semantic/unlock.rs +++ b/sdk/src/types/block/semantic/unlock.rs @@ -110,11 +110,7 @@ impl SemanticValidationContext<'_> { ) -> Result<(), TransactionFailureReason> { match output { Output::Basic(output) => { - let slot_index = self - .transaction - .context_inputs() - .iter() - .find_map(|c| c.as_commitment_opt().map(|c| c.slot_index())); + let slot_index = self.transaction.context_inputs().commitment().map(|c| c.slot_index()); let locked_address = output .unlock_conditions() .locked_address( @@ -144,11 +140,7 @@ impl SemanticValidationContext<'_> { // Output::Anchor(_) => return Err(Error::UnsupportedOutputKind(AnchorOutput::KIND)), Output::Foundry(output) => self.address_unlock(&Address::from(*output.account_address()), unlock)?, Output::Nft(output) => { - let slot_index = self - .transaction - .context_inputs() - .iter() - .find_map(|c| c.as_commitment_opt().map(|c| c.slot_index())); + let slot_index = self.transaction.context_inputs().commitment().map(|c| c.slot_index()); let locked_address = output .unlock_conditions() .locked_address(