From d23652b7fe2f55fde3fca2b03d462cd14d55b6b0 Mon Sep 17 00:00:00 2001 From: CodeSandwich Date: Thu, 18 Jul 2024 15:36:07 +0200 Subject: [PATCH] Add initialization to ManagedProxy --- src/DripsDeployer.sol | 14 ++++++++++++-- src/Managed.sol | 17 ++++++++++++++++- test/AddressDriver.t.sol | 4 ++-- test/Drips.t.sol | 2 +- test/DriverTransferUtils.t.sol | 2 +- test/Giver.t.sol | 6 +++--- test/ImmutableSplitsDriver.t.sol | 4 ++-- test/Managed.t.sol | 31 ++++++++++++++++++++++++++++++- test/NFTDriver.t.sol | 4 ++-- test/RepoDriver.t.sol | 4 ++-- 10 files changed, 71 insertions(+), 17 deletions(-) diff --git a/src/DripsDeployer.sol b/src/DripsDeployer.sol index 28bd96e6..c3cb668e 100644 --- a/src/DripsDeployer.sol +++ b/src/DripsDeployer.sol @@ -102,19 +102,28 @@ abstract contract ProxyDeployerModule is BaseModule { bytes32 public immutable proxySalt = "proxy"; address public proxyAdmin; address public logic; + bytes public proxyDelegateCalldata; function proxy() public view returns (address) { return Create3Factory.getDeployed(proxySalt); } function proxyArgs() public view returns (bytes memory) { - return abi.encode(logic, proxyAdmin); + return abi.encode(logic, proxyAdmin, proxyDelegateCalldata); } function logicArgs() public view virtual returns (bytes memory); - // slither-disable-next-line reentrancy-benign function _deployProxy(address proxyAdmin_, bytes memory logicCreationCode) internal { + _deployProxy(proxyAdmin_, logicCreationCode, ""); + } + + // slither-disable-next-line reentrancy-benign + function _deployProxy( + address proxyAdmin_, + bytes memory logicCreationCode, + bytes memory proxyDelegateCalldata_ + ) internal { // Deploy logic address logic_; bytes memory logicInitCode = abi.encodePacked(logicCreationCode, logicArgs()); @@ -126,6 +135,7 @@ abstract contract ProxyDeployerModule is BaseModule { logic = logic_; // Deploy proxy proxyAdmin = proxyAdmin_; + proxyDelegateCalldata = proxyDelegateCalldata_; // slither-disable-next-line too-many-digits bytes memory proxyInitCode = abi.encodePacked(type(ManagedProxy).creationCode, proxyArgs()); Create3Factory.deploy(0, proxySalt, proxyInitCode); diff --git a/src/Managed.sol b/src/Managed.sol index a44fd7f7..fab96bee 100644 --- a/src/Managed.sol +++ b/src/Managed.sol @@ -4,6 +4,7 @@ pragma solidity ^0.8.20; import {UUPSUpgradeable} from "openzeppelin-contracts/proxy/utils/UUPSUpgradeable.sol"; import {ERC1967Proxy} from "openzeppelin-contracts/proxy/ERC1967/ERC1967Proxy.sol"; import {EnumerableSet} from "openzeppelin-contracts/utils/structs/EnumerableSet.sol"; +import {Address} from "openzeppelin-contracts/utils/Address.sol"; import {StorageSlot} from "openzeppelin-contracts/utils/StorageSlot.sol"; using EnumerableSet for EnumerableSet.AddressSet; @@ -55,6 +56,15 @@ abstract contract Managed is UUPSUpgradeable { _; } + /// @notice Throws if called by any caller other than the admin. + /// May be called by anybody if delegated to from a constructor. + modifier onlyAdminOrConstructor() { + if (Address.isContract(address(this))) { + require(admin() == msg.sender, "Caller not the admin"); + } + _; + } + /// @notice Throws if called by any caller other than the admin or a pauser. modifier onlyAdminOrPauser() { require(admin() == msg.sender || isPauser(msg.sender), "Caller not the admin or a pauser"); @@ -201,7 +211,12 @@ abstract contract Managed is UUPSUpgradeable { /// @notice A generic proxy for contracts implementing `Managed`. contract ManagedProxy is ERC1967Proxy { - constructor(Managed logic, address admin) ERC1967Proxy(address(logic), new bytes(0)) { + /// @param logic The initial implementation address of the proxy. + /// @param admin The initial admin of the proxy. + /// @param data If non-empty, used as calldata to delegate to `logic`. + constructor(Managed logic, address admin, bytes memory data) + ERC1967Proxy(address(logic), data) + { _changeAdmin(admin); } } diff --git a/test/AddressDriver.t.sol b/test/AddressDriver.t.sol index 5fd1ad78..96c7b72a 100644 --- a/test/AddressDriver.t.sol +++ b/test/AddressDriver.t.sol @@ -31,7 +31,7 @@ contract AddressDriverTest is Test { function setUp() public { Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, address(this)))); + drips = Drips(address(new ManagedProxy(dripsLogic, address(this), ""))); caller = new Caller(); @@ -40,7 +40,7 @@ contract AddressDriverTest is Test { drips.registerDriver(address(1)); uint32 driverId = drips.registerDriver(address(this)); AddressDriver driverLogic = new AddressDriver(drips, address(caller), driverId); - driver = AddressDriver(address(new ManagedProxy(driverLogic, admin))); + driver = AddressDriver(address(new ManagedProxy(driverLogic, admin, ""))); drips.updateDriverAddress(driverId, address(driver)); thisId = driver.calcAccountId(address(this)); diff --git a/test/Drips.t.sol b/test/Drips.t.sol index 74e1ee20..54c739a3 100644 --- a/test/Drips.t.sol +++ b/test/Drips.t.sol @@ -48,7 +48,7 @@ contract DripsTest is Test { otherErc20 = new ERC20PresetFixedSupply("other", "other", 2 ** 128, address(this)); erc20 = defaultErc20; Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, admin))); + drips = Drips(address(new ManagedProxy(dripsLogic, admin, ""))); driverId = drips.registerDriver(driver); uint256 baseAccountId = driverId << 224; diff --git a/test/DriverTransferUtils.t.sol b/test/DriverTransferUtils.t.sol index c70e4e4d..e7a242bc 100644 --- a/test/DriverTransferUtils.t.sol +++ b/test/DriverTransferUtils.t.sol @@ -76,7 +76,7 @@ contract DriverTransferUtilsTest is Test { function setUp() public { Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, address(this)))); + drips = Drips(address(new ManagedProxy(dripsLogic, address(this), ""))); caller = new Caller(); diff --git a/test/Giver.t.sol b/test/Giver.t.sol index 0363cd95..ab4db86b 100644 --- a/test/Giver.t.sol +++ b/test/Giver.t.sol @@ -64,16 +64,16 @@ contract GiversRegistryTest is Test { function setUp() public { Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, admin))); + drips = Drips(address(new ManagedProxy(dripsLogic, admin, ""))); drips.registerDriver(address(1)); AddressDriver addressDriverLogic = new AddressDriver(drips, address(0), drips.nextDriverId()); - addressDriver = AddressDriver(address(new ManagedProxy(addressDriverLogic, admin))); + addressDriver = AddressDriver(address(new ManagedProxy(addressDriverLogic, admin, ""))); drips.registerDriver(address(addressDriver)); nativeTokenWrapper = new NativeTokenWrapper(); GiversRegistry giversRegistryLogic = new GiversRegistry(addressDriver, nativeTokenWrapper); - giversRegistry = GiversRegistry(address(new ManagedProxy(giversRegistryLogic, admin))); + giversRegistry = GiversRegistry(address(new ManagedProxy(giversRegistryLogic, admin, ""))); accountId = 1234; giver = payable(giversRegistry.giver(accountId)); emit log_named_address("GIVER", giver); diff --git a/test/ImmutableSplitsDriver.t.sol b/test/ImmutableSplitsDriver.t.sol index 5c0039a2..8ea77fbd 100644 --- a/test/ImmutableSplitsDriver.t.sol +++ b/test/ImmutableSplitsDriver.t.sol @@ -14,14 +14,14 @@ contract ImmutableSplitsDriverTest is Test { function setUp() public { Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, address(this)))); + drips = Drips(address(new ManagedProxy(dripsLogic, address(this), ""))); // Make the driver ID non-0 to test if it's respected by the driver drips.registerDriver(address(1)); drips.registerDriver(address(1)); uint32 driverId = drips.registerDriver(address(this)); ImmutableSplitsDriver driverLogic = new ImmutableSplitsDriver(drips, driverId); - driver = ImmutableSplitsDriver(address(new ManagedProxy(driverLogic, admin))); + driver = ImmutableSplitsDriver(address(new ManagedProxy(driverLogic, admin, ""))); drips.updateDriverAddress(driverId, address(driver)); totalSplitsWeight = driver.totalSplitsWeight(); } diff --git a/test/Managed.t.sol b/test/Managed.t.sol index 3eb6669c..e440f67c 100644 --- a/test/Managed.t.sol +++ b/test/Managed.t.sol @@ -6,6 +6,7 @@ import {Test} from "forge-std/Test.sol"; contract Logic is Managed { uint256 public immutable instanceId; + bool public called; constructor(uint256 instanceId_) { instanceId = instanceId_; @@ -14,6 +15,10 @@ contract Logic is Managed { function erc1967Slot(string memory name) public pure returns (bytes32 slot) { return _erc1967Slot(name); } + + function onlyAdminOrConstructorFunction() public onlyAdminOrConstructor { + called = true; + } } contract ManagedTest is Test { @@ -29,7 +34,7 @@ contract ManagedTest is Test { function setUp() public { logic = new Logic(0); - proxy = Logic(address(new ManagedProxy(logic, admin))); + proxy = Logic(address(new ManagedProxy(logic, admin, ""))); vm.prank(admin); proxy.grantPauser(pauser); } @@ -45,6 +50,30 @@ contract ManagedTest is Test { assertEq(logic.allPausers(), new address[](0), "Pausers not empty"); } + function testArbitraryUserCanNotCallOnlyAdminOrConstructorFunction() public { + vm.expectRevert(ERROR_NOT_ADMIN); + proxy.onlyAdminOrConstructorFunction(); + } + + function testPauserCanNotCallOnlyAdminOrConstructorFunction() public { + vm.prank(pauser); + vm.expectRevert(ERROR_NOT_ADMIN); + proxy.onlyAdminOrConstructorFunction(); + } + + function testAdminCanCallOnlyAdminOrConstructorFunction() public { + vm.prank(admin); + proxy.onlyAdminOrConstructorFunction(); + assertTrue(proxy.called(), "Function wasn't called"); + } + + function testProxyConstructorCanCallOnlyAdminOrConstructorFunction() public { + bytes memory data = abi.encodeCall(Logic.onlyAdminOrConstructorFunction, ()); + Logic proxy_ = Logic(address(new ManagedProxy(new Logic(0), admin, data))); + assertEq(proxy_.admin(), admin, "Invalid admin address"); + assertTrue(proxy_.called(), "Function wasn't called"); + } + function testAdminCanProposeNewAdmin() public { assertEq(proxy.proposedAdmin(), address(0)); diff --git a/test/NFTDriver.t.sol b/test/NFTDriver.t.sol index 7abe55f6..797b1611 100644 --- a/test/NFTDriver.t.sol +++ b/test/NFTDriver.t.sol @@ -36,7 +36,7 @@ contract NFTDriverTest is Test { function setUp() public { Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, address(this)))); + drips = Drips(address(new ManagedProxy(dripsLogic, address(this), ""))); caller = new Caller(); @@ -45,7 +45,7 @@ contract NFTDriverTest is Test { drips.registerDriver(address(1)); uint32 driverId = drips.registerDriver(address(this)); NFTDriver driverLogic = new NFTDriver(drips, address(caller), driverId); - driver = NFTDriver(address(new ManagedProxy(driverLogic, admin))); + driver = NFTDriver(address(new ManagedProxy(driverLogic, admin, ""))); drips.updateDriverAddress(driverId, address(driver)); tokenId = driver.mint(address(this), noMetadata()); diff --git a/test/RepoDriver.t.sol b/test/RepoDriver.t.sol index 575cd0f2..3b7b9338 100644 --- a/test/RepoDriver.t.sol +++ b/test/RepoDriver.t.sol @@ -69,7 +69,7 @@ contract RepoDriverTest is Test { function setUp() public { Drips dripsLogic = new Drips(10); - drips = Drips(address(new ManagedProxy(dripsLogic, address(this)))); + drips = Drips(address(new ManagedProxy(dripsLogic, address(this), ""))); caller = new Caller(); @@ -105,7 +105,7 @@ contract RepoDriverTest is Test { function deployDriverUninitialized() internal { uint32 driverId = drips.registerDriver(address(this)); RepoDriver driverLogic = new RepoDriver(drips, address(caller), driverId); - driver = RepoDriver(address(new ManagedProxy(driverLogic, admin))); + driver = RepoDriver(address(new ManagedProxy(driverLogic, admin, ""))); drips.updateDriverAddress(driverId, address(driver)); driverNonce = 0; }