Skip to content

Commit

Permalink
Add a few more tests, refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
jankjr committed Jun 30, 2023
1 parent a4f7051 commit c61746c
Show file tree
Hide file tree
Showing 4 changed files with 117 additions and 17 deletions.
44 changes: 31 additions & 13 deletions contracts/plugins/assets/erc20/RewardableERC20.sol
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// SPDX-License-Identifier: BlueOak-1.0.0
pragma solidity ^0.8.19;

import { ReentrancyGuard } from "@openzeppelin/contracts/security/ReentrancyGuard.sol";
import "@openzeppelin/contracts/token/ERC20/utils/SafeERC20.sol";
import "@openzeppelin/contracts/token/ERC20/ERC20.sol";
import "@openzeppelin/contracts/token/ERC20/IERC20.sol";
Expand All @@ -15,7 +16,7 @@ import "../../../interfaces/IRewardable.sol";
* - override _claimAssetRewards()
* - call ERC20 constructor elsewhere during construction
*/
abstract contract RewardableERC20 is IRewardable, ERC20 {
abstract contract RewardableERC20 is IRewardable, ERC20, ReentrancyGuard {
using SafeERC20 for IERC20;

uint256 public immutable one; // {qShare/share}
Expand All @@ -37,27 +38,32 @@ abstract contract RewardableERC20 is IRewardable, ERC20 {
one = 10**_decimals; // set via pass-in to prevent inheritance issues
}

function claimRewards() external {
function claimRewards() external nonReentrant {
_claimAndSyncRewards();
_syncAccount(msg.sender);
_claimAccountRewards(msg.sender);
}

function sync() external {
_claimAndSyncRewards();
}

function _syncAccount(address account) internal {
if (account == address(0)) return;

// {qRewards/share}
uint256 accountRewardsPerShare = lastRewardsPerShare[account];

// {qShare}
uint256 shares = balanceOf(account);

// {qRewards}
uint256 _accumuatedRewards = accumulatedRewards[account];

// {qRewards/share}
uint256 _rewardsPerShare = rewardsPerShare;
if (accountRewardsPerShare < _rewardsPerShare) {
// {qRewards/share}
uint256 delta = _rewardsPerShare - accountRewardsPerShare;
delta = (delta * shares) / one;
// {qRewards} = {qRewards/share} * {qShare} / {qShare/share}
_accumuatedRewards = _accumuatedRewards + delta;

// {qRewards} = {qRewards/share} * {qShare}
_accumuatedRewards += (delta * shares) / one;
}
lastRewardsPerShare[account] = _rewardsPerShare;
accumulatedRewards[account] = _accumuatedRewards;
Expand Down Expand Up @@ -85,17 +91,29 @@ abstract contract RewardableERC20 is IRewardable, ERC20 {

function _claimAccountRewards(address account) internal {
uint256 claimableRewards = accumulatedRewards[account] - claimedRewards[account];

emit RewardsClaimed(IERC20(address(rewardToken)), claimableRewards);

if (claimableRewards == 0) {
return;
}

emit RewardsClaimed(IERC20(address(rewardToken)), claimableRewards);
claimedRewards[account] = accumulatedRewards[account];

uint256 currentRewardTokenBalance = rewardToken.balanceOf(address(this));

// This is just to handle the edge case where totalSupply() == 0 and there
// are still reward tokens in the contract.
uint256 nonDistributed = currentRewardTokenBalance > previousBalance
? currentRewardTokenBalance - previousBalance
: 0;

rewardToken.safeTransfer(account, claimableRewards);

// If totalSupply() == 0 and someone with claimable rewards calls claim rewards
// we could possibly end up having unaccounted reward tokens in the contract.
previousBalance = rewardToken.balanceOf(address(this));
currentRewardTokenBalance = rewardToken.balanceOf(address(this));
previousBalance = currentRewardTokenBalance > nonDistributed
? currentRewardTokenBalance - nonDistributed
: 0;
}

function _beforeTokenTransfer(
Expand Down
4 changes: 4 additions & 0 deletions contracts/plugins/mocks/RewardableERC20WrapperTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ contract RewardableERC20WrapperTest is RewardableERC20Wrapper {
function _claimAssetRewards() internal virtual override {
ERC20MockRewarding(address(underlying)).claim();
}

function sync() external {
_claimAndSyncRewards();
}
}
4 changes: 4 additions & 0 deletions contracts/plugins/mocks/RewardableERC4626VaultTest.sol
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,8 @@ contract RewardableERC4626VaultTest is RewardableERC4626Vault {
function _claimAssetRewards() internal virtual override {
ERC20MockRewarding(asset()).claim();
}

function sync() external {
_claimAndSyncRewards();
}
}
82 changes: 78 additions & 4 deletions test/plugins/RewardableERC20.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,20 @@ import {
ERC20MockDecimals,
ERC20MockRewarding,
RewardableERC20Wrapper,
RewardableERC20WrapperTest,
RewardableERC4626Vault,
RewardableERC4626VaultTest,
} from '../../typechain'
import { cartesianProduct } from '../utils/cases'
import { useEnv } from '#/utils/env'
import { Implementation } from '../fixtures'
import snapshotGasCost from '../utils/snapshotGasCost'
import { parseUnits } from 'ethers/lib/utils'
import { formatUnits, parseUnits } from 'ethers/lib/utils'

type Fixture<T> = () => Promise<T>

interface RewardableERC20Fixture {
rewardableVault: RewardableERC4626Vault | RewardableERC20Wrapper
rewardableVault: RewardableERC4626VaultTest | RewardableERC20WrapperTest
rewardableAsset: ERC20MockRewarding
rewardToken: ERC20MockDecimals
}
Expand Down Expand Up @@ -60,7 +62,7 @@ for (const wrapperName of wrapperNames) {
)

const rewardableVaultFactory: ContractFactory = await ethers.getContractFactory(wrapperName)
const rewardableVault = <RewardableERC4626Vault | RewardableERC20Wrapper>(
const rewardableVault = <RewardableERC4626VaultTest | RewardableERC20WrapperTest>(
await rewardableVaultFactory.deploy(
rewardableAsset.address,
'Rewarding Test Asset Vault',
Expand Down Expand Up @@ -96,14 +98,28 @@ for (const wrapperName of wrapperNames) {
return wrapperERC4626.withdraw(amount, to, to)
}
}
const withdrawAll = async (
wrapper: RewardableERC4626Vault | RewardableERC20Wrapper,
to?: string
): Promise<ContractTransaction> => {
const owner = await wrapper.signer.getAddress()
to = to || owner
if (wrapperName == Wrapper.ERC20) {
const wrapperERC20 = wrapper as RewardableERC20Wrapper
return wrapperERC20.withdraw(await wrapperERC20.balanceOf(owner), to)
} else {
const wrapperERC4626 = wrapper as RewardableERC4626Vault
return wrapperERC4626.withdraw(await wrapperERC4626.maxWithdraw(owner), to, owner)
}
}

const runTests = (assetDecimals: number, rewardDecimals: number) => {
describe(wrapperName, () => {
// Decimals
let shareDecimals: number

// Assets
let rewardableVault: RewardableERC20Wrapper | RewardableERC4626Vault
let rewardableVault: RewardableERC20WrapperTest | RewardableERC4626VaultTest
let rewardableAsset: ERC20MockRewarding
let rewardToken: ERC20MockDecimals

Expand Down Expand Up @@ -166,6 +182,64 @@ for (const wrapperName of wrapperNames) {
await rewardableVault.sync()
expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals))
})

it('correctly handles reward tracking if supply is burned', async () => {
await rewardableVault
.connect(alice)
.deposit(parseUnits('10', assetDecimals), alice.address)
expect(await rewardableVault.rewardsPerShare()).to.equal(bn(0))
expect(await rewardableVault.lastRewardsPerShare(alice.address)).to.equal(bn(0))
await rewardToken.mint(rewardableVault.address, parseUnits('10', rewardDecimals))
await rewardableVault.sync()
expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals))

// Setting supply to 0
await withdrawAll(rewardableVault.connect(alice))
expect(await rewardableVault.totalSupply()).to.equal(bn(0))

// Add some undistributed reward tokens to the vault
await rewardToken.mint(rewardableVault.address, parseUnits('10', rewardDecimals))

// Claim whatever rewards are available
expect(await rewardToken.balanceOf(alice.address)).to.be.equal(bn(0))
await rewardableVault.connect(alice).claimRewards()

expect(await rewardToken.balanceOf(alice.address)).to.be.equal(
parseUnits('10', rewardDecimals)
)

// Nothing updates.. as totalSupply as totalSupply is 0
await rewardableVault.sync()
expect(await rewardableVault.rewardsPerShare()).to.equal(parseUnits('1', rewardDecimals))
await rewardableVault
.connect(alice)
.deposit(parseUnits('10', assetDecimals), alice.address)
await rewardableVault.sync()

await rewardableVault.connect(alice).claimRewards()
expect(await rewardToken.balanceOf(alice.address)).to.be.equal(
parseUnits('20', rewardDecimals)
)
})

it('1 wei supply', async () => {
await rewardableVault.connect(alice).deposit('1', alice.address)
expect(await rewardableVault.rewardsPerShare()).to.equal(bn(0))
expect(await rewardableVault.lastRewardsPerShare(alice.address)).to.equal(bn(0))
await rewardToken.mint(rewardableVault.address, parseUnits('1', rewardDecimals))
await rewardableVault.sync()
await rewardableVault.connect(bob).deposit('10', bob.address)
await rewardableVault.connect(alice).deposit('10', alice.address)
await rewardToken.mint(rewardableVault.address, parseUnits('99', rewardDecimals))
await rewardableVault.connect(alice).claimRewards()
await rewardableVault.connect(bob).claimRewards()
const aliceBalance = await rewardToken.balanceOf(await alice.getAddress())
const bobBalance = await rewardToken.balanceOf(await bob.getAddress())

expect(parseFloat(formatUnits(aliceBalance, rewardDecimals))).to.be.closeTo(52.8, 0.1)

expect(parseFloat(formatUnits(bobBalance, rewardDecimals))).to.be.closeTo(47.1, 0.1)
})
})

describe('alice deposit, accrue, alice deposit, bob deposit', () => {
Expand Down

0 comments on commit c61746c

Please sign in to comment.