From 88328ec31da49591975ec8590f4c36793d8e7efe Mon Sep 17 00:00:00 2001 From: Uxio Fuentefria Date: Mon, 13 Nov 2023 19:08:11 +0100 Subject: [PATCH] Calculate proxy expected address in ProxyFactory - Add `calculate_proxy_address` to ProxyFactory - Add option to deploy specific chain id proxies (feature added in v1.4.1) --- gnosis/safe/proxy_factory.py | 119 +++++++++++++++--- gnosis/safe/safe.py | 40 +++--- .../test_proxy_factory/test_proxy_factory.py | 43 ++++++- 3 files changed, 167 insertions(+), 35 deletions(-) diff --git a/gnosis/safe/proxy_factory.py b/gnosis/safe/proxy_factory.py index 25bf22068..c767c377e 100644 --- a/gnosis/safe/proxy_factory.py +++ b/gnosis/safe/proxy_factory.py @@ -2,6 +2,7 @@ from abc import ABCMeta from typing import Callable, Optional +from eth_abi.packed import encode_packed from eth_account.signers.local import LocalAccount from eth_typing import ChecksumAddress from web3 import Web3 @@ -21,12 +22,20 @@ get_proxy_factory_V1_3_0_contract, get_proxy_factory_V1_4_1_contract, ) -from gnosis.eth.utils import compare_byte_code, get_empty_tx_params +from gnosis.eth.utils import ( + compare_byte_code, + fast_keccak, + get_empty_tx_params, + mk_contract_address_2, +) from gnosis.util import cache class ProxyFactory(ContractBase, metaclass=ABCMeta): def __new__(cls, *args, version: str = "1.4.1", **kwargs) -> "ProxyFactory": + if cls is not ProxyFactory: + return super().__new__(cls) + versions = { "1.0.0": ProxyFactoryV100, "1.1.1": ProxyFactoryV111, @@ -57,6 +66,30 @@ def deploy_contract( deployer_account, constructor_data ) + @cache + def get_proxy_creation_code(self) -> bytes: + """ + :return: Creation code used for the Proxy deployment. + With this it is easily possible to calculate predicted address. + """ + return self.contract.functions.proxyCreationCode().call() + + @cache + def get_proxy_runtime_code(self) -> bytes: + """ + :return: Runtime code of a deployed Proxy. For v1.4.1 onwards the method is not available, so `None` + will be returned + """ + return self.contract.functions.proxyRuntimeCode().call() + + def get_deploy_function(self, chain_specific: bool) -> ContractFunction: + if chain_specific: + raise NotImplementedError( + f"createChainSpecificProxyWithNonce is not supported in {self.__class__.__name__}" + ) + + return self.contract.functions.createProxyWithNonce + def check_proxy_code(self, address: ChecksumAddress) -> bool: """ Check if proxy bytecode matches any of the deployed by the supported Proxy Factories @@ -65,6 +98,12 @@ def check_proxy_code(self, address: ChecksumAddress) -> bool: :return: ``True`` if proxy is valid, ``False`` otherwise """ + def get_proxy_runtime_code() -> Optional[bytes]: + try: + return self.get_proxy_runtime_code() + except NotImplementedError: + return None + deployed_proxy_code = self.w3.eth.get_code(address) proxy_code_fns = ( get_proxy_1_4_1_deployed_bytecode, @@ -73,7 +112,7 @@ def check_proxy_code(self, address: ChecksumAddress) -> bool: get_proxy_1_1_1_mainnet_deployed_bytecode, get_proxy_1_0_0_deployed_bytecode, get_paying_proxy_deployed_bytecode, - self.get_proxy_runtime_code, + get_proxy_runtime_code, ) for proxy_code_fn in proxy_code_fns: proxy_code = proxy_code_fn() @@ -140,6 +179,48 @@ def deploy_proxy_contract( deployer_account, create_proxy_fn, gas=gas, gas_price=gas_price, nonce=nonce ) + def calculate_proxy_address( + self, + master_copy: ChecksumAddress, + initializer: bytes, + salt_nonce: int, + chain_specific: bool = False, + ) -> ChecksumAddress: + """ + Calculate proxy address for calling deploy_proxy_contract_with_nonce + + :param master_copy: + :param initializer: + :param salt_nonce: + :param chain_specific: Calculate chain specific address (to prevent same address in other chains) + :return: + """ + + if chain_specific: + salt_nonce = fast_keccak( + encode_packed( + ["bytes32", "uint256", "uint256"], + [ + fast_keccak(initializer), + salt_nonce, + self.ethereum_client.get_chain_id(), + ], + ) + ) + else: + salt_nonce = fast_keccak( + encode_packed( + ["bytes32", "uint256"], [fast_keccak(initializer), salt_nonce] + ) + ) + + proxy_creation_code = self.get_proxy_creation_code() + deployment_data = encode_packed( + ["bytes", "uint256"], [proxy_creation_code, int(master_copy, 0)] + ) + + return mk_contract_address_2(self.address, salt_nonce, deployment_data) + def deploy_proxy_contract_with_nonce( self, deployer_account: LocalAccount, @@ -149,6 +230,7 @@ def deploy_proxy_contract_with_nonce( gas: Optional[int] = None, gas_price: Optional[int] = None, nonce: Optional[int] = None, + chain_specific: bool = False, ) -> EthereumTxSent: """ Deploy proxy contract via Proxy Factory using `createProxyWithNonce` (CREATE2 opcode) @@ -160,26 +242,18 @@ def deploy_proxy_contract_with_nonce( :param gas: Gas :param gas_price: Gas Price :param nonce: Nonce + :param chain_specific: Calculate chain specific address (to prevent same address in other chains) :return: EthereumTxSent """ + + function = self.get_deploy_function(chain_specific) salt_nonce = salt_nonce if salt_nonce is not None else secrets.randbits(256) - create_proxy_fn = self.contract.functions.createProxyWithNonce( - master_copy, initializer, salt_nonce - ) + create_proxy_fn = function(master_copy, initializer, salt_nonce) return self._deploy_proxy_contract( deployer_account, create_proxy_fn, gas=gas, gas_price=gas_price, nonce=nonce ) - @cache - def get_proxy_runtime_code(self) -> Optional[bytes]: - """ - :return: Runtime code for current proxy factory. For v1.4.1 onwards the method is not avaiable, so `None` - will be returned - """ - if hasattr(self.contract.functions, "proxyRuntimeCode"): - return self.contract.functions.proxyRuntimeCode().call() - class ProxyFactoryV100(ProxyFactory): def get_contract_fn(self) -> Callable[[Web3, ChecksumAddress], Contract]: @@ -200,6 +274,23 @@ class ProxyFactoryV141(ProxyFactory): def get_contract_fn(self) -> Callable[[Web3, ChecksumAddress], Contract]: return get_proxy_factory_V1_4_1_contract + @cache + def get_proxy_runtime_code(self) -> Optional[bytes]: + """ + :return: From v1.4.1 onwards the method is not available + :raises: NotImplementedError + """ + raise NotImplementedError( + "Deprecated, only creation code is available using `get_proxy_creation_code`" + ) + + def get_deploy_function(self, chain_specific: bool) -> ContractFunction: + return ( + self.contract.functions.createChainSpecificProxyWithNonce + if chain_specific + else super().get_deploy_function(chain_specific) + ) + def deploy_proxy_contract(self, *args, **kwargs): """ .. deprecated:: ``createProxy`` function was deprecated in v1.4.1, use ``deploy_proxy_contract_with_nonce`` diff --git a/gnosis/safe/safe.py b/gnosis/safe/safe.py index 1d6d2723f..c4d0baf6a 100644 --- a/gnosis/safe/safe.py +++ b/gnosis/safe/safe.py @@ -99,27 +99,27 @@ def __new__( assert fast_is_checksum_address(address), "%s is not a valid address" % address if cls is not Safe: return super().__new__(cls, address, ethereum_client, *args, **kwargs) - else: - versions: Dict[str, Safe] = { - "0.0.1": SafeV001, - "1.0.0": SafeV100, - "1.1.1": SafeV111, - "1.2.0": SafeV120, - "1.3.0": SafeV130, - "1.4.1": SafeV141, - } - default_version = SafeV141 - version: Optional[str] - try: - contract = get_safe_contract(ethereum_client.w3, address=address) - version = contract.functions.VERSION().call(block_identifier="latest") - except (Web3Exception, ValueError): - version = None # Cannot detect the version - - instance_class = versions.get(version, default_version) - instance = super().__new__(instance_class) - return instance + versions: Dict[str, Safe] = { + "0.0.1": SafeV001, + "1.0.0": SafeV100, + "1.1.1": SafeV111, + "1.2.0": SafeV120, + "1.3.0": SafeV130, + "1.4.1": SafeV141, + } + default_version = SafeV141 + + version: Optional[str] + try: + contract = get_safe_contract(ethereum_client.w3, address=address) + version = contract.functions.VERSION().call(block_identifier="latest") + except (Web3Exception, ValueError): + version = None # Cannot detect the version + + instance_class = versions.get(version, default_version) + instance = super().__new__(instance_class) + return instance def __init__( self, diff --git a/gnosis/safe/tests/test_proxy_factory/test_proxy_factory.py b/gnosis/safe/tests/test_proxy_factory/test_proxy_factory.py index f33d09de6..7d8191def 100644 --- a/gnosis/safe/tests/test_proxy_factory/test_proxy_factory.py +++ b/gnosis/safe/tests/test_proxy_factory/test_proxy_factory.py @@ -3,6 +3,7 @@ from django.test import TestCase from eth_account import Account +from web3 import Web3 from gnosis.eth import EthereumClient from gnosis.eth.contracts import ( @@ -100,6 +101,45 @@ def test_check_proxy_code_mainnet(self): with self.subTest(safe=safe): self.assertTrue(proxy_factory.check_proxy_code(safe)) + def test_calculate_proxy_address(self): + salt_nonce = 12 + address = self.proxy_factory.calculate_proxy_address( + self.safe_contract_V1_4_1.address, b"", salt_nonce + ) + self.assertTrue(Web3.is_checksum_address(address)) + # Same call with same parameters should return the same address + same_address = self.proxy_factory.calculate_proxy_address( + self.safe_contract_V1_4_1.address, b"", salt_nonce + ) + self.assertEqual(address, same_address) + ethereum_tx_sent = self.proxy_factory.deploy_proxy_contract_with_nonce( + self.ethereum_test_account, + self.safe_contract_V1_4_1.address, + initializer=b"", + salt_nonce=salt_nonce, + ) + self.assertEqual(ethereum_tx_sent.contract_address, address) + + # Calculating the proxy address after deployment should return the same address + address_after_deploying = self.proxy_factory.calculate_proxy_address( + self.safe_contract_V1_4_1.address, b"", salt_nonce + ) + self.assertEqual(ethereum_tx_sent.contract_address, address_after_deploying) + + chain_specific_address = self.proxy_factory.calculate_proxy_address( + self.safe_contract_V1_4_1.address, b"", salt_nonce, chain_specific=True + ) + self.assertTrue(Web3.is_checksum_address(chain_specific_address)) + self.assertNotEqual(address, chain_specific_address) + ethereum_tx_sent = self.proxy_factory.deploy_proxy_contract_with_nonce( + self.ethereum_test_account, + self.safe_contract_V1_4_1.address, + initializer=b"", + salt_nonce=salt_nonce, + chain_specific=True, + ) + self.assertEqual(ethereum_tx_sent.contract_address, chain_specific_address) + def test_deploy_proxy_contract_with_nonce(self): salt_nonce = generate_salt_nonce() owners = [Account.create().address for _ in range(2)] @@ -146,4 +186,5 @@ def test_deploy_proxy_contract_with_nonce(self): ) def test_get_proxy_runtime_code(self): - self.assertIsNone(self.proxy_factory.get_proxy_runtime_code()) + with self.assertRaises(NotImplementedError): + self.proxy_factory.get_proxy_runtime_code()