Skip to content

Commit

Permalink
Add support for preimage hash for EIP1271 signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
Uxio0 committed Dec 19, 2023
1 parent 33c987a commit 5d2958a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 38 deletions.
93 changes: 58 additions & 35 deletions gnosis/safe/safe_signature.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from enum import Enum
from logging import getLogger
from typing import List, Union
from typing import List, Optional, Union

from eth_abi import decode as decode_abi
from eth_abi import encode as encode_abi
Expand All @@ -12,7 +12,10 @@
from web3.exceptions import Web3Exception

from gnosis.eth import EthereumClient
from gnosis.eth.contracts import get_safe_contract, get_safe_V1_1_1_contract
from gnosis.eth.contracts import (
get_compatibility_fallback_handler_contract,
get_safe_contract,
)
from gnosis.eth.utils import fast_to_checksum_address
from gnosis.safe.signatures import (
get_signing_address,
Expand Down Expand Up @@ -68,9 +71,13 @@ def uint_to_address(value: int) -> ChecksumAddress:


class SafeSignature(ABC):
def __init__(self, signature: EthereumBytes, safe_tx_hash: EthereumBytes):
def __init__(self, signature: EthereumBytes, safe_hash: EthereumBytes):
"""
:param signature: Owner signature
:param safe_hash: Signed hash for the Safe (message or transaction)
"""
self.signature = HexBytes(signature)
self.safe_tx_hash = HexBytes(safe_tx_hash)
self.safe_hash = HexBytes(safe_hash)
self.v, self.r, self.s = signature_split(self.signature)

def __str__(self):
Expand All @@ -80,12 +87,14 @@ def __str__(self):
def parse_signature(
cls,
signatures: EthereumBytes,
safe_tx_hash: EthereumBytes,
safe_hash: EthereumBytes,
safe_hash_preimage: Optional[EthereumBytes] = None,
ignore_trailing: bool = True,
) -> List["SafeSignature"]:
"""
:param signatures: One or more signatures appended. EIP1271 data at the end is supported.
:param safe_tx_hash:
:param safe_hash: Signed hash for the Safe (message or transaction)
:param safe_hash_preimage: ``safe_hash`` preimage for EIP1271 validation
:param ignore_trailing: Ignore trailing data on the signature. Some libraries pad it and add some zeroes at
the end
:return: List of SafeSignatures decoded
Expand Down Expand Up @@ -124,14 +133,17 @@ def parse_signature(
s + 32 : s + 32 + contract_signature_len
] # Skip array size (32 bytes)
safe_signature = SafeSignatureContract(
signature, safe_tx_hash, contract_signature
signature,
safe_hash,
safe_hash_preimage or safe_hash,
contract_signature,
)
elif signature_type == SafeSignatureType.APPROVED_HASH:
safe_signature = SafeSignatureApprovedHash(signature, safe_tx_hash)
safe_signature = SafeSignatureApprovedHash(signature, safe_hash)
elif signature_type == SafeSignatureType.EOA:
safe_signature = SafeSignatureEOA(signature, safe_tx_hash)
safe_signature = SafeSignatureEOA(signature, safe_hash)
elif signature_type == SafeSignatureType.ETH_SIGN:
safe_signature = SafeSignatureEthSign(signature, safe_tx_hash)
safe_signature = SafeSignatureEthSign(signature, safe_hash)

safe_signatures.append(safe_signature)
return safe_signatures
Expand Down Expand Up @@ -174,17 +186,25 @@ class SafeSignatureContract(SafeSignature):
def __init__(
self,
signature: EthereumBytes,
safe_tx_hash: EthereumBytes,
safe_hash: EthereumBytes,
safe_hash_preimage: EthereumBytes,
contract_signature: EthereumBytes,
):
super().__init__(signature, safe_tx_hash)
"""
:param signature:
:param safe_hash: Signed hash for the Safe (message or transaction)
:param safe_hash_preimage: ``safe_hash`` preimage for EIP1271 validation
:param contract_signature:
"""
super().__init__(signature, safe_hash)
self.safe_hash_preimage = HexBytes(safe_hash_preimage)
self.contract_signature = HexBytes(contract_signature)

@property
def owner(self) -> ChecksumAddress:
"""
:return: Address of contract signing. No further checks to get the owner are needed,
but it could be a non existing contract
but it could be a non-existing contract
"""

return uint_to_address(self.r)
Expand All @@ -208,21 +228,26 @@ def export_signature(self) -> HexBytes:
)

def is_valid(self, ethereum_client: EthereumClient, *args) -> bool:
safe_contract = get_safe_V1_1_1_contract(ethereum_client.w3, self.owner)
# Newest versions of the Safe contract have `isValidSignature` on the compatibility fallback handler
for block_identifier in ("pending", "latest"):
try:
return safe_contract.functions.isValidSignature(
self.safe_tx_hash, self.contract_signature
).call(block_identifier=block_identifier) in (
self.EIP1271_MAGIC_VALUE,
self.EIP1271_MAGIC_VALUE_UPDATED,
)
except (Web3Exception, DecodingError, ValueError):
# Error using `pending` block identifier or contract does not exist
logger.warning(
"Cannot check EIP1271 signature from contract %s", self.owner
)
compatibility_fallback_handler = get_compatibility_fallback_handler_contract(
ethereum_client.w3, self.owner
)
is_valid_signature_fn = (
compatibility_fallback_handler.get_function_by_signature(
"isValidSignature(bytes,bytes)"
)
)
try:
return is_valid_signature_fn(
self.safe_hash_preimage, self.contract_signature
).call() in (
self.EIP1271_MAGIC_VALUE,
self.EIP1271_MAGIC_VALUE_UPDATED,
)
except (Web3Exception, DecodingError, ValueError):
# Error using `pending` block identifier or contract does not exist
logger.warning(
"Cannot check EIP1271 signature from contract %s", self.owner
)
return False


Expand All @@ -236,13 +261,11 @@ def signature_type(self):
return SafeSignatureType.APPROVED_HASH

@classmethod
def build_for_owner(
cls, owner: str, safe_tx_hash: str
) -> "SafeSignatureApprovedHash":
def build_for_owner(cls, owner: str, safe_hash: str) -> "SafeSignatureApprovedHash":
r = owner.lower().replace("0x", "").rjust(64, "0")
s = "0" * 64
v = "01"
return cls(HexBytes(r + s + v), safe_tx_hash)
return cls(HexBytes(r + s + v), safe_hash)

def is_valid(self, ethereum_client: EthereumClient, safe_address: str) -> bool:
safe_contract = get_safe_contract(ethereum_client.w3, safe_address)
Expand All @@ -251,7 +274,7 @@ def is_valid(self, ethereum_client: EthereumClient, safe_address: str) -> bool:
try:
return (
safe_contract.functions.approvedHashes(
self.owner, self.safe_tx_hash
self.owner, self.safe_hash_preimage
).call(block_identifier=block_identifier)
== 1
)
Expand All @@ -265,7 +288,7 @@ class SafeSignatureEthSign(SafeSignature):
@property
def owner(self):
# defunct_hash_message prepends `\x19Ethereum Signed Message:\n32`
message_hash = defunct_hash_message(primitive=self.safe_tx_hash)
message_hash = defunct_hash_message(primitive=self.safe_hash)
return get_signing_address(message_hash, self.v - 4, self.r, self.s)

@property
Expand All @@ -279,7 +302,7 @@ def is_valid(self, *args) -> bool:
class SafeSignatureEOA(SafeSignature):
@property
def owner(self):
return get_signing_address(self.safe_tx_hash, self.v, self.r, self.s)
return get_signing_address(self.safe_hash, self.v, self.r, self.s)

@property
def signature_type(self):
Expand Down
9 changes: 7 additions & 2 deletions gnosis/safe/safe_tx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@
from gnosis.eth import EthereumClient
from gnosis.eth.constants import NULL_ADDRESS
from gnosis.eth.contracts import get_safe_contract
from gnosis.eth.eip712 import eip712_encode_hash
from gnosis.eth.eip712 import eip712_encode
from gnosis.eth.ethereum_client import TxSpeed

from ..eth.utils import fast_keccak
from .exceptions import (
CouldNotFinishInitialization,
CouldNotPayGasWithEther,
Expand Down Expand Up @@ -183,9 +184,13 @@ def eip712_structured_data(self) -> Dict[str, Any]:

return payload

@property
def safe_tx_hash_preimage(self) -> HexBytes:
return HexBytes(b"".join(eip712_encode(self.eip712_structured_data)))

@property
def safe_tx_hash(self) -> HexBytes:
return HexBytes(eip712_encode_hash(self.eip712_structured_data))
return HexBytes(fast_keccak(self.safe_tx_hash_preimage))

@property
def signers(self) -> List[str]:
Expand Down
31 changes: 30 additions & 1 deletion gnosis/safe/tests/test_safe_signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.test import TestCase

import eth_abi
from eth_abi import encode as encode_abi
from eth_abi.packed import encode_packed
from eth_account import Account
Expand All @@ -18,6 +19,7 @@
SafeSignatureEthSign,
SafeSignatureType,
)
from ..signatures import signature_to_bytes
from .safe_test_case import SafeTestCaseMixin

logger = logging.getLogger(__name__)
Expand All @@ -33,7 +35,7 @@ def test_contract_signature(self):
"0x00000000000000000000000005c85ab5b09eb8a55020d72daf6091e04e264af900000000000000000000000"
"0000000000000000000000000000000000000000000"
)
safe_signature = SafeSignatureContract(signature, safe_tx_hash, b"")
safe_signature = SafeSignatureContract(signature, safe_tx_hash, b"", b"")
self.assertEqual(safe_signature.owner, owner)
self.assertEqual(
safe_signature.signature_type, SafeSignatureType.CONTRACT_SIGNATURE
Expand Down Expand Up @@ -177,6 +179,32 @@ def test_parse_signature_empty(self):


class TestSafeContractSignature(SafeTestCaseMixin, TestCase):
def test_contract_signature_for_message(self):
account = Account.create()
safe_owner = self.deploy_test_safe(owners=[account.address])
safe = self.deploy_test_safe(owners=[safe_owner.address])

safe_address = safe.address
message = "Testing EIP191 message signing"
message_hash = defunct_hash_message(text=message)
safe_owner_message_hash = safe_owner.get_message_hash(message_hash)
safe_owner_signature = account.signHash(safe_owner_message_hash)["signature"]
safe_parent_message_hash = safe.get_message_hash(message_hash)

# Build EIP1271 signature v=0 r=safe v=dynamic_part dynamic_part=size+owner_signature
signature_1271 = (
signature_to_bytes(
0, int.from_bytes(HexBytes(safe_owner.address), byteorder="big"), 65
)
+ eth_abi.encode(["bytes"], [safe_owner_signature])[32:]
)

safe_signatures = SafeSignature.parse_signature(
signature_1271, safe_parent_message_hash, message_hash
)
self.assertEqual(len(safe_signatures), 1)
self.assertTrue(safe_signatures[0].is_valid(self.ethereum_client))

def test_contract_signature(self):
owner_1 = self.ethereum_test_account
safe = self.deploy_test_safe_v1_1_1(
Expand All @@ -194,6 +222,7 @@ def test_contract_signature(self):
signature = signature_r + signature_s + signature_v + contract_signature

safe_signature = SafeSignature.parse_signature(signature, safe_tx_hash)[0]
self.assertIsInstance(safe_signature, SafeSignatureContract)
self.assertFalse(safe_signature.is_valid(self.ethereum_client, None))

# Check with previously signedMessage
Expand Down

0 comments on commit 5d2958a

Please sign in to comment.