diff --git a/l1-contracts/contracts/state-transition/libraries/Merkle.sol b/l1-contracts/contracts/state-transition/libraries/Merkle.sol index 8680f9ac6..6b92875af 100644 --- a/l1-contracts/contracts/state-transition/libraries/Merkle.sol +++ b/l1-contracts/contracts/state-transition/libraries/Merkle.sol @@ -38,6 +38,46 @@ library Merkle { return currentHash; } + /// @dev Calculate Merkle root by the provided Merkle proof for a range of elements + /// NOTE: When using this function, check that the _startPath and _endPath lengths are equal to the tree height to prevent shorter/longer paths attack + /// @param _startPath Merkle path from the first element of the range to the root + /// @param _endPath Merkle path from the last element of the range to the root + /// @param _startIndex Index of the first element of the range in the tree + /// @param _itemHashes Hashes of the elements in the range + /// @return The Merkle root + function calculateRoot( + bytes32[] calldata _startPath, + bytes32[] calldata _endPath, + uint256 _startIndex, + bytes32[] calldata _itemHashes + ) internal pure returns (bytes32) { + uint256 pathLength = _startPath.length; + + require(pathLength == _endPath.length, ""); + require(pathLength > 0, ""); + require(pathLength < 256, ""); + + uint256 levelLen = _itemHashes.length; + + require(_startIndex + levelLen <= (1 << pathLength), ""); + + bytes32[] memory itemHashes = _itemHashes; + + for (uint256 level; level < pathLength; level = level.uncheckedInc()) { + uint256 parity = _startIndex % 2; + uint256 nextLevelLen = levelLen / 2 + (parity | (levelLen % 2)); + for (uint256 i; i < nextLevelLen; i = i.uncheckedInc()) { + bytes32 lhs = (i == 0 && parity == 1) ? _startPath[level] : itemHashes[2 * i - parity]; + bytes32 rhs = (i == nextLevelLen - 1 && (levelLen - parity) % 2 == 1) ? _endPath[level] : itemHashes[2 * i + 1 - parity]; + itemHashes[i] = _efficientHash(lhs, rhs); + } + levelLen = nextLevelLen; + _startIndex /= 2; + } + + return itemHashes[0]; + } + /// @dev Keccak hash of the concatenation of two 32-byte words function _efficientHash(bytes32 _lhs, bytes32 _rhs) private pure returns (bytes32 result) { assembly {