diff --git a/contracts/src/utils/Uint16Array.sol b/contracts/src/utils/Uint16Array.sol index fcffa22fbc..7caeaeb4ca 100644 --- a/contracts/src/utils/Uint16Array.sol +++ b/contracts/src/utils/Uint16Array.sol @@ -30,18 +30,27 @@ pragma solidity 0.8.22; * then to bit-index 128 to 143 of uint256[1]. */ library Uint16Array { + /** + * @dev stores the backing array and the length. + */ struct Array { uint256[] data; + uint256 length; } + + /** + * @dev error for when out of bound accesses occur. + */ + error IndexOutOfBounds(); + /** * @dev Creates a new counter which can store at least `length` counters. * @param length The amount of counters. */ - function create(uint256 length) internal pure returns (Array memory) { // create space for `length` elements and round up if needed. uint256 bufferLength = length / 16 + (length % 16 == 0 ? 0 : 1); - return Array({data: new uint256[](bufferLength)}); + return Array({data: new uint256[](bufferLength), length: length}); } /** @@ -50,6 +59,9 @@ library Uint16Array { * @param index The logical index. */ function get(Array storage self, uint256 index) internal view returns (uint16) { + if (index >= self.length) { + revert IndexOutOfBounds(); + } // Right-shift the index by 4. This truncates the first 4 bits (bit-index) leaving us with the index // into the array. uint256 element = index >> 4; @@ -66,6 +78,9 @@ library Uint16Array { * @param value The value to set the counter to. */ function set(Array storage self, uint256 index, uint16 value) internal { + if (index >= self.length) { + revert IndexOutOfBounds(); + } // Right-shift the index by 4. This truncates the first 4 bits (bit-index) leaving us with the index // into the array. uint256 element = index >> 4; diff --git a/contracts/test/Uint16Array.t.sol b/contracts/test/Uint16Array.t.sol index b28009f2f8..e537037bb9 100644 --- a/contracts/test/Uint16Array.t.sol +++ b/contracts/test/Uint16Array.t.sol @@ -106,4 +106,16 @@ contract Uint16ArrayTest is Test { console.log("round2:index at %d set %d and get %d", index, value, new_value); assertEq(value, new_value); } + + function testCounterGetOutOfBounds() public { + counters = Uint16Array.create(17); + vm.expectRevert(Uint16Array.IndexOutOfBounds.selector); + counters.get(17); + } + + function testCounterSetOutOfBounds() public { + counters = Uint16Array.create(17); + vm.expectRevert(Uint16Array.IndexOutOfBounds.selector); + counters.set(17, 1); + } }