Skip to content

Commit

Permalink
Add Click CLI Unit Tests (Proof of Concept) (#15746)
Browse files Browse the repository at this point in the history
* testing utils v1

* testing utils v2

* tests for chia show

* oopsies

* fix chia root not working & rename class

* change to generated full block instead of pre-generated mess

* fix capsys, BASE_LIST and add more comments

* ignore mypy ...
  • Loading branch information
jack60612 authored Jul 26, 2023
1 parent e3f5fdc commit ed4c521
Show file tree
Hide file tree
Showing 4 changed files with 424 additions and 0 deletions.
224 changes: 224 additions & 0 deletions tests/cmds/cmd_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
from __future__ import annotations

import sys
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Tuple, Type, cast

import chia.cmds.wallet_funcs
from chia.cmds.chia import cli as chia_cli
from chia.cmds.cmds_util import _T_RpcClient, node_config_section_names
from chia.consensus.block_record import BlockRecord
from chia.consensus.default_constants import DEFAULT_CONSTANTS
from chia.rpc.data_layer_rpc_client import DataLayerRpcClient
from chia.rpc.farmer_rpc_client import FarmerRpcClient
from chia.rpc.full_node_rpc_client import FullNodeRpcClient
from chia.rpc.rpc_client import RpcClient
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.simulator.simulator_full_node_rpc_client import SimulatorFullNodeRpcClient
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.util.config import load_config
from chia.util.ints import uint16, uint32
from tests.cmds.testing_classes import create_test_block_record

# Any functions that are the same for every command being tested should be below.
# Functions that are specific to a command should be in the test file for that command.


@dataclass
class TestRpcClient:
client_type: Type[RpcClient]
rpc_port: Optional[uint16] = None
root_path: Optional[Path] = None
config: Optional[Dict[str, Any]] = None
create_called: bool = field(init=False, default=False)
rpc_log: Dict[str, List[Tuple[Any, ...]]] = field(init=False, default_factory=dict)

async def create(self, _: str, rpc_port: uint16, root_path: Path, config: Dict[str, Any]) -> None:
self.rpc_port = rpc_port
self.root_path = root_path
self.config = config
self.create_called = True

def add_to_log(self, method_name: str, args: Tuple[Any, ...]) -> None:
if method_name not in self.rpc_log:
self.rpc_log[method_name] = []
self.rpc_log[method_name].append(args)

def check_log(self, expected_calls: Dict[str, Optional[List[Tuple[Any, ...]]]]) -> None:
for k, v in expected_calls.items():
assert k in self.rpc_log
if v is not None: # None means we don't care about the value used when calling the rpc.
assert self.rpc_log[k] == v
self.rpc_log = {}


@dataclass
class TestFarmerRpcClient(TestRpcClient):
client_type: Type[FarmerRpcClient] = field(init=False, default=FarmerRpcClient)


@dataclass
class TestWalletRpcClient(TestRpcClient):
client_type: Type[WalletRpcClient] = field(init=False, default=WalletRpcClient)
fingerprint: int = field(init=False, default=0)


@dataclass
class TestFullNodeRpcClient(TestRpcClient):
client_type: Type[FullNodeRpcClient] = field(init=False, default=FullNodeRpcClient)

async def get_blockchain_state(self) -> Dict[str, Any]:
response: Dict[str, Any] = {
"peak": cast(BlockRecord, create_test_block_record()),
"genesis_challenge_initialized": True,
"sync": {
"sync_mode": False,
"synced": True,
"sync_tip_height": 0,
"sync_progress_height": 0,
},
"difficulty": 1024,
"sub_slot_iters": 147849216,
"space": 29569289860555554816,
"mempool_size": 3,
"mempool_cost": 88304083,
"mempool_fees": 50,
"mempool_min_fees": {
# We may give estimates for varying costs in the future
# This Dict sets us up for that in the future
"cost_5000000": 0,
},
"mempool_max_total_cost": 550000000000,
"block_max_cost": DEFAULT_CONSTANTS.MAX_BLOCK_COST_CLVM,
"node_id": "7991a584ae4784ab7525bda352ea9b155ce2ac108d361afc13d5964a0f33fa6d",
}
self.add_to_log("get_blockchain_state", ())
return response

async def get_block_record_by_height(self, height: int) -> Optional[BlockRecord]:
self.add_to_log("get_block_record_by_height", (height,))
return cast(BlockRecord, create_test_block_record(height=uint32(height)))

async def get_block_record(self, header_hash: bytes32) -> Optional[BlockRecord]:
self.add_to_log("get_block_record", (header_hash,))
return cast(BlockRecord, create_test_block_record(header_hash=header_hash))


@dataclass
class TestDataLayerRpcClient(TestRpcClient):
client_type: Type[DataLayerRpcClient] = field(init=False, default=DataLayerRpcClient)


@dataclass
class TestSimulatorFullNodeRpcClient(TestRpcClient):
client_type: Type[SimulatorFullNodeRpcClient] = field(init=False, default=SimulatorFullNodeRpcClient)


@dataclass
class TestRpcClients:
"""
Because this data is in a class, it can be modified by the tests even after the generator is created and imported.
This is important, as we need an easy way to modify the monkey-patched functions.
"""

farmer_rpc_client: TestFarmerRpcClient = TestFarmerRpcClient()
wallet_rpc_client: TestWalletRpcClient = TestWalletRpcClient()
full_node_rpc_client: TestFullNodeRpcClient = TestFullNodeRpcClient()
data_layer_rpc_client: TestDataLayerRpcClient = TestDataLayerRpcClient()
simulator_full_node_rpc_client: TestSimulatorFullNodeRpcClient = TestSimulatorFullNodeRpcClient()

def get_client(self, client_type: Type[_T_RpcClient]) -> _T_RpcClient:
if client_type == FarmerRpcClient:
return cast(FarmerRpcClient, self.farmer_rpc_client) # type: ignore[return-value]
elif client_type == WalletRpcClient:
return cast(WalletRpcClient, self.wallet_rpc_client) # type: ignore[return-value]
elif client_type == FullNodeRpcClient:
return cast(FullNodeRpcClient, self.full_node_rpc_client) # type: ignore[return-value]
elif client_type == DataLayerRpcClient:
return cast(DataLayerRpcClient, self.data_layer_rpc_client) # type: ignore[return-value]
elif client_type == SimulatorFullNodeRpcClient:
return cast(SimulatorFullNodeRpcClient, self.simulator_full_node_rpc_client) # type: ignore[return-value]
else:
raise ValueError(f"Invalid client type requested: {client_type.__name__}")


def create_service_and_wallet_client_generators(test_rpc_clients: TestRpcClients, default_root: Path) -> None:
"""
Create and monkey patch custom generators designed for testing.
These are monkey patched into the chia.cmds.cmds_util module.
Each generator below replaces the original function with a new one that returns a custom client, given by the class.
The clients given can be changed by changing the variables in the class above, after running this function.
"""

@asynccontextmanager
async def test_get_any_service_client(
client_type: Type[_T_RpcClient],
rpc_port: Optional[int] = None,
root_path: Path = default_root,
consume_errors: bool = True,
) -> AsyncIterator[Tuple[_T_RpcClient, Dict[str, Any]]]:
node_type = node_config_section_names.get(client_type)
if node_type is None:
# Click already checks this, so this should never happen
raise ValueError(f"Invalid client type requested: {client_type.__name__}")
# load variables from config file
config = load_config(
root_path,
"config.yaml",
fill_missing_services=issubclass(client_type, DataLayerRpcClient),
)
self_hostname = config["self_hostname"]
if rpc_port is None:
rpc_port = config[node_type]["rpc_port"]
test_rpc_client = test_rpc_clients.get_client(client_type)

await test_rpc_client.create(self_hostname, uint16(rpc_port), root_path, config)
yield test_rpc_client, config

@asynccontextmanager
async def test_get_wallet_client(
wallet_rpc_port: Optional[int] = None,
fingerprint: Optional[int] = None,
root_path: Path = default_root,
) -> AsyncIterator[Tuple[WalletRpcClient, int, Dict[str, Any]]]:
async with test_get_any_service_client(WalletRpcClient, wallet_rpc_port, root_path) as (wallet_client, config):
wallet_client.fingerprint = fingerprint # type: ignore
assert fingerprint is not None
yield wallet_client, fingerprint, config

# Monkey patches the functions into the module, the classes returned by these functions can be changed in the class.
# For more information, read the docstring of this function.
chia.cmds.cmds_util.get_any_service_client = test_get_any_service_client
chia.cmds.wallet_funcs.get_wallet_client = test_get_wallet_client # type: ignore[attr-defined]


def run_cli_command(capsys: object, chia_root: Path, command_list: List[str]) -> Tuple[bool, str]:
"""
This is just an easy way to run the chia CLI with the given command list.
"""
# we don't use the real capsys object because its only accessible in a private part of the pytest module
argv_temp = sys.argv
try:
sys.argv = ["chia", "--root-path", str(chia_root)] + command_list
exited_cleanly = True
try:
chia_cli() # pylint: disable=no-value-for-parameter
except SystemExit as e:
if e.code != 0:
exited_cleanly = False
output = capsys.readouterr() # type: ignore[attr-defined]
finally: # always reset sys.argv
sys.argv = argv_temp
if not exited_cleanly: # so we can look at what went wrong
print(f"\n{output.out}\n{output.err}")
return exited_cleanly, output.out


def cli_assert_shortcut(output: str, strings_to_assert: Iterable[str]) -> None:
"""
Asserts that all the strings in strings_to_assert are in the output
"""
for string_to_assert in strings_to_assert:
assert string_to_assert in output
23 changes: 23 additions & 0 deletions tests/cmds/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

import tempfile
from pathlib import Path
from typing import Iterator, Tuple

import pytest

from chia.util.config import create_default_chia_config
from tests.cmds.cmd_test_utils import TestRpcClients, create_service_and_wallet_client_generators


@pytest.fixture(scope="module") # every file has its own config generated, just to be safe
def get_test_cli_clients() -> Iterator[Tuple[TestRpcClients, Path]]:
# we cant use the normal config fixture because it only supports function scope.
with tempfile.TemporaryDirectory() as tmp_path:
root_path: Path = Path(tmp_path) / "chia_root"
root_path.mkdir(parents=True, exist_ok=True)
create_default_chia_config(root_path)
# ^ this is basically the generate config fixture.
global_test_rpc_clients = TestRpcClients()
create_service_and_wallet_client_generators(global_test_rpc_clients, root_path)
yield global_test_rpc_clients, root_path
118 changes: 118 additions & 0 deletions tests/cmds/test_show.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple

from chia.types.blockchain_format.foliage import FoliageTransactionBlock
from chia.types.blockchain_format.serialized_program import SerializedProgram
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.full_block import FullBlock
from chia.util.ints import uint32, uint64
from tests.cmds.cmd_test_utils import TestFullNodeRpcClient, TestRpcClients, cli_assert_shortcut, run_cli_command
from tests.cmds.testing_classes import hash_to_height, height_hash
from tests.util.test_full_block_utils import get_foliage, get_reward_chain_block, get_transactions_info, vdf_proof


@dataclass
class ShowFullNodeRpcClient(TestFullNodeRpcClient):
async def get_fee_estimate(self, target_times: Optional[List[int]], cost: Optional[int]) -> Dict[str, Any]:
self.add_to_log("get_fee_estimate", (target_times, cost))
response: Dict[str, Any] = {
"current_fee_rate": 0,
"estimates": [0, 0, 0],
"fee_rate_last_block": 30769.681426718744,
"fees_last_block": 500000000000,
"full_node_synced": True,
"last_block_cost": 16249762,
"last_peak_timestamp": 1688858763,
"last_tx_block_height": 11,
"mempool_fees": 0,
"mempool_max_size": 0,
"mempool_size": 0,
"node_time_utc": 1689187617,
"num_spends": 0,
"peak_height": 11,
"success": True,
"target_times": target_times,
}
return response

async def get_block(self, header_hash: bytes32) -> Optional[FullBlock]:
# we return a block with the height matching the header hash
self.add_to_log("get_block", (header_hash,))
height = hash_to_height(header_hash)
foliage = None
for foliage in get_foliage():
break
assert foliage is not None
r_chain_block = None
for r_chain_block in get_reward_chain_block(height=uint32(height)):
break
assert r_chain_block is not None
foliage_tx_block = FoliageTransactionBlock(
prev_transaction_block_hash=height_hash(height - 1),
timestamp=uint64(100400000),
filter_hash=bytes32([2] * 32),
additions_root=bytes32([3] * 32),
removals_root=bytes32([4] * 32),
transactions_info_hash=bytes32([5] * 32),
)
tx_info = None
for tx_info in get_transactions_info(height=uint32(height), foliage_transaction_block=foliage_tx_block):
break
assert tx_info is not None
full_block = FullBlock(
finished_sub_slots=[],
reward_chain_block=r_chain_block,
challenge_chain_sp_proof=None,
challenge_chain_ip_proof=vdf_proof(),
reward_chain_sp_proof=None,
reward_chain_ip_proof=vdf_proof(),
infused_challenge_chain_ip_proof=None,
foliage=foliage,
foliage_transaction_block=foliage_tx_block,
transactions_info=tx_info,
transactions_generator=SerializedProgram.from_bytes(bytes.fromhex("ff01820539")),
transactions_generator_ref_list=[],
)
return full_block


RPC_CLIENT_TO_USE = ShowFullNodeRpcClient() # pylint: disable=no-value-for-parameter


def test_chia_show(capsys: object, get_test_cli_clients: Tuple[TestRpcClients, Path]) -> None:
test_rpc_clients, root_dir = get_test_cli_clients
# set RPC Client
test_rpc_clients.full_node_rpc_client = RPC_CLIENT_TO_USE
# get output with all options
command_args = [
"show",
"-s",
"-f",
"--block-header-hash-by-height",
"10",
"-b0x000000000000000000000000000000000000000000000000000000000000000b",
]
success, output = run_cli_command(capsys, root_dir, command_args)
assert success
# these are various things that should be in the output
assert_list = [
"Current Blockchain Status: Full Node Synced",
"Estimated network space: 25.647 EiB",
"Block fees: 500000000000 mojos",
"Fee rate: 3.077e+04 mojos per CLVM cost",
f"Tx Filter Hash {bytes32([2] * 32).hex()}",
"Weight 10000",
"Is a Transaction Block?True",
]
cli_assert_shortcut(output, assert_list)
expected_calls: dict[str, Optional[List[tuple[Any, ...]]]] = { # name of rpc: (args)
"get_blockchain_state": None,
"get_block_record": [(height_hash(height),) for height in [11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 11, 10]],
"get_block_record_by_height": [(10,)],
"get_fee_estimate": [([60, 120, 300], 1)],
"get_block": [(height_hash(11),)],
} # these RPC's should be called with these variables.
test_rpc_clients.full_node_rpc_client.check_log(expected_calls)
Loading

0 comments on commit ed4c521

Please sign in to comment.