Skip to content

Commit

Permalink
various typing fixes (#15899)
Browse files Browse the repository at this point in the history
* add blspy-stubs to improve mypy

* various typing fixes

* avoid cast

* use ValueError exception instead of assert when aborting picking a wallet key
  • Loading branch information
arvidn authored Aug 18, 2023
1 parent ffb2b61 commit 7b0bea4
Show file tree
Hide file tree
Showing 21 changed files with 197 additions and 45 deletions.
118 changes: 118 additions & 0 deletions blspy-stubs/blspy-stubs/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

class G1Element:
SIZE: ClassVar[int] = ...
def __new__(cls) -> G1Element: ...
def get_fingerprint(self) -> int: ...
@staticmethod
def from_bytes_unchecked(b: bytes) -> G1Element: ...
def pair(self, other: G2Element) -> GTElement: ...
@staticmethod
def generator() -> G1Element: ...
def __add__(self, other: G1Element) -> G1Element: ...
def __iadd__(self, other: G1Element) -> None: ...
def __init__(self) -> None: ...
def __hash__(self) -> int: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __richcmp__(self) -> Any: ...
def __deepcopy__(self) -> G1Element: ...
def __copy__(self) -> G1Element: ...
@staticmethod
def from_bytes(bytes) -> G1Element: ...
@staticmethod
def parse_rust(ReadableBuffer) -> Tuple[G1Element, int]: ...
def to_bytes(self) -> bytes: ...
def __bytes__(self) -> bytes: ...
def get_hash(self) -> bytes32: ...
def to_json_dict(self) -> Dict[str, Any]: ...
@staticmethod
def from_json_dict(o: Dict[str, Any]) -> G1Element: ...

class G2Element:
SIZE: ClassVar[int] = ...
def __new__(cls) -> G2Element: ...
def get_fingerprint(self) -> int: ...
@staticmethod
def from_bytes_unchecked(b: bytes) -> G2Element: ...
def pair(self, other: G1Element) -> GTElement: ...
@staticmethod
def generator() -> G2Element: ...
def __hash__(self) -> int: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __richcmp__(self) -> bool: ...
def __deepcopy__(self) -> G2Element: ...
def __copy__(self) -> G2Element: ...
@staticmethod
def from_bytes(bytes) -> G2Element: ...
@staticmethod
def parse_rust(ReadableBuffer) -> Tuple[G2Element, int]: ...
def to_bytes(self) -> bytes: ...
def __bytes__(self) -> bytes: ...
def get_hash(self) -> bytes32: ...
def to_json_dict(self) -> Dict[str, Any]: ...
@staticmethod
def from_json_dict(o: Dict[str, Any]) -> G2Element: ...

class GTElement:
SIZE: ClassVar[int] = ...
@staticmethod
def from_bytes_unchecked(b: bytes) -> GTElement: ...
def __mul__(self, rhs: GTElement) -> GTElement: ...
def __imul__(self, rhs: GTElement) -> None: ...
def __hash__(self) -> int: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __richcmp__(self) -> bool: ...
def __deepcopy__(self) -> GTElement: ...
def __copy__(self) -> GTElement: ...
@staticmethod
def from_bytes(bytes) -> GTElement: ...
@staticmethod
def parse_rust(ReadableBuffer) -> Tuple[GTElement, int]: ...
def to_bytes(self) -> bytes: ...
def __bytes__(self) -> bytes: ...
def get_hash(self) -> bytes32: ...
def to_json_dict(self) -> Dict[str, Any]: ...
@staticmethod
def from_json_dict(o: Dict[str, Any]) -> GTElement: ...

class PrivateKey:
PRIVATE_KEY_SIZE: ClassVar[int] = ...
def sign_g2(self, msg: bytes, dst: bytes) -> G2Element: ...
def get_g1(self) -> G1Element: ...
def __hash__(self) -> int: ...
def __str__(self) -> str: ...
def __repr__(self) -> str: ...
def __richcmp__(self) -> bool: ...
def __deepcopy__(self) -> PrivateKey: ...
def __copy__(self) -> PrivateKey: ...
@staticmethod
def from_bytes(bytes) -> PrivateKey: ...
@staticmethod
def parse_rust(ReadableBuffer) -> Tuple[PrivateKey, int]: ...
def to_bytes(self) -> bytes: ...
def __bytes__(self) -> bytes: ...
def get_hash(self) -> bytes32: ...
def to_json_dict(self) -> Dict[str, Any]: ...
@staticmethod
def from_json_dict(o: Dict[str, Any]) -> PrivateKey: ...

class AugSchemeMPL:
@staticmethod
def sign(pk: PrivateKey, msg: bytes, prepend_pk: G1Element = None) -> G2Element: ...
@staticmethod
def aggregate(sigs: Sequence[G2Element]) -> G2Element: ...
@staticmethod
def verify(pk: G1Element, msg: bytes, sig: G2Element) -> bool: ...
@staticmethod
def aggregate_verify(pks: Sequence[G1Element], msgs: Sequence[bytes], sig: G2Element) -> bool: ...
@staticmethod
def key_gen(seed: bytes) -> PrivateKey: ...
@staticmethod
def g2_from_message(msg: bytes) -> G2Element: ...
@staticmethod
def derive_child_sk(pk: PrivateKey, index: int) -> PrivateKey: ...
@staticmethod
def derive_child_sk_unhardened(pk: PrivateKey, index: int) -> PrivateKey: ...
9 changes: 9 additions & 0 deletions blspy-stubs/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from __future__ import annotations

from setuptools import setup

setup(
name="blspy-stubs",
packages=["blspy-stubs"],
package_data={"blspy-stubs": ["__init__.pyi"]},
)
12 changes: 8 additions & 4 deletions chia/cmds/keys_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,9 +285,9 @@ def sign(message: str, private_key: PrivateKey, hd_path: str, as_bytes: bool, js

def verify(message: str, public_key: str, signature: str, as_bytes: bool) -> None:
data = bytes.fromhex(message) if as_bytes else bytes(message, "utf-8")
public_key = G1Element.from_bytes(bytes.fromhex(public_key))
signature = G2Element.from_bytes(bytes.fromhex(signature))
print(AugSchemeMPL.verify(public_key, data, signature))
pk = G1Element.from_bytes(bytes.fromhex(public_key))
sig = G2Element.from_bytes(bytes.fromhex(signature))
print(AugSchemeMPL.verify(pk, data, sig))


def as_bytes_from_signing_mode(signing_mode_str: str) -> bool:
Expand Down Expand Up @@ -361,6 +361,7 @@ class DerivedSearchResultType(Enum):

if search_address:
# Generate a wallet address using the standard p2_delegated_puzzle_or_hidden_puzzle puzzle
assert child_pk is not None
# TODO: consider generating addresses using other puzzles
address = encode_puzzle_hash(create_puzzlehash_for_pk(child_pk), prefix)

Expand Down Expand Up @@ -741,4 +742,7 @@ def resolve_derivation_master_key(fingerprint_or_filename: Optional[Union[int, s
):
return private_key_from_mnemonic_seed_file(Path(os.fspath(fingerprint_or_filename)))
else:
return get_private_key_with_fingerprint_or_prompt(fingerprint_or_filename)
ret = get_private_key_with_fingerprint_or_prompt(fingerprint_or_filename)
if ret is None:
raise ValueError("Abort. No private key")
return ret
2 changes: 2 additions & 0 deletions chia/consensus/block_header_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,7 @@ def validate_unfinished_header_block(

# 17. Check foliage block signature by plot key
if header_block.foliage.foliage_transaction_block_hash is not None:
assert header_block.foliage.foliage_transaction_block_signature is not None
if not AugSchemeMPL.verify(
header_block.reward_chain_block.proof_of_space.plot_public_key,
header_block.foliage.foliage_transaction_block_hash,
Expand Down Expand Up @@ -789,6 +790,7 @@ def validate_unfinished_header_block(
# 20b. If pospace has a pool pk, heck pool target signature. Should not check this for genesis block.
if header_block.reward_chain_block.proof_of_space.pool_public_key is not None:
assert header_block.reward_chain_block.proof_of_space.pool_contract_puzzle_hash is None
assert header_block.foliage.foliage_block_data.pool_signature is not None
if not AugSchemeMPL.verify(
header_block.reward_chain_block.proof_of_space.pool_public_key,
bytes(header_block.foliage.foliage_block_data.pool_target),
Expand Down
1 change: 1 addition & 0 deletions chia/full_node/full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ def has_valid_pool_sig(self, block: Union[UnfinishedBlock, FullBlock]) -> bool:
and block.foliage.prev_block_hash != self.constants.GENESIS_CHALLENGE
and block.reward_chain_block.proof_of_space.pool_public_key is not None
):
assert block.foliage.foliage_block_data.pool_signature is not None
if not AugSchemeMPL.verify(
block.reward_chain_block.proof_of_space.pool_public_key,
bytes(block.foliage.foliage_block_data.pool_target),
Expand Down
2 changes: 1 addition & 1 deletion chia/plotting/create_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def resolve(self) -> PlotKeys:
await keychain_proxy.close()
return self.resolved_keys

async def get_sk(self, keychain_proxy: Optional[KeychainProxy] = None) -> Optional[Tuple[PrivateKey, bytes]]:
async def get_sk(self, keychain_proxy: Optional[KeychainProxy] = None) -> Optional[PrivateKey]:
sk: Optional[PrivateKey] = None
if keychain_proxy:
try:
Expand Down
2 changes: 1 addition & 1 deletion chia/rpc/wallet_rpc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2284,7 +2284,7 @@ async def did_recovery_spend(self, request: Dict[str, Any]) -> EndpointResult:
pubkey = G1Element.from_bytes(hexstr_to_bytes(request["pubkey"]))
else:
assert wallet.did_info.temp_pubkey is not None
pubkey = wallet.did_info.temp_pubkey
pubkey = G1Element.from_bytes(wallet.did_info.temp_pubkey)

if "puzhash" in request:
puzhash = bytes32.from_hexstr(request["puzhash"])
Expand Down
8 changes: 6 additions & 2 deletions chia/simulator/block_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,12 @@ async def setup_keys(self, fingerprint: Optional[int] = None, reward_ph: Optiona
bytes_to_mnemonic(self.pool_master_sk_entropy),
)
else:
self.farmer_master_sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
self.pool_master_sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
assert sk is not None
self.farmer_master_sk = sk
sk = await keychain_proxy.get_key_for_fingerprint(fingerprint)
assert sk is not None
self.pool_master_sk = sk

self.farmer_pk = master_sk_to_farmer_sk(self.farmer_master_sk).get_g1()
self.pool_pk = master_sk_to_pool_sk(self.pool_master_sk).get_g1()
Expand Down
9 changes: 6 additions & 3 deletions chia/util/cached_bls.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def get_pairings(

# G1Element.from_bytes can be expensive due to subgroup check, so we avoid recomputing it with this cache
pk_bytes_to_g1: Dict[bytes48, G1Element] = {}
ret: List[GTElement] = []
for i, pairing in enumerate(pairings):
if pairing is None:
aug_msg = pks[i] + msgs[i]
Expand All @@ -42,12 +43,14 @@ def get_pairings(
pk_parsed = G1Element.from_bytes(pks[i])
pk_bytes_to_g1[pks[i]] = pk_parsed

pairing = pk_parsed.pair(aug_hash)
pairing = aug_hash.pair(pk_parsed)

h = std_hash(aug_msg)
cache.put(h, pairing)
pairings[i] = pairing
return pairings
ret.append(pairing)
else:
ret.append(pairing)
return ret


# Increasing this number will increase RAM usage, but decrease BLS validation time for blocks and unfinished blocks.
Expand Down
6 changes: 3 additions & 3 deletions chia/util/keychain.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,14 +202,14 @@ def __post_init__(self) -> None:
# an attribute mismatch for calculated cached values. Should be ok since we don't handle a lot of keys here.
if self.secrets is not None and self.public_key != self.private_key.get_g1():
raise KeychainKeyDataMismatch("public_key")
if self.public_key.get_fingerprint() != self.fingerprint:
if uint32(self.public_key.get_fingerprint()) != self.fingerprint:
raise KeychainKeyDataMismatch("fingerprint")

@classmethod
def from_mnemonic(cls, mnemonic: str, label: Optional[str] = None) -> KeyData:
private_key = AugSchemeMPL.key_gen(mnemonic_to_seed(mnemonic))
return cls(
fingerprint=private_key.get_g1().get_fingerprint(),
fingerprint=uint32(private_key.get_g1().get_fingerprint()),
public_key=private_key.get_g1(),
label=label,
secrets=KeyDataSecrets.from_mnemonic(mnemonic),
Expand Down Expand Up @@ -285,7 +285,7 @@ def _get_key_data(self, index: int, include_secrets: bool = True) -> KeyData:
entropy = str_bytes[G1Element.SIZE : G1Element.SIZE + 32]

return KeyData(
fingerprint=fingerprint,
fingerprint=uint32(fingerprint),
public_key=public_key,
label=self.keyring_wrapper.get_label(fingerprint),
secrets=KeyDataSecrets.from_entropy(entropy) if include_secrets else None,
Expand Down
7 changes: 6 additions & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ usage() {

PACMAN_AUTOMATED=
EXTRAS=
BLSPY_STUBS=
SKIP_PACKAGE_INSTALL=
PLOTTER_INSTALL=
EDITABLE='-e'
Expand All @@ -30,7 +31,7 @@ do
# automated
a) PACMAN_AUTOMATED=--noconfirm;;
# development
d) EXTRAS=${EXTRAS}dev,;;
d) EXTRAS=${EXTRAS}dev,;BLSPY_STUBS=1;;
# non-editable
i) EDITABLE='';;
# legacy keyring
Expand Down Expand Up @@ -334,6 +335,10 @@ python -m pip install wheel
python -m pip install --extra-index-url https://pypi.chia.net/simple/ miniupnpc==2.2.2
python -m pip install ${EDITABLE} ."${EXTRAS}" --extra-index-url https://pypi.chia.net/simple/

if [ -n "$BLSPY_STUBS" ]; then
python -m pip install ${EDITABLE} ./blspy-stubs
fi

if [ -n "$PLOTTER_INSTALL" ]; then
set +e
PREV_VENV="$VIRTUAL_ENV"
Expand Down
2 changes: 1 addition & 1 deletion tests/blockchain/test_blockchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -2040,7 +2040,7 @@ async def test_aggsig_garbage(self, empty_blockchain, opcode, with_garbage, expe
synthetic_secret_key = calculate_synthetic_secret_key(secret_key, DEFAULT_HIDDEN_PUZZLE_HASH)
public_key = synthetic_secret_key.get_g1()

args = [public_key, b"msg"] + ([b"garbage"] if with_garbage else [])
args = [bytes(public_key), b"msg"] + ([b"garbage"] if with_garbage else [])
conditions = {opcode: [ConditionWithArgs(opcode, args)]}

tx2: SpendBundle = wt.generate_signed_transaction(
Expand Down
6 changes: 3 additions & 3 deletions tests/clvm/test_puzzles.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Iterable, List, Tuple
from unittest import TestCase

from blspy import AugSchemeMPL, BasicSchemeMPL, G1Element, G2Element
from blspy import AugSchemeMPL, G1Element, G2Element

from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.types.blockchain_format.program import Program
Expand Down Expand Up @@ -34,7 +34,7 @@

def secret_exponent_for_index(index: int) -> int:
blob = index.to_bytes(32, "big")
hashed_blob = BasicSchemeMPL.key_gen(std_hash(b"foo" + blob))
hashed_blob = AugSchemeMPL.key_gen(std_hash(b"foo" + blob))
r = int.from_bytes(hashed_blob, "big")
return r

Expand Down Expand Up @@ -87,7 +87,7 @@ def do_test_spend(
assert 0

# make sure we can actually sign the solution
signatures = []
signatures: List[G2Element] = []
for coin_spend in spend_bundle.coin_spends:
signature = key_lookup.signature_for_solution(coin_spend, bytes([2] * 32))
signatures.append(signature)
Expand Down
2 changes: 1 addition & 1 deletion tests/pools/test_pool_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def mempool_empty() -> bool:
[some_wallet.wallet_state_manager.private_key], status.current.owner_pubkey
)
assert owner_sk is not None
assert owner_sk != auth_sk
assert owner_sk[0] != auth_sk

@pytest.mark.asyncio
async def test_absorb_self(
Expand Down
2 changes: 1 addition & 1 deletion tests/util/benchmark_cost.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def benchmark_all_operators():

for i in range(0, 1000):
private_key: PrivateKey = master_sk_to_wallet_sk(secret_key, uint32(i))
public_key = private_key.public_key()
public_key = private_key.get_g1()
solution = wallet_tool.make_solution(
{ConditionOpcode.ASSERT_MY_COIN_ID: [ConditionWithArgs(ConditionOpcode.ASSERT_MY_COIN_ID, [token_bytes()])]}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/util/key_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def sign(self, public_key: bytes, message: bytes) -> G2Element:
bls_private_key = PrivateKey.from_bytes(secret_exponent.to_bytes(32, "big"))
return AugSchemeMPL.sign(bls_private_key, message)

def signature_for_solution(self, coin_spend: CoinSpend, additional_data: bytes) -> AugSchemeMPL:
def signature_for_solution(self, coin_spend: CoinSpend, additional_data: bytes) -> G2Element:
signatures = []
conditions_dict = conditions_dict_for_solution(
coin_spend.puzzle_reveal, coin_spend.solution, test_constants.MAX_BLOCK_COST_CLVM
Expand Down
Loading

0 comments on commit 7b0bea4

Please sign in to comment.