Skip to content

Commit

Permalink
implement sign_raw for all other chains, fix formatting and imports
Browse files Browse the repository at this point in the history
  • Loading branch information
MHHukiewitz committed Aug 12, 2023
1 parent 3c25982 commit 1dd4c06
Show file tree
Hide file tree
Showing 12 changed files with 66 additions and 55 deletions.
13 changes: 6 additions & 7 deletions src/aleph/sdk/chains/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def _setup_sender(self, message: Dict) -> Dict:
else:
raise ValueError("Message sender does not match the account's public key.")

@abstractmethod
async def sign_message(self, message: Dict) -> Dict:
"""
Returns a signed message from an Aleph message.
Expand All @@ -72,19 +71,19 @@ async def sign_message(self, message: Dict) -> Dict:
Dict: Signed message
"""
message = self._setup_sender(message)
sig = await self.sign_raw(get_verification_buffer(message))
message["signature"] = sig.hex()
message["signature"] = await self.sign_raw(get_verification_buffer(message))
return message

async def sign_raw(self, buffer: bytes) -> bytes:

@abstractmethod
async def sign_raw(self, buffer: bytes) -> str:
"""
Returns a signed message from a raw buffer.
Args:
buffer: Buffer to sign
Returns:
bytes: Signed buffer
str: Signature in preferred format
"""

raise NotImplementedError

@abstractmethod
def get_address(self) -> str:
Expand Down
20 changes: 10 additions & 10 deletions src/aleph/sdk/chains/cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,27 +51,27 @@ def __init__(self, private_key=None, hrp=DEFAULT_HRP):

async def sign_message(self, message):
message = self._setup_sender(message)

verif = get_verification_string(message)

privkey = ecdsa.SigningKey.from_string(self.private_key, curve=ecdsa.SECP256k1)
signature_compact = privkey.sign_deterministic(
verif.encode("utf-8"),
hashfunc=hashlib.sha256,
sigencode=ecdsa.util.sigencode_string_canonize,
)
signature_base64_str = base64.b64encode(signature_compact).decode("utf-8")
base64_pubkey = base64.b64encode(self.get_public_key().encode()).decode("utf-8")

sig = {
"signature": signature_base64_str,
"signature": self.sign_raw(verif.encode("utf-8")),
"pub_key": {"type": "tendermint/PubKeySecp256k1", "value": base64_pubkey},
"account_number": str(0),
"sequence": str(0),
}
message["signature"] = json.dumps(sig)
return message

async def sign_raw(self, buffer: bytes) -> str:
privkey = ecdsa.SigningKey.from_string(self.private_key, curve=ecdsa.SECP256k1)
signature_compact = privkey.sign_deterministic(
buffer,
hashfunc=hashlib.sha256,
sigencode=ecdsa.util.sigencode_string_canonize,
)
return base64.b64encode(signature_compact).decode("utf-8")

def get_address(self) -> str:
return privkey_to_address(self.private_key)

Expand Down
9 changes: 4 additions & 5 deletions src/aleph/sdk/chains/ethereum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path
from typing import Dict, Optional, Union
from typing import Optional, Union

from eth_account import Account
from eth_account.messages import encode_defunct
Expand All @@ -11,7 +11,6 @@
BaseAccount,
get_fallback_private_key,
get_public_key,
get_verification_buffer,
)


Expand All @@ -23,12 +22,12 @@ class ETHAccount(BaseAccount):
def __init__(self, private_key: bytes):
self.private_key = private_key
self._account = Account.from_key(self.private_key)
async def sign_raw(self, buffer: bytes) -> bytes:

async def sign_raw(self, buffer: bytes) -> str:
"""Sign a raw buffer."""
msghash = encode_defunct(text=buffer.decode("utf-8"))
sig = self._account.sign_message(msghash)
return sig["signature"]
return sig["signature"].hex()

def get_address(self) -> str:
return self._account.address
Expand Down
4 changes: 4 additions & 0 deletions src/aleph/sdk/chains/nuls1.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,10 @@ async def sign_message(self, message):
message["signature"] = sig.serialize().hex()
return message

async def sign_raw(self, buffer: bytes) -> str:
sig = NulsSignature.sign_data(self.private_key, buffer)
return sig.serialize().hex()

def get_address(self):
return address_from_hash(
public_key_to_hash(self.get_public_key(), chain_id=self.chain_id)
Expand Down
14 changes: 3 additions & 11 deletions src/aleph/sdk/chains/nuls2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
BaseAccount,
get_fallback_private_key,
get_public_key,
get_verification_buffer,
)


Expand All @@ -37,16 +36,9 @@ def __init__(self, private_key=None, chain_id=1, prefix=None):
else:
self.prefix = prefix

async def sign_message(self, message):
# sig = NulsSignature.sign_message(self.private_key,
# get_verification_buffer(message))
message = self._setup_sender(message)

sig = sign_recoverable_message(
self.private_key, get_verification_buffer(message)
)
message["signature"] = base64.b64encode(sig).decode()
return message
async def sign_raw(self, buffer: bytes) -> str:
sig = sign_recoverable_message(self.private_key, buffer)
return base64.b64encode(sig).decode()

def get_address(self):
return address_from_hash(
Expand Down
3 changes: 3 additions & 0 deletions src/aleph/sdk/chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ async def sign_message(self, message: Dict) -> Dict:
response.raise_for_status()
return await response.json()

async def sign_raw(self, buffer: bytes) -> str:
raise NotImplementedError()

def get_address(self) -> str:
return self._address

Expand Down
6 changes: 3 additions & 3 deletions src/aleph/sdk/chains/sol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,15 @@ async def sign_message(self, message: Dict) -> Dict:
verif = get_verification_buffer(message)
sig = {
"publicKey": self.get_address(),
"signature": encode(self.sign_raw(verif)),
"signature": self.sign_raw(verif),
}
message["signature"] = json.dumps(sig)
return message

async def sign_raw(self, buffer: bytes) -> bytes:
async def sign_raw(self, buffer: bytes) -> str:
"""Sign a raw buffer."""
sig = self._signing_key.sign(buffer)
return sig.signature
return encode(sig.signature)

def get_address(self) -> str:
return encode(self._signing_key.verify_key)
Expand Down
4 changes: 4 additions & 0 deletions src/aleph/sdk/chains/substrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ async def sign_message(self, message):
message["signature"] = json.dumps(sig)
return message

async def sign_raw(self, buffer: bytes) -> str:
sig = self._account.sign(buffer)
return sig.hex()

def get_address(self):
return self._account.ss58_address

Expand Down
5 changes: 4 additions & 1 deletion src/aleph/sdk/chains/tezos.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,15 @@ async def sign_message(self, message: Dict) -> Dict:
verif = get_verification_buffer(message)
sig = {
"publicKey": self.get_public_key(),
"signature": self._account.sign(verif),
"signature": self.sign_raw(verif),
}

message["signature"] = json.dumps(sig)
return message

async def sign_raw(self, buffer: bytes) -> str:
return self._account.sign(buffer)

def get_address(self) -> str:
return self._account.public_key_hash()

Expand Down
29 changes: 21 additions & 8 deletions src/aleph/sdk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import threading
import time
from datetime import datetime
from io import BytesIO
from pathlib import Path
from typing import (
Any,
Expand All @@ -23,7 +24,6 @@
TypeVar,
Union,
)
from io import BytesIO

import aiohttp
from aleph_message.models import (
Expand All @@ -45,17 +45,18 @@
)
from aleph_message.models.execution.base import Encoding
from aleph_message.status import MessageStatus
from pydantic import ValidationError, BaseModel
from pydantic import ValidationError

from aleph.sdk.types import Account, GenericMessage, StorageEnum
from aleph.sdk.utils import copy_async_readable_to_buffer, Writable, AsyncReadable
from aleph.sdk.utils import Writable, copy_async_readable_to_buffer

from .conf import settings
from .exceptions import (
BroadcastError,
FileTooLarge,
InvalidMessageError,
MessageNotFoundError,
MultipleMessagesError,
FileTooLarge,
)
from .models import MessagesResponse
from .utils import check_unix_socket_valid, get_message_type_value
Expand Down Expand Up @@ -237,12 +238,24 @@ def download_file_ipfs(self, file_hash: str) -> bytes:
self.async_session.download_file_ipfs,
file_hash=file_hash,
)
def download_file_to_buffer(self, file_hash: str, output_buffer: Writable[bytes]) -> bytes:
return self._wrap(self.async_session.download_file_to_buffer, file_hash=file_hash, output_buffer=output_buffer)

def download_file_ipfs_to_buffer(self, file_hash: str, output_buffer: Writable[bytes]) -> bytes:
return self._wrap(self.async_session.download_file_ipfs_to_buffer, file_hash=file_hash, output_buffer=output_buffer)
def download_file_to_buffer(
self, file_hash: str, output_buffer: Writable[bytes]
) -> bytes:
return self._wrap(
self.async_session.download_file_to_buffer,
file_hash=file_hash,
output_buffer=output_buffer,
)

def download_file_ipfs_to_buffer(
self, file_hash: str, output_buffer: Writable[bytes]
) -> bytes:
return self._wrap(
self.async_session.download_file_ipfs_to_buffer,
file_hash=file_hash,
output_buffer=output_buffer,
)

def watch_messages(
self,
Expand Down
13 changes: 3 additions & 10 deletions src/aleph/sdk/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from pathlib import Path
from shutil import make_archive
from typing import Tuple, Type, Union
from typing import Protocol, Tuple, Type, TypeVar, Union
from zipfile import BadZipFile, ZipFile

from aleph_message.models import MessageType
Expand All @@ -13,13 +13,6 @@
from aleph.sdk.conf import settings
from aleph.sdk.types import GenericMessage

from typing import (
Tuple,
Type,
TypeVar,
Protocol,
)

logger = logging.getLogger(__name__)

try:
Expand Down Expand Up @@ -54,7 +47,7 @@ def create_archive(path: Path) -> Tuple[Path, Encoding]:
return archive_path, Encoding.zip
elif os.path.isfile(path):
if path.suffix == ".squashfs" or (
magic and magic.from_file(path).startswith("Squashfs filesystem")
magic and magic.from_file(path).startswith("Squashfs filesystem")
):
return path, Encoding.squashfs
else:
Expand Down Expand Up @@ -101,7 +94,7 @@ def write(self, buffer: U) -> int:


async def copy_async_readable_to_buffer(
readable: AsyncReadable[T], buffer: Writable[T], chunk_size: int
readable: AsyncReadable[T], buffer: Writable[T], chunk_size: int
):
while True:
chunk = await readable.read(chunk_size)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/test_download.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest

from aleph.sdk import AlephClient
from aleph.sdk.conf import settings as sdk_settings

Expand Down

0 comments on commit 1dd4c06

Please sign in to comment.