Skip to content

Commit

Permalink
[CHIA-1211] Use better key resolution logic in derivation commands (#…
Browse files Browse the repository at this point in the history
…18516)

* Resolve root secret key always during non observer derivation

* Tweak a bit

* Fix tests

* A bit simpler

* Add test coverage
  • Loading branch information
Quexington authored Sep 12, 2024
1 parent ab21776 commit d611c7e
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 42 deletions.
26 changes: 22 additions & 4 deletions chia/_tests/core/cmds/test_keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from chia.cmds.chia import cli
from chia.cmds.keys import delete_all_cmd, generate_and_print_cmd, sign_cmd, verify_cmd
from chia.cmds.keys_funcs import get_private_key_with_fingerprint_or_prompt
from chia.util.config import load_config
from chia.util.default_root import DEFAULT_KEYS_ROOT_PATH
from chia.util.keychain import Keychain, KeyData, generate_mnemonic
Expand Down Expand Up @@ -1347,7 +1348,7 @@ def test_derive_wallet_address(self, tmp_path, keyring_with_one_public_one_priva
],
)
assert result.exit_code == 0
assert result.output.find("Need a private key for non observer derivation of wallet addresses") != -1
assert result.output.find("Could not resolve private key for non-observer derivation") != -1

def test_derive_wallet_testnet_address(self, tmp_path, keyring_with_one_public_one_private_key):
"""
Expand Down Expand Up @@ -1685,9 +1686,7 @@ def test_derive_child_keys(self, tmp_path, keyring_with_one_public_one_private_k
],
)

assert isinstance(result.exception, ValueError) and result.exception.args == (
"Cannot perform non-observer derivation on an observer-only key",
)
assert result.output.find("Could not resolve private key for non-observer derivation") != -1

result: Result = runner.invoke(
cli,
Expand All @@ -1714,3 +1713,22 @@ def test_derive_child_keys(self, tmp_path, keyring_with_one_public_one_private_k
assert isinstance(result.exception, ValueError) and result.exception.args == (
"Hardened path specified for observer key",
)

@pytest.mark.anyio
async def test_get_private_key_with_fingerprint_or_prompt(
self, monkeypatch, keyring_with_one_public_one_private_key
) -> None:
[sk1_plus_ent] = keyring_with_one_public_one_private_key.get_all_private_keys()
sk1, _ = sk1_plus_ent
[pk1, pk2] = keyring_with_one_public_one_private_key.get_all_public_keys()
assert pk1.get_fingerprint() == TEST_FINGERPRINT
assert pk2.get_fingerprint() == TEST_PK_FINGERPRINT

assert get_private_key_with_fingerprint_or_prompt(TEST_FINGERPRINT) == (TEST_FINGERPRINT, sk1)
assert get_private_key_with_fingerprint_or_prompt(TEST_PK_FINGERPRINT) == (TEST_PK_FINGERPRINT, None)

monkeypatch.setattr("builtins.input", lambda _: "1")
assert get_private_key_with_fingerprint_or_prompt(None) == (TEST_FINGERPRINT, sk1)

monkeypatch.setattr("builtins.input", lambda _: "2")
assert get_private_key_with_fingerprint_or_prompt(None) == (TEST_PK_FINGERPRINT, None)
62 changes: 48 additions & 14 deletions chia/cmds/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Tuple

import click
from chia_rs import PrivateKey

from chia.cmds import options

Expand Down Expand Up @@ -204,8 +205,13 @@ def sign_cmd(
) -> None:
from .keys_funcs import resolve_derivation_master_key, sign

private_key = resolve_derivation_master_key(filename if filename is not None else fingerprint)
sign(message, private_key, hd_path, as_bytes, json)
_, resolved_sk = resolve_derivation_master_key(filename if filename is not None else fingerprint)

if resolved_sk is None:
print("Could not resolve a secret key to sign with.")
return

sign(message, resolved_sk, hd_path, as_bytes, json)


def parse_signature_json(json_str: str) -> Tuple[str, str, str, str]:
Expand Down Expand Up @@ -330,9 +336,11 @@ def search_cmd(
filename: Optional[str] = ctx.obj.get("filename", None)

# Specifying the master key is optional for the search command. If not specified, we'll search all keys.
sk = None
if fingerprint is None and filename is not None:
sk = resolve_derivation_master_key(filename)
resolved_sk = None
if fingerprint is not None or filename is not None:
_, resolved_sk = resolve_derivation_master_key(filename if filename is not None else fingerprint)
if resolved_sk is None:
print("Could not resolve private key from fingerprint/mnemonic file")

found: bool = search_derive(
ctx.obj["root_path"],
Expand All @@ -344,12 +352,36 @@ def search_cmd(
("all",) if "all" in search_type else search_type,
derive_from_hd_path,
prefix,
sk,
resolved_sk,
)

sys.exit(0 if found else 1)


class ResolutionError(Exception):
pass


def _resolve_fingerprint_and_sk(
filename: Optional[str], fingerprint: Optional[int], non_observer_derivation: bool
) -> Tuple[Optional[int], Optional[PrivateKey]]:
from .keys_funcs import resolve_derivation_master_key

reolved_fp, resolved_sk = resolve_derivation_master_key(filename if filename is not None else fingerprint)

if non_observer_derivation and resolved_sk is None:
print("Could not resolve private key for non-observer derivation")
raise ResolutionError()
else:
pass

if reolved_fp is None:
print("A fingerprint of a root key to derive from is required")
raise ResolutionError()

return reolved_fp, resolved_sk


@derive_cmd.command("wallet-address", help="Derive wallet receive addresses")
@click.option(
"--index", "-i", help="Index of the first wallet address to derive. Index 0 is the first wallet address.", default=0
Expand All @@ -376,14 +408,15 @@ def search_cmd(
def wallet_address_cmd(
ctx: click.Context, index: int, count: int, prefix: Optional[str], non_observer_derivation: bool, show_hd_path: bool
) -> None:
from .keys_funcs import derive_wallet_address, resolve_derivation_master_key
from .keys_funcs import derive_wallet_address

fingerprint: Optional[int] = ctx.obj.get("fingerprint", None)
filename: Optional[str] = ctx.obj.get("filename", None)

sk = None
if fingerprint is None and filename is not None:
sk = resolve_derivation_master_key(filename)
try:
fingerprint, sk = _resolve_fingerprint_and_sk(filename, fingerprint, non_observer_derivation)
except ResolutionError:
return

derive_wallet_address(
ctx.obj["root_path"], fingerprint, index, count, prefix, non_observer_derivation, show_hd_path, sk
Expand Down Expand Up @@ -450,17 +483,18 @@ def child_key_cmd(
show_hd_path: bool,
bech32m_prefix: Optional[str],
) -> None:
from .keys_funcs import derive_child_key, resolve_derivation_master_key
from .keys_funcs import derive_child_key

if key_type is None and derive_from_hd_path is None:
ctx.fail("--type or --derive-from-hd-path is required")

fingerprint: Optional[int] = ctx.obj.get("fingerprint", None)
filename: Optional[str] = ctx.obj.get("filename", None)

sk = None
if fingerprint is None and filename is not None:
sk = resolve_derivation_master_key(filename)
try:
fingerprint, sk = _resolve_fingerprint_and_sk(filename, fingerprint, non_observer_derivation)
except ResolutionError:
return

derive_child_key(
fingerprint,
Expand Down
49 changes: 25 additions & 24 deletions chia/cmds/keys_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,10 +657,7 @@ def derive_wallet_address(
"""
if fingerprint is not None:
key_data: KeyData = Keychain().get_key(fingerprint, include_secrets=non_observer_derivation)
if non_observer_derivation and key_data.secrets is None:
print("Need a private key for non observer derivation of wallet addresses")
return
elif non_observer_derivation:
if non_observer_derivation:
sk = key_data.private_key
else:
sk = None
Expand Down Expand Up @@ -729,9 +726,6 @@ def derive_child_key(
current_pk = private_key.get_g1()
current_sk = private_key

if non_observer_derivation and current_sk is None:
raise ValueError("Cannot perform non-observer derivation on an observer-only key")

# Key type was specified
if key_type is not None:
path_indices: List[int] = [12381, 8444]
Expand Down Expand Up @@ -807,16 +801,7 @@ def private_key_for_fingerprint(fingerprint: int) -> Optional[PrivateKey]:
return None


def get_private_key_with_fingerprint_or_prompt(fingerprint: Optional[int]) -> Optional[PrivateKey]:
"""
Get a private key with the specified fingerprint. If fingerprint is not
specified, prompt the user to select a key.
"""

# Return the private key matching the specified fingerprint
if fingerprint is not None:
return private_key_for_fingerprint(fingerprint)

def prompt_for_fingerprint() -> Optional[int]:
fingerprints: List[int] = [pk.get_fingerprint() for pk in Keychain().get_all_public_keys()]
while True:
print("Choose key:")
Expand All @@ -836,7 +821,23 @@ def get_private_key_with_fingerprint_or_prompt(fingerprint: Optional[int]) -> Op
val = None
continue
else:
return private_key_for_fingerprint(fingerprints[index])
return fingerprints[index]


def get_private_key_with_fingerprint_or_prompt(
fingerprint: Optional[int],
) -> Tuple[Optional[int], Optional[PrivateKey]]:
"""
Get a private key with the specified fingerprint. If fingerprint is not
specified, prompt the user to select a key.
"""

# Return the private key matching the specified fingerprint
if fingerprint is not None:
return fingerprint, private_key_for_fingerprint(fingerprint)

fingerprint_prompt = prompt_for_fingerprint()
return fingerprint_prompt, None if fingerprint_prompt is None else private_key_for_fingerprint(fingerprint_prompt)


def private_key_from_mnemonic_seed_file(filename: Path) -> PrivateKey:
Expand All @@ -849,15 +850,15 @@ def private_key_from_mnemonic_seed_file(filename: Path) -> PrivateKey:
return AugSchemeMPL.key_gen(seed)


def resolve_derivation_master_key(fingerprint_or_filename: Optional[Union[int, str, Path]]) -> PrivateKey:
def resolve_derivation_master_key(
fingerprint_or_filename: Optional[Union[int, str, Path]]
) -> Tuple[Optional[int], Optional[PrivateKey]]:
"""
Given a key fingerprint of file containing a mnemonic seed, return the private key.
"""

if fingerprint_or_filename is not None and (isinstance(fingerprint_or_filename, (str, Path))):
return private_key_from_mnemonic_seed_file(Path(os.fspath(fingerprint_or_filename)))
sk = private_key_from_mnemonic_seed_file(Path(os.fspath(fingerprint_or_filename)))
return sk.get_g1().get_fingerprint(), sk
else:
ret = get_private_key_with_fingerprint_or_prompt(fingerprint_or_filename)
if ret is None:
raise ValueError("Abort. No private key")
return ret
return get_private_key_with_fingerprint_or_prompt(fingerprint_or_filename)

0 comments on commit d611c7e

Please sign in to comment.