Skip to content

Commit

Permalink
add approval func and related tests
Browse files Browse the repository at this point in the history
  • Loading branch information
programskillforverification committed Sep 16, 2024
1 parent 5b871cc commit b65f68b
Showing 1 changed file with 88 additions and 16 deletions.
104 changes: 88 additions & 16 deletions contracts/src/token/erc1155/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use stylus_sdk::{
prelude::*,
};

use crate::utils::math::storage::{AddAssignUnchecked, SubAssignUnchecked};
use crate::utils::math::storage::SubAssignUnchecked;

pub mod extensions;

Expand Down Expand Up @@ -367,7 +367,7 @@ pub trait IErc1155 {
) -> Result<(), Self::Error>;
}

#[external]
#[public]
impl IErc1155 for Erc1155 {
type Error = Error;

Expand All @@ -386,12 +386,12 @@ impl IErc1155 for Erc1155 {
) -> Result<Vec<U256>, Self::Error> {
if accounts.len() != token_ids.len() {
return Err(Error::InvalidArrayLength(ERC1155InvalidArrayLength {
ids_length: uint!(token_ids.len()),
values_length: uint!(accounts.len()),
ids_length: U256::from(token_ids.len()),
values_length: U256::from(accounts.len()),
}));
}

let mut balances = accounts
let balances: Vec<Uint<256, 4>> = accounts
.iter()
.zip(token_ids.iter())
.map(|(&account, &token_id)| {
Expand All @@ -406,6 +406,7 @@ impl IErc1155 for Erc1155 {
operator: Address,
approved: bool,
) -> Result<(), Self::Error> {
self._set_approval_for_all(msg::sender(), operator, approved)?;
Ok(())
}

Expand Down Expand Up @@ -471,13 +472,13 @@ impl Erc1155 {
) -> Result<(), Error> {
if token_ids.len() != values.len() {
return Err(Error::InvalidArrayLength(ERC1155InvalidArrayLength {
ids_length: uint!(token_ids.len()),
values_length: uint!(values.len()),
ids_length: U256::from(token_ids.len()),
values_length: U256::from(values.len()),
}));
}

let operator = msg::sender();
token_ids.iter().zip(values.iter()).for_each(|(&token_id, &value)| {
for (&token_id, &value) in token_ids.iter().zip(values.iter()) {
let from_balance = self._balances.get(token_id).get(from);
if from_balance < value {
return Err(Error::InsufficientBalance(
Expand All @@ -501,7 +502,7 @@ impl Erc1155 {
.checked_add(value)
.expect("should not exceed `U256::MAX` for `_balances`");
}
});
}

if token_ids.len() == 1 {
evm::log(TransferSingle {
Expand Down Expand Up @@ -562,6 +563,34 @@ impl Erc1155 {
Ok(())
}

/// Approve `operator` to operate on all of `owner` tokens
///
/// Emits an [`ApprovalForAll`] event.
///
/// Requirements:
///
/// - `operator` cannot be the zero address.
///
/// # Errors
///
/// If `operator` is the zero address, then the error
/// [`Error::InvalidOperator`] is returned.
fn _set_approval_for_all(
&mut self,
owner: Address,
operator: Address,
approved: bool,
) -> Result<(), Error> {
if operator.is_zero() {
return Err(Error::InvalidOperator(ERC1155InvalidOperator {
operator,
}));
}
self._operator_approvals.setter(owner).setter(operator).set(approved);
evm::log(ApprovalForAll { account: owner, operator, approved });
Ok(())
}

/// Performs an acceptance check for the provided `operator` by
/// calling [`IERC1155Receiver::on_erc_1155_received`] on the `to` address.
/// The `operator` is generally the address that initiated the token
Expand Down Expand Up @@ -605,9 +634,14 @@ impl Erc1155 {

let receiver = IERC1155Receiver::new(to);
let call = Call::new_in(self);
let data = data.to_vec();
let result = receiver
.on_erc_1155_received(call, operator, from, token_id, value, data);
let result = receiver.on_erc_1155_received(
call,
operator,
from,
token_id,
value,
data.to_vec().into(),
);

let id = match result {
Ok(id) => id,
Expand Down Expand Up @@ -674,9 +708,13 @@ impl Erc1155 {

let receiver = IERC1155Receiver::new(to);
let call = Call::new_in(self);
let data = data.to_vec();
let result = receiver.on_erc_1155_batch_received(
call, operator, from, token_ids, values, data,
call,
operator,
from,
token_ids,
values,
data.to_vec().into(),
);

let id = match result {
Expand Down Expand Up @@ -705,10 +743,12 @@ impl Erc1155 {
#[cfg(all(test, feature = "std"))]
mod tests {
use alloy_primitives::{address, uint, Address, U256};
use alloy_sol_types::token;
use stylus_sdk::{contract, msg};

use super::{ERC1155InvalidArrayLength, Erc1155, Error, IErc1155};
use super::{
ERC1155InvalidArrayLength, ERC1155InvalidOperator, Erc1155, Error,
IErc1155,
};

const ALICE: Address = address!("A11CEacF9aa32246d767FCCD72e02d6bCbcC375d");
const BOB: Address = address!("F4EaCDAbEf3c8f1EdE91b6f2A6840bc2E4DD3526");
Expand Down Expand Up @@ -762,4 +802,36 @@ mod tests {
assert_eq!(U256::ZERO, balance);
}
}

#[motsu::test]
fn test_set_approval_for_all(contract: Erc1155) {
let alice = msg::sender();
contract._operator_approvals.setter(alice).setter(BOB).set(false);

contract
.set_approval_for_all(BOB, true)
.expect("should approve Bob for operations on all Alice's tokens");
assert_eq!(contract.is_approved_for_all(alice, BOB), true);

contract.set_approval_for_all(BOB, false).expect(
"should disapprove Bob for operations on all Alice's tokens",
);
assert_eq!(contract.is_approved_for_all(alice, BOB), false);
}

#[motsu::test]
fn test_error_invalid_operator_when_approval_for_all(contract: Erc1155) {
let invalid_operator = Address::ZERO;

let err = contract
.set_approval_for_all(invalid_operator, true)
.expect_err("should not approve for all for invalid operator");

assert!(matches!(
err,
Error::InvalidOperator(ERC1155InvalidOperator {
operator
}) if operator == invalid_operator
));
}
}

0 comments on commit b65f68b

Please sign in to comment.