diff --git a/chia/_tests/core/cmds/test_keys.py b/chia/_tests/core/cmds/test_keys.py index 5acd81954af4..7629b1924bb9 100644 --- a/chia/_tests/core/cmds/test_keys.py +++ b/chia/_tests/core/cmds/test_keys.py @@ -11,6 +11,7 @@ from chia.cmds.chia import cli from chia.cmds.keys import delete_all_cmd, generate_and_print_cmd, sign_cmd, verify_cmd +from chia.cmds.keys_funcs import get_private_key_with_fingerprint_or_prompt from chia.util.config import load_config from chia.util.default_root import DEFAULT_KEYS_ROOT_PATH from chia.util.keychain import Keychain, KeyData, generate_mnemonic @@ -1347,7 +1348,7 @@ def test_derive_wallet_address(self, tmp_path, keyring_with_one_public_one_priva ], ) assert result.exit_code == 0 - assert result.output.find("Need a private key for non observer derivation of wallet addresses") != -1 + assert result.output.find("Could not resolve private key for non-observer derivation") != -1 def test_derive_wallet_testnet_address(self, tmp_path, keyring_with_one_public_one_private_key): """ @@ -1685,9 +1686,7 @@ def test_derive_child_keys(self, tmp_path, keyring_with_one_public_one_private_k ], ) - assert isinstance(result.exception, ValueError) and result.exception.args == ( - "Cannot perform non-observer derivation on an observer-only key", - ) + assert result.output.find("Could not resolve private key for non-observer derivation") != -1 result: Result = runner.invoke( cli, @@ -1714,3 +1713,22 @@ def test_derive_child_keys(self, tmp_path, keyring_with_one_public_one_private_k assert isinstance(result.exception, ValueError) and result.exception.args == ( "Hardened path specified for observer key", ) + + @pytest.mark.anyio + async def test_get_private_key_with_fingerprint_or_prompt( + self, monkeypatch, keyring_with_one_public_one_private_key + ) -> None: + [sk1_plus_ent] = keyring_with_one_public_one_private_key.get_all_private_keys() + sk1, _ = sk1_plus_ent + [pk1, pk2] = keyring_with_one_public_one_private_key.get_all_public_keys() + assert pk1.get_fingerprint() == TEST_FINGERPRINT + assert pk2.get_fingerprint() == TEST_PK_FINGERPRINT + + assert get_private_key_with_fingerprint_or_prompt(TEST_FINGERPRINT) == (TEST_FINGERPRINT, sk1) + assert get_private_key_with_fingerprint_or_prompt(TEST_PK_FINGERPRINT) == (TEST_PK_FINGERPRINT, None) + + monkeypatch.setattr("builtins.input", lambda _: "1") + assert get_private_key_with_fingerprint_or_prompt(None) == (TEST_FINGERPRINT, sk1) + + monkeypatch.setattr("builtins.input", lambda _: "2") + assert get_private_key_with_fingerprint_or_prompt(None) == (TEST_PK_FINGERPRINT, None) diff --git a/chia/cmds/keys.py b/chia/cmds/keys.py index f561d02440e8..ff9a30ea064f 100644 --- a/chia/cmds/keys.py +++ b/chia/cmds/keys.py @@ -3,6 +3,7 @@ from typing import Optional, Tuple import click +from chia_rs import PrivateKey from chia.cmds import options @@ -204,8 +205,13 @@ def sign_cmd( ) -> None: from .keys_funcs import resolve_derivation_master_key, sign - private_key = resolve_derivation_master_key(filename if filename is not None else fingerprint) - sign(message, private_key, hd_path, as_bytes, json) + _, resolved_sk = resolve_derivation_master_key(filename if filename is not None else fingerprint) + + if resolved_sk is None: + print("Could not resolve a secret key to sign with.") + return + + sign(message, resolved_sk, hd_path, as_bytes, json) def parse_signature_json(json_str: str) -> Tuple[str, str, str, str]: @@ -330,9 +336,11 @@ def search_cmd( filename: Optional[str] = ctx.obj.get("filename", None) # Specifying the master key is optional for the search command. If not specified, we'll search all keys. - sk = None - if fingerprint is None and filename is not None: - sk = resolve_derivation_master_key(filename) + resolved_sk = None + if fingerprint is not None or filename is not None: + _, resolved_sk = resolve_derivation_master_key(filename if filename is not None else fingerprint) + if resolved_sk is None: + print("Could not resolve private key from fingerprint/mnemonic file") found: bool = search_derive( ctx.obj["root_path"], @@ -344,12 +352,36 @@ def search_cmd( ("all",) if "all" in search_type else search_type, derive_from_hd_path, prefix, - sk, + resolved_sk, ) sys.exit(0 if found else 1) +class ResolutionError(Exception): + pass + + +def _resolve_fingerprint_and_sk( + filename: Optional[str], fingerprint: Optional[int], non_observer_derivation: bool +) -> Tuple[Optional[int], Optional[PrivateKey]]: + from .keys_funcs import resolve_derivation_master_key + + reolved_fp, resolved_sk = resolve_derivation_master_key(filename if filename is not None else fingerprint) + + if non_observer_derivation and resolved_sk is None: + print("Could not resolve private key for non-observer derivation") + raise ResolutionError() + else: + pass + + if reolved_fp is None: + print("A fingerprint of a root key to derive from is required") + raise ResolutionError() + + return reolved_fp, resolved_sk + + @derive_cmd.command("wallet-address", help="Derive wallet receive addresses") @click.option( "--index", "-i", help="Index of the first wallet address to derive. Index 0 is the first wallet address.", default=0 @@ -376,14 +408,15 @@ def search_cmd( def wallet_address_cmd( ctx: click.Context, index: int, count: int, prefix: Optional[str], non_observer_derivation: bool, show_hd_path: bool ) -> None: - from .keys_funcs import derive_wallet_address, resolve_derivation_master_key + from .keys_funcs import derive_wallet_address fingerprint: Optional[int] = ctx.obj.get("fingerprint", None) filename: Optional[str] = ctx.obj.get("filename", None) - sk = None - if fingerprint is None and filename is not None: - sk = resolve_derivation_master_key(filename) + try: + fingerprint, sk = _resolve_fingerprint_and_sk(filename, fingerprint, non_observer_derivation) + except ResolutionError: + return derive_wallet_address( ctx.obj["root_path"], fingerprint, index, count, prefix, non_observer_derivation, show_hd_path, sk @@ -450,7 +483,7 @@ def child_key_cmd( show_hd_path: bool, bech32m_prefix: Optional[str], ) -> None: - from .keys_funcs import derive_child_key, resolve_derivation_master_key + from .keys_funcs import derive_child_key if key_type is None and derive_from_hd_path is None: ctx.fail("--type or --derive-from-hd-path is required") @@ -458,9 +491,10 @@ def child_key_cmd( fingerprint: Optional[int] = ctx.obj.get("fingerprint", None) filename: Optional[str] = ctx.obj.get("filename", None) - sk = None - if fingerprint is None and filename is not None: - sk = resolve_derivation_master_key(filename) + try: + fingerprint, sk = _resolve_fingerprint_and_sk(filename, fingerprint, non_observer_derivation) + except ResolutionError: + return derive_child_key( fingerprint, diff --git a/chia/cmds/keys_funcs.py b/chia/cmds/keys_funcs.py index aed2184d925c..4a4681bec91f 100644 --- a/chia/cmds/keys_funcs.py +++ b/chia/cmds/keys_funcs.py @@ -657,10 +657,7 @@ def derive_wallet_address( """ if fingerprint is not None: key_data: KeyData = Keychain().get_key(fingerprint, include_secrets=non_observer_derivation) - if non_observer_derivation and key_data.secrets is None: - print("Need a private key for non observer derivation of wallet addresses") - return - elif non_observer_derivation: + if non_observer_derivation: sk = key_data.private_key else: sk = None @@ -729,9 +726,6 @@ def derive_child_key( current_pk = private_key.get_g1() current_sk = private_key - if non_observer_derivation and current_sk is None: - raise ValueError("Cannot perform non-observer derivation on an observer-only key") - # Key type was specified if key_type is not None: path_indices: List[int] = [12381, 8444] @@ -807,16 +801,7 @@ def private_key_for_fingerprint(fingerprint: int) -> Optional[PrivateKey]: return None -def get_private_key_with_fingerprint_or_prompt(fingerprint: Optional[int]) -> Optional[PrivateKey]: - """ - Get a private key with the specified fingerprint. If fingerprint is not - specified, prompt the user to select a key. - """ - - # Return the private key matching the specified fingerprint - if fingerprint is not None: - return private_key_for_fingerprint(fingerprint) - +def prompt_for_fingerprint() -> Optional[int]: fingerprints: List[int] = [pk.get_fingerprint() for pk in Keychain().get_all_public_keys()] while True: print("Choose key:") @@ -836,7 +821,23 @@ def get_private_key_with_fingerprint_or_prompt(fingerprint: Optional[int]) -> Op val = None continue else: - return private_key_for_fingerprint(fingerprints[index]) + return fingerprints[index] + + +def get_private_key_with_fingerprint_or_prompt( + fingerprint: Optional[int], +) -> Tuple[Optional[int], Optional[PrivateKey]]: + """ + Get a private key with the specified fingerprint. If fingerprint is not + specified, prompt the user to select a key. + """ + + # Return the private key matching the specified fingerprint + if fingerprint is not None: + return fingerprint, private_key_for_fingerprint(fingerprint) + + fingerprint_prompt = prompt_for_fingerprint() + return fingerprint_prompt, None if fingerprint_prompt is None else private_key_for_fingerprint(fingerprint_prompt) def private_key_from_mnemonic_seed_file(filename: Path) -> PrivateKey: @@ -849,15 +850,15 @@ def private_key_from_mnemonic_seed_file(filename: Path) -> PrivateKey: return AugSchemeMPL.key_gen(seed) -def resolve_derivation_master_key(fingerprint_or_filename: Optional[Union[int, str, Path]]) -> PrivateKey: +def resolve_derivation_master_key( + fingerprint_or_filename: Optional[Union[int, str, Path]] +) -> Tuple[Optional[int], Optional[PrivateKey]]: """ Given a key fingerprint of file containing a mnemonic seed, return the private key. """ if fingerprint_or_filename is not None and (isinstance(fingerprint_or_filename, (str, Path))): - return private_key_from_mnemonic_seed_file(Path(os.fspath(fingerprint_or_filename))) + sk = private_key_from_mnemonic_seed_file(Path(os.fspath(fingerprint_or_filename))) + return sk.get_g1().get_fingerprint(), sk else: - ret = get_private_key_with_fingerprint_or_prompt(fingerprint_or_filename) - if ret is None: - raise ValueError("Abort. No private key") - return ret + return get_private_key_with_fingerprint_or_prompt(fingerprint_or_filename)