From ed4c5217cc15d912cacd05a8d99936ec7dc06fcc Mon Sep 17 00:00:00 2001 From: Jack Nelson Date: Tue, 25 Jul 2023 22:42:52 -0400 Subject: [PATCH] Add Click CLI Unit Tests (Proof of Concept) (#15746) * testing utils v1 * testing utils v2 * tests for chia show * oopsies * fix chia root not working & rename class * change to generated full block instead of pre-generated mess * fix capsys, BASE_LIST and add more comments * ignore mypy ... --- tests/cmds/cmd_test_utils.py | 224 ++++++++++++++++++++++++++++++++++ tests/cmds/conftest.py | 23 ++++ tests/cmds/test_show.py | 118 ++++++++++++++++++ tests/cmds/testing_classes.py | 59 +++++++++ 4 files changed, 424 insertions(+) create mode 100644 tests/cmds/cmd_test_utils.py create mode 100644 tests/cmds/conftest.py create mode 100644 tests/cmds/test_show.py create mode 100644 tests/cmds/testing_classes.py diff --git a/tests/cmds/cmd_test_utils.py b/tests/cmds/cmd_test_utils.py new file mode 100644 index 000000000000..cac7209377df --- /dev/null +++ b/tests/cmds/cmd_test_utils.py @@ -0,0 +1,224 @@ +from __future__ import annotations + +import sys +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type, cast + +import chia.cmds.wallet_funcs +from chia.cmds.chia import cli as chia_cli +from chia.cmds.cmds_util import _T_RpcClient, node_config_section_names +from chia.consensus.block_record import BlockRecord +from chia.consensus.default_constants import DEFAULT_CONSTANTS +from chia.rpc.data_layer_rpc_client import DataLayerRpcClient +from chia.rpc.farmer_rpc_client import FarmerRpcClient +from chia.rpc.full_node_rpc_client import FullNodeRpcClient +from chia.rpc.rpc_client import RpcClient +from chia.rpc.wallet_rpc_client import WalletRpcClient +from chia.simulator.simulator_full_node_rpc_client import SimulatorFullNodeRpcClient +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.config import load_config +from chia.util.ints import uint16, uint32 +from tests.cmds.testing_classes import create_test_block_record + +# Any functions that are the same for every command being tested should be below. +# Functions that are specific to a command should be in the test file for that command. + + +@dataclass +class TestRpcClient: + client_type: Type[RpcClient] + rpc_port: Optional[uint16] = None + root_path: Optional[Path] = None + config: Optional[Dict[str, Any]] = None + create_called: bool = field(init=False, default=False) + rpc_log: Dict[str, List[Tuple[Any, ...]]] = field(init=False, default_factory=dict) + + async def create(self, _: str, rpc_port: uint16, root_path: Path, config: Dict[str, Any]) -> None: + self.rpc_port = rpc_port + self.root_path = root_path + self.config = config + self.create_called = True + + def add_to_log(self, method_name: str, args: Tuple[Any, ...]) -> None: + if method_name not in self.rpc_log: + self.rpc_log[method_name] = [] + self.rpc_log[method_name].append(args) + + def check_log(self, expected_calls: Dict[str, Optional[List[Tuple[Any, ...]]]]) -> None: + for k, v in expected_calls.items(): + assert k in self.rpc_log + if v is not None: # None means we don't care about the value used when calling the rpc. + assert self.rpc_log[k] == v + self.rpc_log = {} + + +@dataclass +class TestFarmerRpcClient(TestRpcClient): + client_type: Type[FarmerRpcClient] = field(init=False, default=FarmerRpcClient) + + +@dataclass +class TestWalletRpcClient(TestRpcClient): + client_type: Type[WalletRpcClient] = field(init=False, default=WalletRpcClient) + fingerprint: int = field(init=False, default=0) + + +@dataclass +class TestFullNodeRpcClient(TestRpcClient): + client_type: Type[FullNodeRpcClient] = field(init=False, default=FullNodeRpcClient) + + async def get_blockchain_state(self) -> Dict[str, Any]: + response: Dict[str, Any] = { + "peak": cast(BlockRecord, create_test_block_record()), + "genesis_challenge_initialized": True, + "sync": { + "sync_mode": False, + "synced": True, + "sync_tip_height": 0, + "sync_progress_height": 0, + }, + "difficulty": 1024, + "sub_slot_iters": 147849216, + "space": 29569289860555554816, + "mempool_size": 3, + "mempool_cost": 88304083, + "mempool_fees": 50, + "mempool_min_fees": { + # We may give estimates for varying costs in the future + # This Dict sets us up for that in the future + "cost_5000000": 0, + }, + "mempool_max_total_cost": 550000000000, + "block_max_cost": DEFAULT_CONSTANTS.MAX_BLOCK_COST_CLVM, + "node_id": "7991a584ae4784ab7525bda352ea9b155ce2ac108d361afc13d5964a0f33fa6d", + } + self.add_to_log("get_blockchain_state", ()) + return response + + async def get_block_record_by_height(self, height: int) -> Optional[BlockRecord]: + self.add_to_log("get_block_record_by_height", (height,)) + return cast(BlockRecord, create_test_block_record(height=uint32(height))) + + async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]: + self.add_to_log("get_block_record", (header_hash,)) + return cast(BlockRecord, create_test_block_record(header_hash=header_hash)) + + +@dataclass +class TestDataLayerRpcClient(TestRpcClient): + client_type: Type[DataLayerRpcClient] = field(init=False, default=DataLayerRpcClient) + + +@dataclass +class TestSimulatorFullNodeRpcClient(TestRpcClient): + client_type: Type[SimulatorFullNodeRpcClient] = field(init=False, default=SimulatorFullNodeRpcClient) + + +@dataclass +class TestRpcClients: + """ + Because this data is in a class, it can be modified by the tests even after the generator is created and imported. + This is important, as we need an easy way to modify the monkey-patched functions. + """ + + farmer_rpc_client: TestFarmerRpcClient = TestFarmerRpcClient() + wallet_rpc_client: TestWalletRpcClient = TestWalletRpcClient() + full_node_rpc_client: TestFullNodeRpcClient = TestFullNodeRpcClient() + data_layer_rpc_client: TestDataLayerRpcClient = TestDataLayerRpcClient() + simulator_full_node_rpc_client: TestSimulatorFullNodeRpcClient = TestSimulatorFullNodeRpcClient() + + def get_client(self, client_type: Type[_T_RpcClient]) -> _T_RpcClient: + if client_type == FarmerRpcClient: + return cast(FarmerRpcClient, self.farmer_rpc_client) # type: ignore[return-value] + elif client_type == WalletRpcClient: + return cast(WalletRpcClient, self.wallet_rpc_client) # type: ignore[return-value] + elif client_type == FullNodeRpcClient: + return cast(FullNodeRpcClient, self.full_node_rpc_client) # type: ignore[return-value] + elif client_type == DataLayerRpcClient: + return cast(DataLayerRpcClient, self.data_layer_rpc_client) # type: ignore[return-value] + elif client_type == SimulatorFullNodeRpcClient: + return cast(SimulatorFullNodeRpcClient, self.simulator_full_node_rpc_client) # type: ignore[return-value] + else: + raise ValueError(f"Invalid client type requested: {client_type.__name__}") + + +def create_service_and_wallet_client_generators(test_rpc_clients: TestRpcClients, default_root: Path) -> None: + """ + Create and monkey patch custom generators designed for testing. + These are monkey patched into the chia.cmds.cmds_util module. + Each generator below replaces the original function with a new one that returns a custom client, given by the class. + The clients given can be changed by changing the variables in the class above, after running this function. + """ + + @asynccontextmanager + async def test_get_any_service_client( + client_type: Type[_T_RpcClient], + rpc_port: Optional[int] = None, + root_path: Path = default_root, + consume_errors: bool = True, + ) -> AsyncIterator[Tuple[_T_RpcClient, Dict[str, Any]]]: + node_type = node_config_section_names.get(client_type) + if node_type is None: + # Click already checks this, so this should never happen + raise ValueError(f"Invalid client type requested: {client_type.__name__}") + # load variables from config file + config = load_config( + root_path, + "config.yaml", + fill_missing_services=issubclass(client_type, DataLayerRpcClient), + ) + self_hostname = config["self_hostname"] + if rpc_port is None: + rpc_port = config[node_type]["rpc_port"] + test_rpc_client = test_rpc_clients.get_client(client_type) + + await test_rpc_client.create(self_hostname, uint16(rpc_port), root_path, config) + yield test_rpc_client, config + + @asynccontextmanager + async def test_get_wallet_client( + wallet_rpc_port: Optional[int] = None, + fingerprint: Optional[int] = None, + root_path: Path = default_root, + ) -> AsyncIterator[Tuple[WalletRpcClient, int, Dict[str, Any]]]: + async with test_get_any_service_client(WalletRpcClient, wallet_rpc_port, root_path) as (wallet_client, config): + wallet_client.fingerprint = fingerprint # type: ignore + assert fingerprint is not None + yield wallet_client, fingerprint, config + + # Monkey patches the functions into the module, the classes returned by these functions can be changed in the class. + # For more information, read the docstring of this function. + chia.cmds.cmds_util.get_any_service_client = test_get_any_service_client + chia.cmds.wallet_funcs.get_wallet_client = test_get_wallet_client # type: ignore[attr-defined] + + +def run_cli_command(capsys: object, chia_root: Path, command_list: List[str]) -> Tuple[bool, str]: + """ + This is just an easy way to run the chia CLI with the given command list. + """ + # we don't use the real capsys object because its only accessible in a private part of the pytest module + argv_temp = sys.argv + try: + sys.argv = ["chia", "--root-path", str(chia_root)] + command_list + exited_cleanly = True + try: + chia_cli() # pylint: disable=no-value-for-parameter + except SystemExit as e: + if e.code != 0: + exited_cleanly = False + output = capsys.readouterr() # type: ignore[attr-defined] + finally: # always reset sys.argv + sys.argv = argv_temp + if not exited_cleanly: # so we can look at what went wrong + print(f"\n{output.out}\n{output.err}") + return exited_cleanly, output.out + + +def cli_assert_shortcut(output: str, strings_to_assert: Iterable[str]) -> None: + """ + Asserts that all the strings in strings_to_assert are in the output + """ + for string_to_assert in strings_to_assert: + assert string_to_assert in output diff --git a/tests/cmds/conftest.py b/tests/cmds/conftest.py new file mode 100644 index 000000000000..8378cd732611 --- /dev/null +++ b/tests/cmds/conftest.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +import tempfile +from pathlib import Path +from typing import Iterator, Tuple + +import pytest + +from chia.util.config import create_default_chia_config +from tests.cmds.cmd_test_utils import TestRpcClients, create_service_and_wallet_client_generators + + +@pytest.fixture(scope="module") # every file has its own config generated, just to be safe +def get_test_cli_clients() -> Iterator[Tuple[TestRpcClients, Path]]: + # we cant use the normal config fixture because it only supports function scope. + with tempfile.TemporaryDirectory() as tmp_path: + root_path: Path = Path(tmp_path) / "chia_root" + root_path.mkdir(parents=True, exist_ok=True) + create_default_chia_config(root_path) + # ^ this is basically the generate config fixture. + global_test_rpc_clients = TestRpcClients() + create_service_and_wallet_client_generators(global_test_rpc_clients, root_path) + yield global_test_rpc_clients, root_path diff --git a/tests/cmds/test_show.py b/tests/cmds/test_show.py new file mode 100644 index 000000000000..ca417c6f84b6 --- /dev/null +++ b/tests/cmds/test_show.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from chia.types.blockchain_format.foliage import FoliageTransactionBlock +from chia.types.blockchain_format.serialized_program import SerializedProgram +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.full_block import FullBlock +from chia.util.ints import uint32, uint64 +from tests.cmds.cmd_test_utils import TestFullNodeRpcClient, TestRpcClients, cli_assert_shortcut, run_cli_command +from tests.cmds.testing_classes import hash_to_height, height_hash +from tests.util.test_full_block_utils import get_foliage, get_reward_chain_block, get_transactions_info, vdf_proof + + +@dataclass +class ShowFullNodeRpcClient(TestFullNodeRpcClient): + async def get_fee_estimate(self, target_times: Optional[List[int]], cost: Optional[int]) -> Dict[str, Any]: + self.add_to_log("get_fee_estimate", (target_times, cost)) + response: Dict[str, Any] = { + "current_fee_rate": 0, + "estimates": [0, 0, 0], + "fee_rate_last_block": 30769.681426718744, + "fees_last_block": 500000000000, + "full_node_synced": True, + "last_block_cost": 16249762, + "last_peak_timestamp": 1688858763, + "last_tx_block_height": 11, + "mempool_fees": 0, + "mempool_max_size": 0, + "mempool_size": 0, + "node_time_utc": 1689187617, + "num_spends": 0, + "peak_height": 11, + "success": True, + "target_times": target_times, + } + return response + + async def get_block(self, header_hash: bytes32) -> Optional[FullBlock]: + # we return a block with the height matching the header hash + self.add_to_log("get_block", (header_hash,)) + height = hash_to_height(header_hash) + foliage = None + for foliage in get_foliage(): + break + assert foliage is not None + r_chain_block = None + for r_chain_block in get_reward_chain_block(height=uint32(height)): + break + assert r_chain_block is not None + foliage_tx_block = FoliageTransactionBlock( + prev_transaction_block_hash=height_hash(height - 1), + timestamp=uint64(100400000), + filter_hash=bytes32([2] * 32), + additions_root=bytes32([3] * 32), + removals_root=bytes32([4] * 32), + transactions_info_hash=bytes32([5] * 32), + ) + tx_info = None + for tx_info in get_transactions_info(height=uint32(height), foliage_transaction_block=foliage_tx_block): + break + assert tx_info is not None + full_block = FullBlock( + finished_sub_slots=[], + reward_chain_block=r_chain_block, + challenge_chain_sp_proof=None, + challenge_chain_ip_proof=vdf_proof(), + reward_chain_sp_proof=None, + reward_chain_ip_proof=vdf_proof(), + infused_challenge_chain_ip_proof=None, + foliage=foliage, + foliage_transaction_block=foliage_tx_block, + transactions_info=tx_info, + transactions_generator=SerializedProgram.from_bytes(bytes.fromhex("ff01820539")), + transactions_generator_ref_list=[], + ) + return full_block + + +RPC_CLIENT_TO_USE = ShowFullNodeRpcClient() # pylint: disable=no-value-for-parameter + + +def test_chia_show(capsys: object, get_test_cli_clients: Tuple[TestRpcClients, Path]) -> None: + test_rpc_clients, root_dir = get_test_cli_clients + # set RPC Client + test_rpc_clients.full_node_rpc_client = RPC_CLIENT_TO_USE + # get output with all options + command_args = [ + "show", + "-s", + "-f", + "--block-header-hash-by-height", + "10", + "-b0x000000000000000000000000000000000000000000000000000000000000000b", + ] + success, output = run_cli_command(capsys, root_dir, command_args) + assert success + # these are various things that should be in the output + assert_list = [ + "Current Blockchain Status: Full Node Synced", + "Estimated network space: 25.647 EiB", + "Block fees: 500000000000 mojos", + "Fee rate: 3.077e+04 mojos per CLVM cost", + f"Tx Filter Hash {bytes32([2] * 32).hex()}", + "Weight 10000", + "Is a Transaction Block?True", + ] + cli_assert_shortcut(output, assert_list) + expected_calls: dict[str, Optional[List[tuple[Any, ...]]]] = { # name of rpc: (args) + "get_blockchain_state": None, + "get_block_record": [(height_hash(height),) for height in [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 11, 10]], + "get_block_record_by_height": [(10,)], + "get_fee_estimate": [([60, 120, 300], 1)], + "get_block": [(height_hash(11),)], + } # these RPC's should be called with these variables. + test_rpc_clients.full_node_rpc_client.check_log(expected_calls) diff --git a/tests/cmds/testing_classes.py b/tests/cmds/testing_classes.py new file mode 100644 index 000000000000..98c121bc12f6 --- /dev/null +++ b/tests/cmds/testing_classes.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.util.ints import uint8, uint32, uint64 + +# This is a modified version of the TestBlockRecord from test_mempool_manager.py + + +@dataclass(frozen=True) +class TestBlockRecord: + """ + This is a subset of BlockRecord that the cli tests need + """ + + header_hash: bytes32 + height: uint32 + timestamp: Optional[uint64] + prev_transaction_block_height: uint32 + prev_transaction_block_hash: Optional[bytes32] + prev_hash: Optional[bytes32] + weight: uint64 = uint64(10000) + fees: uint64 = uint64(5000) + farmer_puzzle_hash: bytes32 = bytes32([1] * 32) + pool_puzzle_hash: bytes32 = bytes32([2] * 32) + sub_slot_iters: uint64 = uint64(1024) + total_iters: uint64 = uint64(12081) + deficit: uint8 = uint8(0) + + @property + def is_transaction_block(self) -> bool: + return self.timestamp is not None + + +def height_hash(height: int) -> bytes32: + return bytes32(height.to_bytes(32, byteorder="big")) + + +def hash_to_height(int_bytes: bytes32) -> int: + return int.from_bytes(int_bytes, byteorder="big") + + +def create_test_block_record( + *, height: uint32 = uint32(11), timestamp: uint64 = uint64(10040), header_hash: Optional[bytes32] = None +) -> TestBlockRecord: + if header_hash is None: + header_hash = height_hash(height) + else: + height = uint32(hash_to_height(header_hash)) # so the heights make sense + return TestBlockRecord( + header_hash=header_hash, + height=height, + timestamp=timestamp, + prev_transaction_block_height=uint32(height - 1), + prev_transaction_block_hash=height_hash(height - 1), + prev_hash=height_hash(height - 1), + )