diff --git a/src/Drips.sol b/src/Drips.sol index 7c5ce7e0..21dc719c 100644 --- a/src/Drips.sol +++ b/src/Drips.sol @@ -64,7 +64,7 @@ contract Drips is Managed, Streams, Splits { /// Limits the cost of splitting. uint256 public constant MAX_SPLITS_RECEIVERS = _MAX_SPLITS_RECEIVERS; /// @notice The total splits weight of an account - uint32 public constant TOTAL_SPLITS_WEIGHT = _TOTAL_SPLITS_WEIGHT; + uint256 public constant TOTAL_SPLITS_WEIGHT = _TOTAL_SPLITS_WEIGHT; /// @notice The offset of the controlling driver ID in the account ID. /// In other words the controlling driver ID is the highest 32 bits of the account ID. /// Every account ID is a 256-bit integer constructed by concatenating: diff --git a/src/ImmutableSplitsDriver.sol b/src/ImmutableSplitsDriver.sol index 6b53d7c3..c2c60cac 100644 --- a/src/ImmutableSplitsDriver.sol +++ b/src/ImmutableSplitsDriver.sol @@ -17,7 +17,7 @@ contract ImmutableSplitsDriver is Managed { /// @notice The driver ID which this driver uses when calling Drips. uint32 public immutable driverId; /// @notice The required total splits weight of each splits configuration - uint32 public immutable totalSplitsWeight; + uint256 public immutable totalSplitsWeight; /// @notice The ERC-1967 storage slot holding a single `uint256` counter of created identities. bytes32 private immutable _counterSlot = _erc1967Slot("eip1967.immutableSplitsDriver.storage"); @@ -76,7 +76,9 @@ contract ImmutableSplitsDriver is Managed { uint256 weightSum = 0; unchecked { for (uint256 i = 0; i < receivers.length; i++) { - weightSum += receivers[i].weight; + uint256 weight = receivers[i].weight; + if (weight > totalSplitsWeight) weight = totalSplitsWeight + 1; + weightSum += weight; } } require(weightSum == totalSplitsWeight, "Invalid total receivers weight"); diff --git a/src/Splits.sol b/src/Splits.sol index b0254c2b..4a88d785 100644 --- a/src/Splits.sol +++ b/src/Splits.sol @@ -10,7 +10,7 @@ struct SplitsReceiver { /// @notice The splits weight. Must never be zero. /// The account will be getting `weight / _TOTAL_SPLITS_WEIGHT` /// share of the funds collected by the splitting account. - uint32 weight; + uint256 weight; } /// @notice Splits can keep track of at most `type(uint128).max` @@ -22,7 +22,7 @@ abstract contract Splits { /// Limits the cost of splitting. uint256 internal constant _MAX_SPLITS_RECEIVERS = 200; /// @notice The total splits weight of an account. - uint32 internal constant _TOTAL_SPLITS_WEIGHT = 1_000_000; + uint256 internal constant _TOTAL_SPLITS_WEIGHT = 1_000_000; /// @notice The amount the contract can keep track of each ERC-20 token. // slither-disable-next-line unused-state uint128 internal constant _MAX_SPLITS_BALANCE = _SPLITTABLE_MASK; @@ -127,6 +127,8 @@ abstract contract Splits { unchecked { uint256 splitsWeight = 0; for (uint256 i = currReceivers.length; i != 0;) { + // This will not overflow because the receivers list + // is verified to add up to no more than _TOTAL_SPLITS_WEIGHT splitsWeight += currReceivers[--i].weight; } splitAmt = uint128(amount * splitsWeight / _TOTAL_SPLITS_WEIGHT); @@ -164,6 +166,8 @@ abstract contract Splits { unchecked { uint256 splitsWeight = 0; for (uint256 i = 0; i < currReceivers.length; i++) { + // This will not overflow because the receivers list + // is verified to add up to no more than _TOTAL_SPLITS_WEIGHT splitsWeight += currReceivers[i].weight; uint128 currSplitAmt = splitAmt; splitAmt = uint128(splittable * splitsWeight / _TOTAL_SPLITS_WEIGHT); @@ -247,8 +251,9 @@ abstract contract Splits { uint256 prevAccountId = 0; for (uint256 i = 0; i < receivers.length; i++) { SplitsReceiver memory receiver = receivers[i]; - uint32 weight = receiver.weight; + uint256 weight = receiver.weight; require(weight != 0, "Splits receiver weight is zero"); + if (weight > _TOTAL_SPLITS_WEIGHT) weight = _TOTAL_SPLITS_WEIGHT + 1; totalWeight += weight; uint256 accountId = receiver.accountId; if (accountId <= prevAccountId) require(i == 0, "Splits receivers not sorted"); diff --git a/test/Drips.t.sol b/test/Drips.t.sol index 96c3f02b..b10e2144 100644 --- a/test/Drips.t.sol +++ b/test/Drips.t.sol @@ -192,7 +192,7 @@ contract DripsTest is Test { list = new SplitsReceiver[](0); } - function splitsReceivers(uint256 splitsReceiver, uint32 weight) + function splitsReceivers(uint256 splitsReceiver, uint256 weight) internal pure returns (SplitsReceiver[] memory list) @@ -203,9 +203,9 @@ contract DripsTest is Test { function splitsReceivers( uint256 splitsReceiver1, - uint32 weight1, + uint256 weight1, uint256 splitsReceiver2, - uint32 weight2 + uint256 weight2 ) internal pure returns (SplitsReceiver[] memory list) { list = new SplitsReceiver[](2); list[0] = SplitsReceiver(splitsReceiver1, weight1); @@ -418,7 +418,7 @@ contract DripsTest is Test { } function testUncollectedFundsAreSplitUsingCurrentConfig() public { - uint32 totalWeight = drips.TOTAL_SPLITS_WEIGHT(); + uint256 totalWeight = drips.TOTAL_SPLITS_WEIGHT(); setSplits(accountId1, splitsReceivers(receiver1, totalWeight)); setStreams(accountId2, 0, 5, streamsReceivers(accountId1, 5)); skipToCycleEnd(); @@ -508,7 +508,7 @@ contract DripsTest is Test { } function testSplitSplitsFundsReceivedFromAllSources() public { - uint32 totalWeight = drips.TOTAL_SPLITS_WEIGHT(); + uint256 totalWeight = drips.TOTAL_SPLITS_WEIGHT(); // Gives give(accountId2, accountId1, 1); diff --git a/test/ImmutableSplitsDriver.t.sol b/test/ImmutableSplitsDriver.t.sol index a2bd257f..afa712de 100644 --- a/test/ImmutableSplitsDriver.t.sol +++ b/test/ImmutableSplitsDriver.t.sol @@ -9,7 +9,7 @@ import {Test} from "forge-std/Test.sol"; contract ImmutableSplitsDriverTest is Test { Drips internal drips; ImmutableSplitsDriver internal driver; - uint32 internal totalSplitsWeight; + uint256 internal totalSplitsWeight; function setUp() public { Drips dripsLogic = new Drips(10); @@ -25,10 +25,18 @@ contract ImmutableSplitsDriverTest is Test { totalSplitsWeight = driver.totalSplitsWeight(); } + function splitsReceivers(uint256 weight1, uint256 weight2) + internal + pure + returns (SplitsReceiver[] memory list) + { + list = new SplitsReceiver[](2); + list[0] = SplitsReceiver(1, weight1); + list[1] = SplitsReceiver(2, weight2); + } + function testCreateSplits() public { - SplitsReceiver[] memory receivers = new SplitsReceiver[](2); - receivers[0] = SplitsReceiver({accountId: 1, weight: totalSplitsWeight - 1}); - receivers[1] = SplitsReceiver({accountId: 2, weight: 1}); + SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 1, 1); uint256 nextAccountId = driver.nextAccountId(); AccountMetadata[] memory metadata = new AccountMetadata[](1); metadata[0] = AccountMetadata("key", "value"); @@ -43,10 +51,20 @@ contract ImmutableSplitsDriverTest is Test { } function testCreateSplitsRevertsWhenWeightsSumTooLow() public { - SplitsReceiver[] memory receivers = new SplitsReceiver[](2); - receivers[0] = SplitsReceiver({accountId: 1, weight: totalSplitsWeight - 2}); - receivers[1] = SplitsReceiver({accountId: 2, weight: 1}); + SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 2, 1); + vm.expectRevert("Invalid total receivers weight"); + driver.createSplits(receivers, new AccountMetadata[](0)); + } + + function testCreateSplitsRevertsWhenWeightsSumTooHigh() public { + SplitsReceiver[] memory receivers = splitsReceivers(totalSplitsWeight - 1, 2); + vm.expectRevert("Invalid total receivers weight"); + driver.createSplits(receivers, new AccountMetadata[](0)); + } + function testCreateSplitsRevertsWhenWeightsSumOverflows() public { + SplitsReceiver[] memory receivers = + splitsReceivers(totalSplitsWeight + 1, type(uint256).max); vm.expectRevert("Invalid total receivers weight"); driver.createSplits(receivers, new AccountMetadata[](0)); } diff --git a/test/Splits.t.sol b/test/Splits.t.sol index 83da965b..6f6794f6 100644 --- a/test/Splits.t.sol +++ b/test/Splits.t.sol @@ -27,7 +27,7 @@ contract SplitsTest is Test, Splits { list = new SplitsReceiver[](0); } - function splitsReceivers(uint256 usedAccountId, uint32 weight) + function splitsReceivers(uint256 usedAccountId, uint256 weight) internal pure returns (SplitsReceiver[] memory list) @@ -36,7 +36,7 @@ contract SplitsTest is Test, Splits { list[0] = SplitsReceiver(usedAccountId, weight); } - function splitsReceivers(uint256 account1, uint32 weight1, uint256 account2, uint32 weight2) + function splitsReceivers(uint256 account1, uint256 weight1, uint256 account2, uint256 weight2) internal pure returns (SplitsReceiver[] memory list) @@ -190,13 +190,23 @@ contract SplitsTest is Test, Splits { } function testRejectsTooHighTotalWeightSplitsReceivers() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; setSplits(accountId, splitsReceivers(receiver, totalWeight)); assertSetSplitsReverts( accountId, splitsReceivers(receiver, totalWeight + 1), "Splits weights sum too high" ); } + function testRejectsOverflowingTotalWeightSplitsReceivers() public { + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + setSplits(accountId, splitsReceivers(receiver, totalWeight)); + assertSetSplitsReverts( + accountId, + splitsReceivers(receiver1, type(uint256).max, receiver2, 4), + "Splits weights sum too high" + ); + } + function testRejectsZeroWeightSplitsReceivers() public { assertSetSplitsReverts( accountId, splitsReceivers(receiver, 0), "Splits receiver weight is zero" @@ -216,7 +226,7 @@ contract SplitsTest is Test, Splits { } function testCanSplitAllWhenCollectedDoesNotSplitEvenly() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; // 3 waiting for accountId addSplittable(accountId, 3); @@ -239,7 +249,7 @@ contract SplitsTest is Test, Splits { } function testSplittingSplitsAllFundsEvenWhenTheyDoNotDivideEvenly() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; setSplits( accountId, splitsReceivers(receiver1, (totalWeight / 5) * 2, receiver2, totalWeight / 5) ); @@ -251,7 +261,7 @@ contract SplitsTest is Test, Splits { } function testAccountCanSplitToItself() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; // receiver1 receives 30%, gets 50% split to themselves and receiver2 gets split 20% setSplits( receiver1, splitsReceivers(receiver1, totalWeight / 2, receiver2, totalWeight / 5) @@ -280,7 +290,7 @@ contract SplitsTest is Test, Splits { } function testSplitsConfigurationIsCommonBetweenTokens() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; setSplits(accountId, splitsReceivers(receiver1, totalWeight / 10)); erc20 = defaultErc20; addSplittable(accountId, 30); @@ -298,7 +308,7 @@ contract SplitsTest is Test, Splits { } function testForwardSplits() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; addSplittable(accountId, 10); setSplits(accountId, splitsReceivers(receiver1, totalWeight)); @@ -315,7 +325,7 @@ contract SplitsTest is Test, Splits { } function testSplitMultipleReceivers() public { - uint32 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; + uint256 totalWeight = Splits._TOTAL_SPLITS_WEIGHT; addSplittable(accountId, 10); setSplits( @@ -352,6 +362,7 @@ contract SplitsTest is Test, Splits { uint256 weightSum = 0; for (uint256 i = 0; i < receivers.length; i++) { receivers[i] = receiversRaw[i]; + receivers[i].weight %= _TOTAL_SPLITS_WEIGHT; weightSum += receivers[i].weight; } if (weightSum == 0) weightSum = 1; @@ -360,8 +371,7 @@ contract SplitsTest is Test, Splits { for (uint256 i = 0; i < receivers.length; i++) { uint256 usedTotalWeight = totalWeight * usedWeight / weightSum; usedWeight += receivers[i].weight; - receivers[i].weight = - uint32((totalWeight * usedWeight / weightSum) - usedTotalWeight + 1); + receivers[i].weight = (totalWeight * usedWeight / weightSum) - usedTotalWeight + 1; } } diff --git a/test/dataStore/DripsDataProxy.t.sol b/test/dataStore/DripsDataProxy.t.sol index 29339d9c..92c1a065 100644 --- a/test/dataStore/DripsDataProxy.t.sol +++ b/test/dataStore/DripsDataProxy.t.sol @@ -63,7 +63,7 @@ contract DripsDataProxyTest is Test { } function testSplit() public { - uint32 splitWeight = drips.TOTAL_SPLITS_WEIGHT() / 4; + uint256 splitWeight = drips.TOTAL_SPLITS_WEIGHT() / 4; uint128 totalAmt = 8; uint128 splitAmt = 2; uint128 collectableAmt = 6;