From d80360266af57abfc2212071d51c5c8c5767fa94 Mon Sep 17 00:00:00 2001 From: Thoralf-M <46689931+Thoralf-M@users.noreply.github.com> Date: Mon, 22 Jan 2024 18:10:05 +0100 Subject: [PATCH] Improve ISA with many native tokens in inputs (#1784) * Improve ISA with many native tokens in inputs * Cleanup * Changelog entry * Update sdk/src/client/api/block_builder/input_selection/core/requirement/amount.rs * Fmt --------- Co-authored-by: Thibault Martinez --- sdk/CHANGELOG.md | 1 + .../core/requirement/amount.rs | 50 +++++++++++++- .../client/input_selection/native_tokens.rs | 69 ++++++++++++++++++- 3 files changed, 116 insertions(+), 4 deletions(-) diff --git a/sdk/CHANGELOG.md b/sdk/CHANGELOG.md index dd798e3418..af016b5a17 100644 --- a/sdk/CHANGELOG.md +++ b/sdk/CHANGELOG.md @@ -34,6 +34,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Ed25519Signature` type no longer requires validated public key bytes to construct; - `SelfControlledAliasOutput` and `SelfDepositNft` conditions; +- Input selection for > max native tokens; ## 1.1.3 - 2023-12-07 diff --git a/sdk/src/client/api/block_builder/input_selection/core/requirement/amount.rs b/sdk/src/client/api/block_builder/input_selection/core/requirement/amount.rs index 2ab4ac21b5..381f3738f9 100644 --- a/sdk/src/client/api/block_builder/input_selection/core/requirement/amount.rs +++ b/sdk/src/client/api/block_builder/input_selection/core/requirement/amount.rs @@ -1,7 +1,7 @@ // Copyright 2023 IOTA Stiftung // SPDX-License-Identifier: Apache-2.0 -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use super::{Error, InputSelection, Requirement}; use crate::{ @@ -11,7 +11,7 @@ use crate::{ input::INPUT_COUNT_MAX, output::{ unlock_condition::StorageDepositReturnUnlockCondition, AliasOutputBuilder, AliasTransition, - FoundryOutputBuilder, NftOutputBuilder, Output, OutputId, Rent, + FoundryOutputBuilder, NativeTokens, NftOutputBuilder, Output, OutputId, Rent, TokenId, }, }, }; @@ -81,6 +81,7 @@ struct AmountSelection { remainder_amount: u64, native_tokens_remainder: bool, timestamp: u32, + selected_native_tokens: HashSet, } impl AmountSelection { @@ -90,6 +91,17 @@ impl AmountSelection { &input_selection.outputs, input_selection.timestamp, ); + let selected_native_tokens = HashSet::::from_iter( + input_selection + .selected_inputs + .iter() + .flat_map(|i| { + i.output + .native_tokens() + .map(|n| n.iter().copied().map(|n| *n.token_id()).collect::>()) + }) + .flatten(), + ); let (remainder_amount, native_tokens_remainder) = input_selection.remainder_amount()?; Ok(Self { @@ -101,6 +113,7 @@ impl AmountSelection { remainder_amount, native_tokens_remainder, timestamp: input_selection.timestamp, + selected_native_tokens, }) } @@ -146,6 +159,19 @@ impl AmountSelection { *self.inputs_sdr.entry(*sdruc.return_address()).or_default() += sdruc.amount(); } + if let Some(nt) = input.output.native_tokens() { + let mut selected_native_tokens = self.selected_native_tokens.clone(); + + selected_native_tokens.extend(nt.iter().map(|t| t.token_id())); + // Don't select input if the tx would end up with more than allowed native tokens. + if selected_native_tokens.len() > NativeTokens::COUNT_MAX.into() { + continue; + } else { + // Update selected with NTs from this output. + self.selected_native_tokens = selected_native_tokens; + } + } + self.inputs_sum += input.output.amount(); self.newly_selected_inputs .insert(*input.output_id(), (input.clone(), None)); @@ -308,7 +334,25 @@ impl InputSelection { return Ok(r); } - if self.selected_inputs.len() + amount_selection.newly_selected_inputs.len() > INPUT_COUNT_MAX.into() { + // If the available inputs have more NTs than are allowed in a single tx, we might not be able to find inputs + // without exceeding the threshold, so in this case we also try again with the outputs ordered the other way + // around. + let potentially_too_many_native_tokens = HashSet::::from_iter( + self.available_inputs + .iter() + .flat_map(|i| { + i.output + .native_tokens() + .map(|n| n.iter().copied().map(|n| *n.token_id()).collect::>()) + }) + .flatten(), + ) + .len() + > NativeTokens::COUNT_MAX.into(); + + if self.selected_inputs.len() + amount_selection.newly_selected_inputs.len() > INPUT_COUNT_MAX.into() + || potentially_too_many_native_tokens + { // Clear before trying with reversed ordering. log::debug!("Clearing amount selection"); amount_selection = AmountSelection::new(self)?; diff --git a/sdk/tests/client/input_selection/native_tokens.rs b/sdk/tests/client/input_selection/native_tokens.rs index 3d770dbce3..a08dbc9331 100644 --- a/sdk/tests/client/input_selection/native_tokens.rs +++ b/sdk/tests/client/input_selection/native_tokens.rs @@ -5,7 +5,7 @@ use std::str::FromStr; use iota_sdk::{ client::api::input_selection::{Burn, Error, InputSelection}, - types::block::{output::TokenId, protocol::protocol_parameters}, + types::block::{output::TokenId, protocol::protocol_parameters, rand::bytes::rand_bytes_array}, }; use pretty_assertions::assert_eq; use primitive_types::U256; @@ -1468,6 +1468,73 @@ fn two_basic_outputs_native_tokens_not_needed() { )); } +#[test] +fn higher_nts_count_but_below_max_native_tokens() { + let protocol_parameters = protocol_parameters(); + + let mut input_native_tokens_0 = Vec::new(); + for _ in 0..10 { + input_native_tokens_0.push((TokenId::from(rand_bytes_array()).to_string(), 10)); + } + let mut input_native_tokens_1 = Vec::new(); + for _ in 0..64 { + input_native_tokens_1.push((TokenId::from(rand_bytes_array()).to_string(), 10)); + } + + let inputs = build_inputs([ + Basic( + 3_000_000, + BECH32_ADDRESS_ED25519_0, + Some(input_native_tokens_0.iter().map(|(t, a)| (t.as_str(), *a)).collect()), + None, + None, + None, + None, + None, + ), + Basic( + 10_000_000, + BECH32_ADDRESS_ED25519_0, + Some(input_native_tokens_1.iter().map(|(t, a)| (t.as_str(), *a)).collect()), + None, + None, + None, + None, + None, + ), + ]); + let outputs = build_outputs([Basic( + 5_000_000, + BECH32_ADDRESS_ED25519_0, + None, + None, + None, + None, + None, + None, + )]); + + let selected = InputSelection::new( + inputs.clone(), + outputs.clone(), + addresses([BECH32_ADDRESS_ED25519_0]), + protocol_parameters, + ) + .select() + .unwrap(); + + assert_eq!(selected.inputs.len(), 1); + assert!(selected.inputs.contains(&inputs[1])); + assert_eq!(selected.outputs.len(), 2); + assert!(selected.outputs.contains(&outputs[0])); + assert!(is_remainder_or_return( + &selected.outputs[1], + 5_000_000, + BECH32_ADDRESS_ED25519_0, + Some(input_native_tokens_1.iter().map(|(t, a)| (t.as_str(), *a)).collect()), + )); +} + // T27: :wavy_dash: // inputs: [basic{ amount: 1_000_000, native_tokens: [{‘a’: 100}] }, basic{ amount: 1_000_000, native_tokens: [{‘a’: // 200}] }] }] outputs: [basic{ amount: 500_000, native_tokens: [{‘a’: 150}] }]