Skip to content

Commit

Permalink
Make split weights uint256 (#345)
Browse files Browse the repository at this point in the history
  • Loading branch information
CodeSandwich committed Mar 4, 2024
1 parent fbe3456 commit ae7606d
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/Drips.sol
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ contract Drips is Managed, Streams, Splits {
/// Limits the cost of splitting.
uint256 public constant MAX_SPLITS_RECEIVERS = _MAX_SPLITS_RECEIVERS;
/// @notice The total splits weight of an account
uint32 public constant TOTAL_SPLITS_WEIGHT = _TOTAL_SPLITS_WEIGHT;
uint256 public constant TOTAL_SPLITS_WEIGHT = _TOTAL_SPLITS_WEIGHT;
/// @notice The offset of the controlling driver ID in the account ID.
/// In other words the controlling driver ID is the highest 32 bits of the account ID.
/// Every account ID is a 256-bit integer constructed by concatenating:
Expand Down
6 changes: 4 additions & 2 deletions src/ImmutableSplitsDriver.sol
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ contract ImmutableSplitsDriver is Managed {
/// @notice The driver ID which this driver uses when calling Drips.
uint32 public immutable driverId;
/// @notice The required total splits weight of each splits configuration
uint32 public immutable totalSplitsWeight;
uint256 public immutable totalSplitsWeight;
/// @notice The ERC-1967 storage slot holding a single `uint256` counter of created identities.
bytes32 private immutable _counterSlot = _erc1967Slot("eip1967.immutableSplitsDriver.storage");

Expand Down Expand Up @@ -76,7 +76,9 @@ contract ImmutableSplitsDriver is Managed {
uint256 weightSum = 0;
unchecked {
for (uint256 i = 0; i < receivers.length; i++) {
weightSum += receivers[i].weight;
uint256 weight = receivers[i].weight;
if (weight > totalSplitsWeight) weight = totalSplitsWeight + 1;
weightSum += weight;
}
}
require(weightSum == totalSplitsWeight, "Invalid total receivers weight");
Expand Down
11 changes: 8 additions & 3 deletions src/Splits.sol
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ struct SplitsReceiver {
/// @notice The splits weight. Must never be zero.
/// The account will be getting `weight / _TOTAL_SPLITS_WEIGHT`
/// share of the funds collected by the splitting account.
uint32 weight;
uint256 weight;
}

/// @notice Splits can keep track of at most `type(uint128).max`
Expand All @@ -22,7 +22,7 @@ abstract contract Splits {
/// Limits the cost of splitting.
uint256 internal constant _MAX_SPLITS_RECEIVERS = 200;
/// @notice The total splits weight of an account.
uint32 internal constant _TOTAL_SPLITS_WEIGHT = 1_000_000;
uint256 internal constant _TOTAL_SPLITS_WEIGHT = 1_000_000;
/// @notice The amount the contract can keep track of each ERC-20 token.
// slither-disable-next-line unused-state
uint128 internal constant _MAX_SPLITS_BALANCE = _SPLITTABLE_MASK;
Expand Down Expand Up @@ -127,6 +127,8 @@ abstract contract Splits {
unchecked {
uint256 splitsWeight = 0;
for (uint256 i = currReceivers.length; i != 0;) {
// This will not overflow because the receivers list
// is verified to add up to no more than _TOTAL_SPLITS_WEIGHT
splitsWeight += currReceivers[--i].weight;
}
splitAmt = uint128(amount * splitsWeight / _TOTAL_SPLITS_WEIGHT);
Expand Down Expand Up @@ -164,6 +166,8 @@ abstract contract Splits {
unchecked {
uint256 splitsWeight = 0;
for (uint256 i = 0; i < currReceivers.length; i++) {
// This will not overflow because the receivers list
// is verified to add up to no more than _TOTAL_SPLITS_WEIGHT
splitsWeight += currReceivers[i].weight;
uint128 currSplitAmt = splitAmt;
splitAmt = uint128(splittable * splitsWeight / _TOTAL_SPLITS_WEIGHT);
Expand Down Expand Up @@ -247,8 +251,9 @@ abstract contract Splits {
uint256 prevAccountId = 0;
for (uint256 i = 0; i < receivers.length; i++) {
SplitsReceiver memory receiver = receivers[i];
uint32 weight = receiver.weight;
uint256 weight = receiver.weight;
require(weight != 0, "Splits receiver weight is zero");
if (weight > _TOTAL_SPLITS_WEIGHT) weight = _TOTAL_SPLITS_WEIGHT + 1;
totalWeight += weight;
uint256 accountId = receiver.accountId;
if (accountId <= prevAccountId) require(i == 0, "Splits receivers not sorted");
Expand Down
10 changes: 5 additions & 5 deletions test/Drips.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ contract DripsTest is Test {
list = new SplitsReceiver[](0);
}

function splitsReceivers(uint256 splitsReceiver, uint32 weight)
function splitsReceivers(uint256 splitsReceiver, uint256 weight)
internal
pure
returns (SplitsReceiver[] memory list)
Expand All @@ -203,9 +203,9 @@ contract DripsTest is Test {

function splitsReceivers(
uint256 splitsReceiver1,
uint32 weight1,
uint256 weight1,
uint256 splitsReceiver2,
uint32 weight2
uint256 weight2
) internal pure returns (SplitsReceiver[] memory list) {
list = new SplitsReceiver[](2);
list[0] = SplitsReceiver(splitsReceiver1, weight1);
Expand Down Expand Up @@ -418,7 +418,7 @@ contract DripsTest is Test {
}

function testUncollectedFundsAreSplitUsingCurrentConfig() public {
uint32 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
uint256 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
setSplits(accountId1, splitsReceivers(receiver1, totalWeight));
setStreams(accountId2, 0, 5, streamsReceivers(accountId1, 5));
skipToCycleEnd();
Expand Down Expand Up @@ -508,7 +508,7 @@ contract DripsTest is Test {
}

function testSplitSplitsFundsReceivedFromAllSources() public {
uint32 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
uint256 totalWeight = drips.TOTAL_SPLITS_WEIGHT();
// Gives
give(accountId2, accountId1, 1);

Expand Down
32 changes: 25 additions & 7 deletions test/ImmutableSplitsDriver.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {Test} from "forge-std/Test.sol";
contract ImmutableSplitsDriverTest is Test {
Drips internal drips;
ImmutableSplitsDriver internal driver;
uint32 internal totalSplitsWeight;
uint256 internal totalSplitsWeight;

function setUp() public {
Drips dripsLogic = new Drips(10);
Expand All @@ -25,10 +25,18 @@ contract ImmutableSplitsDriverTest is Test {
totalSplitsWeight = driver.totalSplitsWeight();
}

function splitsReceivers(uint256 weight1, uint256 weight2)
internal
pure
returns (SplitsReceiver[] memory list)
{
list = new SplitsReceiver[](2);
list[0] = SplitsReceiver(1, weight1);
list[1] = SplitsReceiver(2, weight2);
}

function testCreateSplits() public {
SplitsReceiver[] memory receivers = new SplitsReceiver[](2);
receivers[0] = SplitsReceiver({accountId: 1, weight: totalSplitsWeight - 1});
receivers[1] = SplitsReceiver({accountId: 2, weight: 1});
SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 1, 1);
uint256 nextAccountId = driver.nextAccountId();
AccountMetadata[] memory metadata = new AccountMetadata[](1);
metadata[0] = AccountMetadata("key", "value");
Expand All @@ -43,10 +51,20 @@ contract ImmutableSplitsDriverTest is Test {
}

function testCreateSplitsRevertsWhenWeightsSumTooLow() public {
SplitsReceiver[] memory receivers = new SplitsReceiver[](2);
receivers[0] = SplitsReceiver({accountId: 1, weight: totalSplitsWeight - 2});
receivers[1] = SplitsReceiver({accountId: 2, weight: 1});
SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 2, 1);
vm.expectRevert("Invalid total receivers weight");
driver.createSplits(receivers, new AccountMetadata[](0));
}

function testCreateSplitsRevertsWhenWeightsSumTooHigh() public {
SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 1, 2);
vm.expectRevert("Invalid total receivers weight");
driver.createSplits(receivers, new AccountMetadata[](0));
}

function testCreateSplitsRevertsWhenWeightsSumOverflows() public {
SplitsReceiver[] memory receivers =
splitsReceivers(totalSplitsWeight + 1, type(uint256).max);
vm.expectRevert("Invalid total receivers weight");
driver.createSplits(receivers, new AccountMetadata[](0));
}
Expand Down
32 changes: 21 additions & 11 deletions test/Splits.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ contract SplitsTest is Test, Splits {
list = new SplitsReceiver[](0);
}

function splitsReceivers(uint256 usedAccountId, uint32 weight)
function splitsReceivers(uint256 usedAccountId, uint256 weight)
internal
pure
returns (SplitsReceiver[] memory list)
Expand All @@ -36,7 +36,7 @@ contract SplitsTest is Test, Splits {
list[0] = SplitsReceiver(usedAccountId, weight);
}

function splitsReceivers(uint256 account1, uint32 weight1, uint256 account2, uint32 weight2)
function splitsReceivers(uint256 account1, uint256 weight1, uint256 account2, uint256 weight2)
internal
pure
returns (SplitsReceiver[] memory list)
Expand Down Expand Up @@ -190,13 +190,23 @@ contract SplitsTest is Test, Splits {
}

function testRejectsTooHighTotalWeightSplitsReceivers() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
setSplits(accountId, splitsReceivers(receiver, totalWeight));
assertSetSplitsReverts(
accountId, splitsReceivers(receiver, totalWeight + 1), "Splits weights sum too high"
);
}

function testRejectsOverflowingTotalWeightSplitsReceivers() public {
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
setSplits(accountId, splitsReceivers(receiver, totalWeight));
assertSetSplitsReverts(
accountId,
splitsReceivers(receiver1, type(uint256).max, receiver2, 4),
"Splits weights sum too high"
);
}

function testRejectsZeroWeightSplitsReceivers() public {
assertSetSplitsReverts(
accountId, splitsReceivers(receiver, 0), "Splits receiver weight is zero"
Expand All @@ -216,7 +226,7 @@ contract SplitsTest is Test, Splits {
}

function testCanSplitAllWhenCollectedDoesNotSplitEvenly() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
// 3 waiting for accountId
addSplittable(accountId, 3);

Expand All @@ -239,7 +249,7 @@ contract SplitsTest is Test, Splits {
}

function testSplittingSplitsAllFundsEvenWhenTheyDoNotDivideEvenly() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
setSplits(
accountId, splitsReceivers(receiver1, (totalWeight / 5) * 2, receiver2, totalWeight / 5)
);
Expand All @@ -251,7 +261,7 @@ contract SplitsTest is Test, Splits {
}

function testAccountCanSplitToItself() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
// receiver1 receives 30%, gets 50% split to themselves and receiver2 gets split 20%
setSplits(
receiver1, splitsReceivers(receiver1, totalWeight / 2, receiver2, totalWeight / 5)
Expand Down Expand Up @@ -280,7 +290,7 @@ contract SplitsTest is Test, Splits {
}

function testSplitsConfigurationIsCommonBetweenTokens() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
setSplits(accountId, splitsReceivers(receiver1, totalWeight / 10));
erc20 = defaultErc20;
addSplittable(accountId, 30);
Expand All @@ -298,7 +308,7 @@ contract SplitsTest is Test, Splits {
}

function testForwardSplits() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;

addSplittable(accountId, 10);
setSplits(accountId, splitsReceivers(receiver1, totalWeight));
Expand All @@ -315,7 +325,7 @@ contract SplitsTest is Test, Splits {
}

function testSplitMultipleReceivers() public {
uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT;
addSplittable(accountId, 10);

setSplits(
Expand Down Expand Up @@ -352,6 +362,7 @@ contract SplitsTest is Test, Splits {
uint256 weightSum = 0;
for (uint256 i = 0; i < receivers.length; i++) {
receivers[i] = receiversRaw[i];
receivers[i].weight %= _TOTAL_SPLITS_WEIGHT;
weightSum += receivers[i].weight;
}
if (weightSum == 0) weightSum = 1;
Expand All @@ -360,8 +371,7 @@ contract SplitsTest is Test, Splits {
for (uint256 i = 0; i < receivers.length; i++) {
uint256 usedTotalWeight = totalWeight * usedWeight / weightSum;
usedWeight += receivers[i].weight;
receivers[i].weight =
uint32((totalWeight * usedWeight / weightSum) - usedTotalWeight + 1);
receivers[i].weight = (totalWeight * usedWeight / weightSum) - usedTotalWeight + 1;
}
}

Expand Down
2 changes: 1 addition & 1 deletion test/dataStore/DripsDataProxy.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ contract DripsDataProxyTest is Test {
}

function testSplit() public {
uint32 splitWeight = drips.TOTAL_SPLITS_WEIGHT() / 4;
uint256 splitWeight = drips.TOTAL_SPLITS_WEIGHT() / 4;
uint128 totalAmt = 8;
uint128 splitAmt = 2;
uint128 collectableAmt = 6;
Expand Down

0 comments on commit ae7606d

Please sign in to comment.