diff --git a/iot_packet_verifier/src/balances.rs b/iot_packet_verifier/src/balances.rs index 11fa7d396..ebe0a21fa 100644 --- a/iot_packet_verifier/src/balances.rs +++ b/iot_packet_verifier/src/balances.rs @@ -5,17 +5,20 @@ use crate::{ use futures_util::StreamExt; use helium_crypto::PublicKeyBinary; use solana::SolanaNetwork; -use std::{collections::HashMap, sync::Arc}; +use std::{ + collections::{hash_map::Entry, HashMap}, + sync::Arc, +}; use tokio::sync::Mutex; /// Caches balances fetched from the solana chain and debits made by the /// packet verifier. pub struct BalanceCache { - balances: BalanceStore, + payer_accounts: BalanceStore, solana: S, } -pub type BalanceStore = Arc>>; +pub type BalanceStore = Arc>>; impl BalanceCache where @@ -40,7 +43,7 @@ where let balance = solana.payer_balance(&payer).await?; balances.insert( payer, - Balance { + PayerAccount { burned: burn_amount as u64, balance, }, @@ -48,7 +51,7 @@ where } Ok(Self { - balances: Arc::new(Mutex::new(balances)), + payer_accounts: Arc::new(Mutex::new(balances)), solana, }) } @@ -56,7 +59,7 @@ where impl BalanceCache { pub fn balances(&self) -> BalanceStore { - self.balances.clone() + self.payer_accounts.clone() } } @@ -73,40 +76,44 @@ where &self, payer: &PublicKeyBinary, amount: u64, + trigger_balance_check_threshold: u64, ) -> Result, S::Error> { - let mut balances = self.balances.lock().await; + let mut payer_accounts = self.payer_accounts.lock().await; - let balance = if !balances.contains_key(payer) { - let new_balance = self.solana.payer_balance(payer).await?; - balances.insert(payer.clone(), Balance::new(new_balance)); - balances.get_mut(payer).unwrap() - } else { - let balance = balances.get_mut(payer).unwrap(); + // Fetch the balance if we haven't seen the payer before + if let Entry::Vacant(payer_account) = payer_accounts.entry(payer.clone()) { + let payer_account = + payer_account.insert(PayerAccount::new(self.solana.payer_balance(payer).await?)); + return Ok((payer_account.balance >= amount).then(|| { + payer_account.burned += amount; + payer_account.balance - amount + })); + } - // If the balance is not sufficient, check to see if it has been increased - if balance.balance < amount + balance.burned { - balance.balance = self.solana.payer_balance(payer).await?; + let payer_account = payer_accounts.get_mut(payer).unwrap(); + match payer_account + .balance + .checked_sub(amount + payer_account.burned) + { + Some(remaining_balance) => { + if remaining_balance < trigger_balance_check_threshold { + payer_account.balance = self.solana.payer_balance(payer).await?; + } + payer_account.burned += amount; + Ok(Some(payer_account.balance - payer_account.burned)) } - - balance - }; - - Ok(if balance.balance >= amount + balance.burned { - balance.burned += amount; - Some(balance.balance - balance.burned) - } else { - None - }) + None => Ok(None), + } } } #[derive(Copy, Clone, Debug, Default)] -pub struct Balance { +pub struct PayerAccount { pub balance: u64, pub burned: u64, } -impl Balance { +impl PayerAccount { pub fn new(balance: u64) -> Self { Self { balance, burned: 0 } } diff --git a/iot_packet_verifier/src/burner.rs b/iot_packet_verifier/src/burner.rs index dd73256b4..bb68af136 100644 --- a/iot_packet_verifier/src/burner.rs +++ b/iot_packet_verifier/src/burner.rs @@ -75,18 +75,22 @@ where .await .map_err(BurnError::SolanaError)?; - // Now that we have successfully executed the burn and are no long in - // sync land, we can remove the amount burned. + // Now that we have successfully executed the burn and are no longer in + // sync land, we can remove the amount burned: self.pending_burns .subtract_burned_amount(&payer, amount) .await .map_err(BurnError::SqlError)?; let mut balance_lock = self.balances.lock().await; - let balances = balance_lock.get_mut(&payer).unwrap(); - balances.burned -= amount; - // Zero the balance in order to force a reset: - balances.balance = 0; + let payer_account = balance_lock.get_mut(&payer).unwrap(); + payer_account.burned -= amount; + // Reset the balance of the payer: + payer_account.balance = self + .solana + .payer_balance(&payer) + .await + .map_err(BurnError::SolanaError)?; metrics::counter!("burned", amount, "payer" => payer.to_string()); diff --git a/iot_packet_verifier/src/daemon.rs b/iot_packet_verifier/src/daemon.rs index 51d8e784d..59a1e2a09 100644 --- a/iot_packet_verifier/src/daemon.rs +++ b/iot_packet_verifier/src/daemon.rs @@ -2,7 +2,7 @@ use crate::{ balances::BalanceCache, burner::Burner, settings::Settings, - verifier::{ConfigServer, Verifier}, + verifier::{CachedOrgClient, ConfigServer, Verifier}, }; use anyhow::{bail, Error, Result}; use file_store::{ @@ -21,11 +21,10 @@ use tokio::{ signal, sync::{mpsc::Receiver, Mutex}, }; -use tracing::debug; struct Daemon { pool: Pool, - verifier: Verifier>>, Arc>>, + verifier: Verifier>>, Arc>>, report_files: Receiver>, valid_packets: FileSinkClient, invalid_packets: FileSinkClient, @@ -69,9 +68,7 @@ impl Daemon { &self.invalid_packets, ) .await?; - debug!("Committing transaction"); transaction.commit().await?; - debug!("Committing files"); self.valid_packets.commit().await?; self.invalid_packets.commit().await?; @@ -159,9 +156,9 @@ impl Cmd { .create() .await?; - let org_client = Arc::new(Mutex::new(OrgClient::from_settings( + let org_client = Arc::new(Mutex::new(CachedOrgClient::new(OrgClient::from_settings( &settings.iot_config_client, - )?)); + )?))); let file_store = FileStore::from_settings(&settings.ingest).await?; diff --git a/iot_packet_verifier/src/verifier.rs b/iot_packet_verifier/src/verifier.rs index a95f32da1..34f381442 100644 --- a/iot_packet_verifier/src/verifier.rs +++ b/iot_packet_verifier/src/verifier.rs @@ -19,7 +19,6 @@ use tokio::{ task::JoinError, time::{sleep_until, Duration, Instant}, }; -use tracing::debug; pub struct Verifier { pub debiter: D, @@ -65,32 +64,25 @@ where tokio::pin!(reports); while let Some(report) = reports.next().await { - debug!(%report.received_timestamp, "Processing packet report"); - let debit_amount = payload_size_to_dc(report.payload_size as u64); - debug!(%report.oui, "Fetching payer"); let payer = self .config_server .fetch_org(report.oui, &mut org_cache) .await .map_err(VerificationError::ConfigError)?; - debug!(%payer, "Debiting payer"); - let remaining_balance = self + + if let Some(remaining_balance) = self .debiter - .debit_if_sufficient(&payer, debit_amount) + .debit_if_sufficient(&payer, debit_amount, minimum_allowed_balance) .await - .map_err(VerificationError::DebitError)?; - - if let Some(remaining_balance) = remaining_balance { - debug!(%debit_amount, "Adding debit amount to pending burns"); - + .map_err(VerificationError::DebitError)? + { pending_burns .add_burned_amount(&payer, debit_amount) .await .map_err(VerificationError::BurnError)?; - debug!("Writing valid packet report"); valid_packets .write(ValidPacket { packet_timestamp: report.timestamp(), @@ -103,14 +95,12 @@ where .map_err(VerificationError::ValidPacketWriterError)?; if remaining_balance < minimum_allowed_balance { - debug!(%report.oui, "Disabling org"); self.config_server .disable_org(report.oui) .await .map_err(VerificationError::ConfigError)?; } } else { - debug!("Writing invalid packet report"); invalid_packets .write(InvalidPacket { payload_size: report.payload_size, @@ -145,6 +135,7 @@ pub trait Debiter { &self, payer: &PublicKeyBinary, amount: u64, + trigger_balance_check_threshold: u64, ) -> Result, Self::Error>; } @@ -156,6 +147,7 @@ impl Debiter for Arc>> { &self, payer: &PublicKeyBinary, amount: u64, + _trigger_balance_check_threshold: u64, ) -> Result, Infallible> { let map = self.lock().await; let balance = map.get(payer).unwrap(); @@ -276,19 +268,34 @@ pub enum ConfigServerError { NotFound(u64), } +pub struct CachedOrgClient { + client: OrgClient, + locked_cache: HashMap, +} + +impl CachedOrgClient { + pub fn new(client: OrgClient) -> Self { + Self { + client, + locked_cache: HashMap::new(), + } + } +} + #[async_trait] -impl ConfigServer for Arc> { +impl ConfigServer for Arc> { type Error = ConfigServerError; async fn fetch_org( &self, oui: u64, - cache: &mut HashMap, + oui_cache: &mut HashMap, ) -> Result { - if let Entry::Vacant(e) = cache.entry(oui) { + if let Entry::Vacant(e) = oui_cache.entry(oui) { let pubkey = PublicKeyBinary::from( self.lock() .await + .client .get(oui) .await? .org @@ -297,16 +304,24 @@ impl ConfigServer for Arc> { ); e.insert(pubkey); } - Ok(cache.get(&oui).unwrap().clone()) + Ok(oui_cache.get(&oui).unwrap().clone()) } async fn disable_org(&self, oui: u64) -> Result<(), Self::Error> { - self.lock().await.disable(oui).await?; + let mut cached_client = self.lock().await; + if *cached_client.locked_cache.entry(oui).or_insert(true) { + cached_client.client.disable(oui).await?; + *cached_client.locked_cache.get_mut(&oui).unwrap() = false; + } Ok(()) } async fn enable_org(&self, oui: u64) -> Result<(), Self::Error> { - self.lock().await.enable(oui).await?; + let mut cached_client = self.lock().await; + if !*cached_client.locked_cache.entry(oui).or_insert(false) { + cached_client.client.enable(oui).await?; + *cached_client.locked_cache.get_mut(&oui).unwrap() = true; + } Ok(()) } @@ -314,6 +329,7 @@ impl ConfigServer for Arc> { Ok(self .lock() .await + .client .list() .await? .into_iter() diff --git a/iot_packet_verifier/tests/integration_tests.rs b/iot_packet_verifier/tests/integration_tests.rs index 60f868a23..4603c1815 100644 --- a/iot_packet_verifier/tests/integration_tests.rs +++ b/iot_packet_verifier/tests/integration_tests.rs @@ -87,6 +87,7 @@ impl Debiter for InstantBurnedBalance { &self, payer: &PublicKeyBinary, amount: u64, + _trigger_balance_check_threshold: u64, ) -> Result, ()> { let map = self.0.lock().await; let balance = map.get(payer).unwrap(); @@ -484,43 +485,4 @@ async fn test_end_to_end() { invalid_packets, vec![invalid_packet(BYTES_PER_DC as u32, vec![5])] ); - - // Add one DC to the balance: - *solana_network.lock().await.get_mut(&payer).unwrap() = 1; - - valid_packets.clear(); - invalid_packets.clear(); - - // First packet should be invalid since it is too large, second - // should clear - verifier - .verify( - 1, - pending_burns.clone(), - stream::iter(vec![ - packet_report(0, 5, 2 * BYTES_PER_DC as u32, vec![6]), - packet_report(0, 6, BYTES_PER_DC as u32, vec![7]), - ]), - &mut valid_packets, - &mut invalid_packets, - ) - .await - .unwrap(); - - assert_eq!( - invalid_packets, - vec![invalid_packet(2 * BYTES_PER_DC as u32, vec![6])] - ); - assert_eq!( - valid_packets, - vec![valid_packet(6000, BYTES_PER_DC as u32, vec![7])] - ); - - let balance = { - let balances = verifier.debiter.balances(); - let balances = balances.lock().await; - *balances.get(&payer).unwrap() - }; - assert_eq!(balance.balance, 1); - assert_eq!(balance.burned, 1); }