From a93b458db8f30c3ae4c3f200a735da6f80b041bd Mon Sep 17 00:00:00 2001 From: Kevin Rodriguez <_@kevinrodriguez.io> Date: Sun, 9 Jul 2023 16:08:46 -0600 Subject: [PATCH] lang: add checked operations for lamport management --- CHANGELOG.md | 1 + lang/src/lib.rs | 51 +++++++++++++++++++++++- tests/misc/programs/lamports/src/lib.rs | 53 +++++++++++++++++++++++++ tests/misc/tests/lamports/lamports.ts | 20 ++++++++++ 4 files changed, 123 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f88abe5212..1832cf18d7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ The minor version will be incremented upon a breaking change and the patch versi ### Features +- lang: Add `add_lamports_checked` and `sub_lamports_checked` methods for all account types. - lang: Add `get_lamports`, `add_lamports` and `sub_lamports` methods for all account types ([#2552](https://github.com/coral-xyz/anchor/pull/2552)). - client: Add a helper struct `DynSigner` to simplify use of `Client where >` with Solana clap CLI utils that loads `Signer` as `Box` ([#2550](https://github.com/coral-xyz/anchor/pull/2550)). - lang: Allow CPI calls matching an interface without pinning program ID ([#2559](https://github.com/coral-xyz/anchor/pull/2559)). diff --git a/lang/src/lib.rs b/lang/src/lib.rs index 1ef35046de..b347b83f98 100644 --- a/lang/src/lib.rs +++ b/lang/src/lib.rs @@ -160,12 +160,35 @@ pub trait Lamports<'info>: AsRef> { /// the transaction. /// 3. `lamports` field of the account info should not currently be borrowed. /// - /// See [`Lamports::sub_lamports`] for subtracting lamports. + /// - See [`Lamports::add_lamports_checked`] for adding lamports with overflow checking. + /// - See [`Lamports::sub_lamports`] for subtracting lamports. + /// - See [`Lamports::sub_lamports_checked`] for subtracting lamports with overflow checking. fn add_lamports(&self, amount: u64) -> Result<&Self> { **self.as_ref().try_borrow_mut_lamports()? += amount; Ok(self) } + /// Add lamports to the account, checking for overflow. + /// + /// This method is useful for transfering lamports from a PDA. + /// Also raises an error if the operation overflows. + /// + /// # Requirements + /// + /// 1. The account must be marked `mut`. + /// 2. The total lamports **before** the transaction must equal to total lamports **after** + /// the transaction. + /// 3. `lamports` field of the account info should not currently be borrowed. + /// + /// - See [`Lamports::add_lamports`] for adding lamports. + /// - See [`Lamports::sub_lamports`] for subtracting lamports. + /// - See [`Lamports::sub_lamports_checked`] for subtracting lamports with overflow checking. + fn add_lamports_checked(&self, amount: u64, error: error::Error) -> Result<&Self> { + let result = self.get_lamports().checked_add(amount).ok_or(error)?; + **self.as_ref().try_borrow_mut_lamports()? = result; + Ok(self) + } + /// Subtract lamports from the account. /// /// This method is useful for transfering lamports from a PDA. @@ -178,11 +201,35 @@ pub trait Lamports<'info>: AsRef> { /// the transaction. /// 4. `lamports` field of the account info should not currently be borrowed. /// - /// See [`Lamports::add_lamports`] for adding lamports. + /// - See [`Lamports::add_lamports`] for adding lamports. + /// - See [`Lamports::add_lamports_checked`] for adding lamports with overflow checking. + /// - See [`Lamports::sub_lamports_checked`] for subtracting lamports with overflow checking. fn sub_lamports(&self, amount: u64) -> Result<&Self> { **self.as_ref().try_borrow_mut_lamports()? -= amount; Ok(self) } + + /// Subtract lamports from the account, checking for overflow. + /// + /// This method is useful for transfering lamports from a PDA. + /// Also raises an error if the operation overflows. + /// + /// # Requirements + /// + /// 1. The account must be owned by the executing program. + /// 2. The account must be marked `mut`. + /// 3. The total lamports **before** the transaction must equal to total lamports **after** + /// the transaction. + /// 4. `lamports` field of the account info should not currently be borrowed. + /// + /// - See [`Lamports::add_lamports`] for adding lamports. + /// - See [`Lamports::add_lamports_checked`] for adding lamports with overflow checking. + /// - See [`Lamports::sub_lamports`] for subtracting lamports. + fn sub_lamports_checked(&self, amount: u64, error: error::Error) -> Result<&Self> { + let result = self.get_lamports().checked_sub(amount).ok_or(error)?; + **self.as_ref().try_borrow_mut_lamports()? = result; + Ok(self) + } } impl<'info, T: AsRef>> Lamports<'info> for T {} diff --git a/tests/misc/programs/lamports/src/lib.rs b/tests/misc/programs/lamports/src/lib.rs index 84ac072d27..18691100cf 100644 --- a/tests/misc/programs/lamports/src/lib.rs +++ b/tests/misc/programs/lamports/src/lib.rs @@ -52,6 +52,53 @@ pub mod lamports { Ok(()) } + + pub fn test_lamports_trait_checked(ctx: Context, amount: u64) -> Result<()> { + let pda = &ctx.accounts.pda; + let signer = &ctx.accounts.signer; + + // Transfer **to** PDA + { + // Get the balance of the PDA **before** the transfer to PDA + let pda_balance_before = pda.get_lamports(); + + // Transfer to the PDA + anchor_lang::system_program::transfer( + CpiContext::new( + ctx.accounts.system_program.to_account_info(), + anchor_lang::system_program::Transfer { + from: signer.to_account_info(), + to: pda.to_account_info(), + }, + ), + amount, + )?; + + // Get the balance of the PDA **after** the transfer to PDA + let pda_balance_after = pda.get_lamports(); + + // Validate balance + require_eq!(pda_balance_after, pda_balance_before + amount); + } + + // Transfer **from** PDA + { + // Get the balance of the PDA **before** the transfer from PDA + let pda_balance_before = pda.get_lamports(); + + // Transfer from the PDA + pda.sub_lamports_checked(amount, LamportsError::NumericOverflow)?; + signer.add_lamports_checked(amount, LamportsError::NumericOverflow)?; + + // Get the balance of the PDA **after** the transfer from PDA + let pda_balance_after = pda.get_lamports(); + + // Validate balance + require_eq!(pda_balance_after, pda_balance_before - amount); + } + + Ok(()) + } } #[derive(Accounts)] @@ -73,3 +120,9 @@ pub struct TestLamportsTrait<'info> { #[account] pub struct LamportsPda {} + +#[error_code] +pub enum LamportsError { + #[msg("Numeric Overflow.")] + NumericOverflow, +} \ No newline at end of file diff --git a/tests/misc/tests/lamports/lamports.ts b/tests/misc/tests/lamports/lamports.ts index 678ed5c795..acc39b22d7 100644 --- a/tests/misc/tests/lamports/lamports.ts +++ b/tests/misc/tests/lamports/lamports.ts @@ -20,4 +20,24 @@ describe(IDL.name, () => { .accounts({ signer, pda }) .rpc(); }); + + it("Can use the Lamports trait for checked arithmetic", async () => { + const signer = program.provider.publicKey!; + const [pda] = anchor.web3.PublicKey.findProgramAddressSync( + [Buffer.from("lamports")], + program.programId + ); + + await program.methods + .testLamportsTraitChecked(new anchor.BN(anchor.web3.LAMPORTS_PER_SOL)) + .accounts({ signer, pda }) + .rpc(); + }); + + it("Throws when overflow occurs", async () => { + const signer = program.provider.publicKey!; + const [pda] = anchor.web3.PublicKey.findProgramAddressSync( + [Buffer.from("lamports")], + program.programId + ); });