Skip to content

Commit

Permalink
Calculate proxy expected address in ProxyFactory
Browse files Browse the repository at this point in the history
- Add `calculate_proxy_address` to ProxyFactory
- Add option to deploy specific chain id proxies (feature added in v1.4.1)
  • Loading branch information
Uxio0 committed Nov 15, 2023
1 parent c7d6d9d commit 88328ec
Show file tree
Hide file tree
Showing 3 changed files with 167 additions and 35 deletions.
119 changes: 105 additions & 14 deletions gnosis/safe/proxy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import ABCMeta
from typing import Callable, Optional

from eth_abi.packed import encode_packed
from eth_account.signers.local import LocalAccount
from eth_typing import ChecksumAddress
from web3 import Web3
Expand All @@ -21,12 +22,20 @@
get_proxy_factory_V1_3_0_contract,
get_proxy_factory_V1_4_1_contract,
)
from gnosis.eth.utils import compare_byte_code, get_empty_tx_params
from gnosis.eth.utils import (
compare_byte_code,
fast_keccak,
get_empty_tx_params,
mk_contract_address_2,
)
from gnosis.util import cache


class ProxyFactory(ContractBase, metaclass=ABCMeta):
def __new__(cls, *args, version: str = "1.4.1", **kwargs) -> "ProxyFactory":
if cls is not ProxyFactory:
return super().__new__(cls)

versions = {
"1.0.0": ProxyFactoryV100,
"1.1.1": ProxyFactoryV111,
Expand Down Expand Up @@ -57,6 +66,30 @@ def deploy_contract(
deployer_account, constructor_data
)

@cache
def get_proxy_creation_code(self) -> bytes:
"""
:return: Creation code used for the Proxy deployment.
With this it is easily possible to calculate predicted address.
"""
return self.contract.functions.proxyCreationCode().call()

@cache
def get_proxy_runtime_code(self) -> bytes:
"""
:return: Runtime code of a deployed Proxy. For v1.4.1 onwards the method is not available, so `None`
will be returned
"""
return self.contract.functions.proxyRuntimeCode().call()

def get_deploy_function(self, chain_specific: bool) -> ContractFunction:
if chain_specific:
raise NotImplementedError(
f"createChainSpecificProxyWithNonce is not supported in {self.__class__.__name__}"
)

return self.contract.functions.createProxyWithNonce

def check_proxy_code(self, address: ChecksumAddress) -> bool:
"""
Check if proxy bytecode matches any of the deployed by the supported Proxy Factories
Expand All @@ -65,6 +98,12 @@ def check_proxy_code(self, address: ChecksumAddress) -> bool:
:return: ``True`` if proxy is valid, ``False`` otherwise
"""

def get_proxy_runtime_code() -> Optional[bytes]:
try:
return self.get_proxy_runtime_code()
except NotImplementedError:
return None

deployed_proxy_code = self.w3.eth.get_code(address)
proxy_code_fns = (
get_proxy_1_4_1_deployed_bytecode,
Expand All @@ -73,7 +112,7 @@ def check_proxy_code(self, address: ChecksumAddress) -> bool:
get_proxy_1_1_1_mainnet_deployed_bytecode,
get_proxy_1_0_0_deployed_bytecode,
get_paying_proxy_deployed_bytecode,
self.get_proxy_runtime_code,
get_proxy_runtime_code,
)
for proxy_code_fn in proxy_code_fns:
proxy_code = proxy_code_fn()
Expand Down Expand Up @@ -140,6 +179,48 @@ def deploy_proxy_contract(
deployer_account, create_proxy_fn, gas=gas, gas_price=gas_price, nonce=nonce
)

def calculate_proxy_address(
self,
master_copy: ChecksumAddress,
initializer: bytes,
salt_nonce: int,
chain_specific: bool = False,
) -> ChecksumAddress:
"""
Calculate proxy address for calling deploy_proxy_contract_with_nonce
:param master_copy:
:param initializer:
:param salt_nonce:
:param chain_specific: Calculate chain specific address (to prevent same address in other chains)
:return:
"""

if chain_specific:
salt_nonce = fast_keccak(
encode_packed(
["bytes32", "uint256", "uint256"],
[
fast_keccak(initializer),
salt_nonce,
self.ethereum_client.get_chain_id(),
],
)
)
else:
salt_nonce = fast_keccak(
encode_packed(
["bytes32", "uint256"], [fast_keccak(initializer), salt_nonce]
)
)

proxy_creation_code = self.get_proxy_creation_code()
deployment_data = encode_packed(
["bytes", "uint256"], [proxy_creation_code, int(master_copy, 0)]
)

return mk_contract_address_2(self.address, salt_nonce, deployment_data)

def deploy_proxy_contract_with_nonce(
self,
deployer_account: LocalAccount,
Expand All @@ -149,6 +230,7 @@ def deploy_proxy_contract_with_nonce(
gas: Optional[int] = None,
gas_price: Optional[int] = None,
nonce: Optional[int] = None,
chain_specific: bool = False,
) -> EthereumTxSent:
"""
Deploy proxy contract via Proxy Factory using `createProxyWithNonce` (CREATE2 opcode)
Expand All @@ -160,26 +242,18 @@ def deploy_proxy_contract_with_nonce(
:param gas: Gas
:param gas_price: Gas Price
:param nonce: Nonce
:param chain_specific: Calculate chain specific address (to prevent same address in other chains)
:return: EthereumTxSent
"""

function = self.get_deploy_function(chain_specific)
salt_nonce = salt_nonce if salt_nonce is not None else secrets.randbits(256)
create_proxy_fn = self.contract.functions.createProxyWithNonce(
master_copy, initializer, salt_nonce
)
create_proxy_fn = function(master_copy, initializer, salt_nonce)

return self._deploy_proxy_contract(
deployer_account, create_proxy_fn, gas=gas, gas_price=gas_price, nonce=nonce
)

@cache
def get_proxy_runtime_code(self) -> Optional[bytes]:
"""
:return: Runtime code for current proxy factory. For v1.4.1 onwards the method is not avaiable, so `None`
will be returned
"""
if hasattr(self.contract.functions, "proxyRuntimeCode"):
return self.contract.functions.proxyRuntimeCode().call()


class ProxyFactoryV100(ProxyFactory):
def get_contract_fn(self) -> Callable[[Web3, ChecksumAddress], Contract]:
Expand All @@ -200,6 +274,23 @@ class ProxyFactoryV141(ProxyFactory):
def get_contract_fn(self) -> Callable[[Web3, ChecksumAddress], Contract]:
return get_proxy_factory_V1_4_1_contract

@cache
def get_proxy_runtime_code(self) -> Optional[bytes]:
"""
:return: From v1.4.1 onwards the method is not available
:raises: NotImplementedError
"""
raise NotImplementedError(
"Deprecated, only creation code is available using `get_proxy_creation_code`"
)

def get_deploy_function(self, chain_specific: bool) -> ContractFunction:
return (
self.contract.functions.createChainSpecificProxyWithNonce
if chain_specific
else super().get_deploy_function(chain_specific)
)

def deploy_proxy_contract(self, *args, **kwargs):
"""
.. deprecated:: ``createProxy`` function was deprecated in v1.4.1, use ``deploy_proxy_contract_with_nonce``
Expand Down
40 changes: 20 additions & 20 deletions gnosis/safe/safe.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,27 +99,27 @@ def __new__(
assert fast_is_checksum_address(address), "%s is not a valid address" % address
if cls is not Safe:
return super().__new__(cls, address, ethereum_client, *args, **kwargs)
else:
versions: Dict[str, Safe] = {
"0.0.1": SafeV001,
"1.0.0": SafeV100,
"1.1.1": SafeV111,
"1.2.0": SafeV120,
"1.3.0": SafeV130,
"1.4.1": SafeV141,
}
default_version = SafeV141

version: Optional[str]
try:
contract = get_safe_contract(ethereum_client.w3, address=address)
version = contract.functions.VERSION().call(block_identifier="latest")
except (Web3Exception, ValueError):
version = None # Cannot detect the version

instance_class = versions.get(version, default_version)
instance = super().__new__(instance_class)
return instance
versions: Dict[str, Safe] = {
"0.0.1": SafeV001,
"1.0.0": SafeV100,
"1.1.1": SafeV111,
"1.2.0": SafeV120,
"1.3.0": SafeV130,
"1.4.1": SafeV141,
}
default_version = SafeV141

version: Optional[str]
try:
contract = get_safe_contract(ethereum_client.w3, address=address)
version = contract.functions.VERSION().call(block_identifier="latest")
except (Web3Exception, ValueError):
version = None # Cannot detect the version

instance_class = versions.get(version, default_version)
instance = super().__new__(instance_class)
return instance

def __init__(
self,
Expand Down
43 changes: 42 additions & 1 deletion gnosis/safe/tests/test_proxy_factory/test_proxy_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from django.test import TestCase

from eth_account import Account
from web3 import Web3

from gnosis.eth import EthereumClient
from gnosis.eth.contracts import (
Expand Down Expand Up @@ -100,6 +101,45 @@ def test_check_proxy_code_mainnet(self):
with self.subTest(safe=safe):
self.assertTrue(proxy_factory.check_proxy_code(safe))

def test_calculate_proxy_address(self):
salt_nonce = 12
address = self.proxy_factory.calculate_proxy_address(
self.safe_contract_V1_4_1.address, b"", salt_nonce
)
self.assertTrue(Web3.is_checksum_address(address))
# Same call with same parameters should return the same address
same_address = self.proxy_factory.calculate_proxy_address(
self.safe_contract_V1_4_1.address, b"", salt_nonce
)
self.assertEqual(address, same_address)
ethereum_tx_sent = self.proxy_factory.deploy_proxy_contract_with_nonce(
self.ethereum_test_account,
self.safe_contract_V1_4_1.address,
initializer=b"",
salt_nonce=salt_nonce,
)
self.assertEqual(ethereum_tx_sent.contract_address, address)

# Calculating the proxy address after deployment should return the same address
address_after_deploying = self.proxy_factory.calculate_proxy_address(
self.safe_contract_V1_4_1.address, b"", salt_nonce
)
self.assertEqual(ethereum_tx_sent.contract_address, address_after_deploying)

chain_specific_address = self.proxy_factory.calculate_proxy_address(
self.safe_contract_V1_4_1.address, b"", salt_nonce, chain_specific=True
)
self.assertTrue(Web3.is_checksum_address(chain_specific_address))
self.assertNotEqual(address, chain_specific_address)
ethereum_tx_sent = self.proxy_factory.deploy_proxy_contract_with_nonce(
self.ethereum_test_account,
self.safe_contract_V1_4_1.address,
initializer=b"",
salt_nonce=salt_nonce,
chain_specific=True,
)
self.assertEqual(ethereum_tx_sent.contract_address, chain_specific_address)

def test_deploy_proxy_contract_with_nonce(self):
salt_nonce = generate_salt_nonce()
owners = [Account.create().address for _ in range(2)]
Expand Down Expand Up @@ -146,4 +186,5 @@ def test_deploy_proxy_contract_with_nonce(self):
)

def test_get_proxy_runtime_code(self):
self.assertIsNone(self.proxy_factory.get_proxy_runtime_code())
with self.assertRaises(NotImplementedError):
self.proxy_factory.get_proxy_runtime_code()

0 comments on commit 88328ec

Please sign in to comment.