From 521fa6409277f35126141f58a296fe1672e3c51e Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 31 Oct 2024 08:58:21 -0700 Subject: [PATCH] Port `get_transactions` --- chia/_tests/cmds/wallet/test_wallet.py | 38 ++++++------- chia/_tests/pools/test_pool_rpc.py | 4 +- chia/_tests/wallet/rpc/test_wallet_rpc.py | 53 ++++++++++++------ .../_tests/wallet/vc_wallet/test_vc_wallet.py | 21 ++++--- chia/cmds/wallet_funcs.py | 16 +++++- chia/rpc/wallet_request_types.py | 55 ++++++++++++++++++- chia/rpc/wallet_rpc_api.py | 48 ++++++++-------- chia/rpc/wallet_rpc_client.py | 37 ++----------- 8 files changed, 160 insertions(+), 112 deletions(-) diff --git a/chia/_tests/cmds/wallet/test_wallet.py b/chia/_tests/cmds/wallet/test_wallet.py index 18dd55ef077c..86cd3891131f 100644 --- a/chia/_tests/cmds/wallet/test_wallet.py +++ b/chia/_tests/cmds/wallet/test_wallet.py @@ -31,12 +31,16 @@ CreateOfferForIDsResponse, GetHeightInfoResponse, GetTransaction, + GetTransactions, + GetTransactionsResponse, GetWalletBalance, GetWalletBalanceResponse, GetWallets, GetWalletsResponse, SendTransactionResponse, TakeOfferResponse, + UserFriendlyMemos, + UserFriendlyTransactionRecordWithMetadata, WalletInfoResponse, ) from chia.server.outbound_message import NodeType @@ -51,7 +55,7 @@ from chia.wallet.trading.trade_status import TradeStatus from chia.wallet.transaction_record import TransactionRecord from chia.wallet.transaction_sorting import SortKey -from chia.wallet.util.query_filter import HashFilter, TransactionTypeFilter +from chia.wallet.util.query_filter import HashFilter from chia.wallet.util.transaction_type import TransactionType from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG, TXConfig from chia.wallet.util.wallet_types import WalletType @@ -110,24 +114,13 @@ def test_get_transactions(capsys: object, get_test_cli_clients: tuple[TestRpcCli # set RPC Client class GetTransactionsWalletRpcClient(TestWalletRpcClient): - async def get_transactions( - self, - wallet_id: int, - start: int, - end: int, - sort_key: Optional[SortKey] = None, - reverse: bool = False, - to_address: Optional[str] = None, - type_filter: Optional[TransactionTypeFilter] = None, - confirmed: Optional[bool] = None, - ) -> list[TransactionRecord]: - self.add_to_log( - "get_transactions", (wallet_id, start, end, sort_key, reverse, to_address, type_filter, confirmed) - ) + async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: + self.add_to_log("get_transactions", (request,)) l_tx_rec = [] - for i in range(start, end): - t_type = TransactionType.INCOMING_CLAWBACK_SEND if i == end - 1 else TransactionType.INCOMING_TX - tx_rec = TransactionRecord( + assert request.start is not None and request.end is not None + for i in range(request.start, request.end): + t_type = TransactionType.INCOMING_CLAWBACK_SEND if i == request.end - 1 else TransactionType.INCOMING_TX + tx_rec = UserFriendlyTransactionRecordWithMetadata( confirmed_at_height=uint32(1 + i), created_at_time=uint64(1234 + i), to_puzzle_hash=bytes32([1 + i] * 32), @@ -143,12 +136,13 @@ async def get_transactions( trade_id=None, type=uint32(t_type.value), name=bytes32([2 + i] * 32), - memos=[(bytes32([3 + i] * 32), [bytes([4 + i] * 32)])], + memos=UserFriendlyMemos([(bytes32([3 + i] * 32), [bytes([4 + i] * 32)])]), valid_times=ConditionValidTimes(), + to_address="", ) l_tx_rec.append(tx_rec) - return l_tx_rec + return GetTransactionsResponse(l_tx_rec, request.wallet_id) async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]: self.add_to_log("get_coin_records", (request,)) @@ -197,8 +191,8 @@ async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]: expected_calls: logType = { "get_wallets": [(GetWallets(type=None, include_data=True),)] * 2, "get_transactions": [ - (1, 2, 4, SortKey.RELEVANCE, True, None, None, None), - (1, 2, 4, SortKey.RELEVANCE, True, None, None, None), + (GetTransactions(uint32(1), uint16(2), uint16(4), SortKey.RELEVANCE.name, True, None, None, None),), + (GetTransactions(uint32(1), uint16(2), uint16(4), SortKey.RELEVANCE.name, True, None, None, None),), ], "get_coin_records": [ (GetCoinRecords(coin_id_filter=HashFilter.include([expected_coin_id])),), diff --git a/chia/_tests/pools/test_pool_rpc.py b/chia/_tests/pools/test_pool_rpc.py index 1c51500c9b43..0dc8664ae26c 100644 --- a/chia/_tests/pools/test_pool_rpc.py +++ b/chia/_tests/pools/test_pool_rpc.py @@ -21,7 +21,7 @@ from chia.consensus.constants import ConsensusConstants from chia.pools.pool_puzzles import SINGLETON_LAUNCHER_HASH from chia.pools.pool_wallet_info import PoolSingletonState, PoolWalletInfo -from chia.rpc.wallet_request_types import GetWalletBalance, GetWallets +from chia.rpc.wallet_request_types import GetTransactions, GetWalletBalance, GetWallets from chia.rpc.wallet_rpc_client import WalletRpcClient from chia.simulator.block_tools import BlockTools, get_plot_dir from chia.simulator.full_node_simulator import FullNodeSimulator @@ -483,7 +483,7 @@ async def test_absorb_self( with pytest.raises(ValueError): await client.pw_absorb_rewards(2, uint64(fee)) - tx1 = await client.get_transactions(1) + tx1 = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert (250_000_000_000 + fee) in [tx.amount for tx in tx1] @pytest.mark.anyio diff --git a/chia/_tests/wallet/rpc/test_wallet_rpc.py b/chia/_tests/wallet/rpc/test_wallet_rpc.py index f189c6e0adf0..db2423a5dc99 100644 --- a/chia/_tests/wallet/rpc/test_wallet_rpc.py +++ b/chia/_tests/wallet/rpc/test_wallet_rpc.py @@ -57,6 +57,7 @@ GetPrivateKey, GetTimestampForHeight, GetTransaction, + GetTransactions, GetWalletBalance, GetWalletBalances, GetWallets, @@ -951,18 +952,20 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment await generate_funds(full_node_api, env.wallet_1, 5) - all_transactions = await client.get_transactions(1) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert len(all_transactions) >= 10 # Test transaction pagination - some_transactions = await client.get_transactions(1, 0, 5) - some_transactions_2 = await client.get_transactions(1, 5, 10) + some_transactions = (await client.get_transactions(GetTransactions(uint32(1), uint16(0), uint16(5)))).transactions + some_transactions_2 = ( + await client.get_transactions(GetTransactions(uint32(1), uint16(5), uint16(10))) + ).transactions assert some_transactions == all_transactions[0:5] assert some_transactions_2 == all_transactions[5:10] # Testing sorts # Test the default sort (CONFIRMED_AT_HEIGHT) assert all_transactions == sorted(all_transactions, key=attrgetter("confirmed_at_height")) - all_transactions = await client.get_transactions(1, reverse=True) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1), reverse=True))).transactions assert all_transactions == sorted(all_transactions, key=attrgetter("confirmed_at_height"), reverse=True) # Test RELEVANCE @@ -972,13 +975,17 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment 1, uint64(1), encode_puzzle_hash(puzhash, "txch"), DEFAULT_TX_CONFIG ) # Create a pending tx - all_transactions = await client.get_transactions(1, sort_key=SortKey.RELEVANCE) + all_transactions = ( + await client.get_transactions(GetTransactions(uint32(1), sort_key=SortKey.RELEVANCE.name)) + ).transactions sorted_transactions = sorted(all_transactions, key=attrgetter("created_at_time"), reverse=True) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed_at_height"), reverse=True) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed")) assert all_transactions == sorted_transactions - all_transactions = await client.get_transactions(1, sort_key=SortKey.RELEVANCE, reverse=True) + all_transactions = ( + await client.get_transactions(GetTransactions(uint32(1), sort_key=SortKey.RELEVANCE.name, reverse=True)) + ).transactions sorted_transactions = sorted(all_transactions, key=attrgetter("created_at_time")) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed_at_height")) sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed"), reverse=True) @@ -989,31 +996,43 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) await client.send_transaction(1, uint64(1), encode_puzzle_hash(ph_by_addr, "txch"), DEFAULT_TX_CONFIG) await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20) - tx_for_address = await client.get_transactions(1, to_address=encode_puzzle_hash(ph_by_addr, "txch")) + tx_for_address = ( + await client.get_transactions(GetTransactions(uint32(1), to_address=encode_puzzle_hash(ph_by_addr, "txch"))) + ).transactions assert len(tx_for_address) == 1 assert tx_for_address[0].to_puzzle_hash == ph_by_addr # Test type filter - all_transactions = await client.get_transactions( - 1, type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD]) - ) + all_transactions = ( + await client.get_transactions( + GetTransactions(uint32(1), type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD])) + ) + ).transactions assert len(all_transactions) == 5 assert all(transaction.type == TransactionType.COINBASE_REWARD for transaction in all_transactions) # Test confirmed filter - all_transactions = await client.get_transactions(1, confirmed=True) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1), confirmed=True))).transactions assert len(all_transactions) == 10 assert all(transaction.confirmed for transaction in all_transactions) - all_transactions = await client.get_transactions(1, confirmed=False) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1), confirmed=False))).transactions assert len(all_transactions) == 2 assert all(not transaction.confirmed for transaction in all_transactions) # Test bypass broken txs await wallet.wallet_state_manager.tx_store.add_transaction_record( - dataclasses.replace(all_transactions[0], type=uint32(TransactionType.INCOMING_CLAWBACK_SEND)) - ) - all_transactions = await client.get_transactions( - 1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]), confirmed=False + dataclasses.replace( + all_transactions[0].to_transaction_record(), type=uint32(TransactionType.INCOMING_CLAWBACK_SEND) + ) ) + all_transactions = ( + await client.get_transactions( + GetTransactions( + uint32(1), + type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]), + confirmed=False, + ) + ) + ).transactions assert len(all_transactions) == 1 @@ -1026,7 +1045,7 @@ async def test_get_transaction_count(wallet_rpc_environment: WalletRpcTestEnviro await generate_funds(full_node_api, env.wallet_1) - all_transactions = await client.get_transactions(1) + all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions assert len(all_transactions) > 0 transaction_count = await client.get_transaction_count(1) assert transaction_count == len(all_transactions) diff --git a/chia/_tests/wallet/vc_wallet/test_vc_wallet.py b/chia/_tests/wallet/vc_wallet/test_vc_wallet.py index 7ac169fe4ed2..f96a460d70be 100644 --- a/chia/_tests/wallet/vc_wallet/test_vc_wallet.py +++ b/chia/_tests/wallet/vc_wallet/test_vc_wallet.py @@ -6,11 +6,12 @@ import pytest from chia_rs import G2Element +from chia_rs.sized_ints import uint32 from typing_extensions import Literal from chia._tests.environments.wallet import WalletEnvironment, WalletStateTransition, WalletTestFramework from chia._tests.util.time_out_assert import time_out_assert_not_none -from chia.rpc.wallet_request_types import GetWallets, WalletInfoResponse +from chia.rpc.wallet_request_types import GetTransactions, GetWallets, WalletInfoResponse from chia.rpc.wallet_rpc_client import WalletRpcClient from chia.simulator.full_node_simulator import FullNodeSimulator from chia.types.blockchain_format.coin import coin_as_list @@ -430,13 +431,17 @@ async def test_vc_lifecycle(wallet_environments: WalletTestFramework) -> None: assert await wallet_node_1.wallet_state_manager.wallets[env_1.dealias_wallet_id("crcat")].match_hinted_coin( next(c for tx in txs for c in tx.additions if c.amount == 90), wallet_1_ph ) - pending_tx = await client_1.get_transactions( - env_1.dealias_wallet_id("crcat"), - 0, - 1, - reverse=True, - type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CRCAT_PENDING]), - ) + pending_tx = ( + await client_1.get_transactions( + GetTransactions( + uint32(env_1.dealias_wallet_id("crcat")), + uint16(0), + uint16(1), + reverse=True, + type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CRCAT_PENDING]), + ) + ) + ).transactions assert len(pending_tx) == 1 # Send the VC to wallet_1 to use for the CR-CATs diff --git a/chia/cmds/wallet_funcs.py b/chia/cmds/wallet_funcs.py index 98416c738030..10fe46507c8c 100644 --- a/chia/cmds/wallet_funcs.py +++ b/chia/cmds/wallet_funcs.py @@ -25,6 +25,7 @@ CATSpendResponse, GetNotifications, GetTransaction, + GetTransactions, GetWalletBalance, GetWallets, SendTransactionResponse, @@ -207,9 +208,18 @@ async def get_transactions( [TransactionType.INCOMING_CLAWBACK_RECEIVE, TransactionType.INCOMING_CLAWBACK_SEND] ) ) - txs: list[TransactionRecord] = await wallet_client.get_transactions( - wallet_id, start=offset, end=(offset + limit), sort_key=sort_key, reverse=reverse, type_filter=type_filter - ) + txs = ( + await wallet_client.get_transactions( + GetTransactions( + uint32(wallet_id), + start=uint16(offset), + end=uint16(offset + limit), + sort_key=sort_key.name, + reverse=reverse, + type_filter=type_filter, + ) + ) + ).transactions address_prefix = selected_network_address_prefix(config) if len(txs) == 0: diff --git a/chia/rpc/wallet_request_types.py b/chia/rpc/wallet_request_types.py index 69c4b83cec45..f6b30467ab48 100644 --- a/chia/rpc/wallet_request_types.py +++ b/chia/rpc/wallet_request_types.py @@ -24,7 +24,9 @@ from chia.wallet.trade_record import TradeRecord from chia.wallet.trading.offer import Offer from chia.wallet.transaction_record import TransactionRecord +from chia.wallet.transaction_sorting import SortKey from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable +from chia.wallet.util.query_filter import TransactionTypeFilter from chia.wallet.util.tx_config import TXConfig from chia.wallet.vc_wallet.vc_store import VCRecord from chia.wallet.wallet_info import WalletInfo @@ -52,6 +54,12 @@ class UserFriendlyMemos: def __init__(self, unfriendly_memos: list[tuple[bytes32, list[bytes]]]) -> None: self.unfriendly_memos = unfriendly_memos + def __eq__(self, other: Any) -> bool: + if isinstance(other, UserFriendlyMemos) and other.unfriendly_memos == self.unfriendly_memos: + return True + else: + return False + def __bytes__(self) -> bytes: raise NotImplementedError("Should not be serializing this object as bytes, it's only for RPC") @@ -74,6 +82,9 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> UserFriendlyMemos: ) +_T_UserFriendlyTransactionRecord = TypeVar("_T_UserFriendlyTransactionRecord", bound="UserFriendlyTransactionRecord") + + @streamable @dataclass(frozen=True) class UserFriendlyTransactionRecord(TransactionRecord): @@ -87,9 +98,11 @@ def to_transaction_record(self) -> TransactionRecord: return TransactionRecord.from_json_dict_convenience(self.to_json_dict()) @classmethod - def from_transaction_record(cls, tx: TransactionRecord, config: dict[str, Any]) -> UserFriendlyTransactionRecord: + def from_transaction_record( + cls: type[_T_UserFriendlyTransactionRecord], tx: TransactionRecord, config: dict[str, Any] + ) -> _T_UserFriendlyTransactionRecord: dict_convenience = tx.to_json_dict_convenience(config) - return super().from_json_dict(dict_convenience) + return cls.from_json_dict(dict_convenience) @streamable @@ -321,6 +334,44 @@ class GetTransactionResponse(Streamable): transaction_id: bytes32 +@streamable +@dataclass(frozen=True) +class GetTransactions(Streamable): + wallet_id: uint32 + start: Optional[uint16] = None + end: Optional[uint16] = None + sort_key: Optional[str] = None + reverse: bool = False + to_address: Optional[str] = None + type_filter: Optional[TransactionTypeFilter] = None + confirmed: Optional[bool] = None + + def __post_init__(self) -> None: + if self.sort_key is not None and self.sort_key not in SortKey.__members__: + raise ValueError(f"There is no known sort {self.sort_key}") + + +# utilities for GetTransactionsResponse +@streamable +@dataclass(frozen=True) +class TransactionRecordMetadata(Streamable): + coin_id: bytes32 + spent: bool + + +@streamable +@dataclass(frozen=True) +class UserFriendlyTransactionRecordWithMetadata(UserFriendlyTransactionRecord): + metadata: Optional[TransactionRecordMetadata] = None + + +@streamable +@dataclass(frozen=True) +class GetTransactionsResponse(Streamable): + transactions: list[UserFriendlyTransactionRecordWithMetadata] + wallet_id: uint32 + + @streamable @dataclass(frozen=True) class GetNotifications(Streamable): diff --git a/chia/rpc/wallet_rpc_api.py b/chia/rpc/wallet_rpc_api.py index 03fff48fb968..9bc7a70716b6 100644 --- a/chia/rpc/wallet_rpc_api.py +++ b/chia/rpc/wallet_rpc_api.py @@ -49,6 +49,8 @@ GetTimestampForHeightResponse, GetTransaction, GetTransactionResponse, + GetTransactions, + GetTransactionsResponse, GetWalletBalance, GetWalletBalanceResponse, GetWalletBalances, @@ -66,6 +68,7 @@ SubmitTransactions, SubmitTransactionsResponse, UserFriendlyTransactionRecord, + UserFriendlyTransactionRecordWithMetadata, WalletInfoResponse, ) from chia.server.outbound_message import NodeType @@ -1279,31 +1282,21 @@ async def combine_coins( return CombineCoinsResponse([], []) # tx_endpoint will take care to fill this out - async def get_transactions(self, request: dict[str, Any]) -> EndpointResult: - wallet_id = int(request["wallet_id"]) - - start = request.get("start", 0) - end = request.get("end", 50) - sort_key = request.get("sort_key", None) - reverse = request.get("reverse", False) - - to_address = request.get("to_address", None) + @marshal + async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: to_puzzle_hash: Optional[bytes32] = None - if to_address is not None: - to_puzzle_hash = decode_puzzle_hash(to_address) - type_filter = None - if "type_filter" in request: - type_filter = TransactionTypeFilter.from_json_dict(request["type_filter"]) + if request.to_address is not None: + to_puzzle_hash = decode_puzzle_hash(request.to_address) transactions = await self.service.wallet_state_manager.tx_store.get_transactions_between( - wallet_id, - start, - end, - sort_key=sort_key, - reverse=reverse, + wallet_id=request.wallet_id, + start=uint16(0) if request.start is None else request.start, + end=uint16(50) if request.end is None else request.end, + sort_key=request.sort_key, + reverse=request.reverse, to_puzzle_hash=to_puzzle_hash, - type_filter=type_filter, - confirmed=request.get("confirmed", None), + type_filter=request.type_filter, + confirmed=request.confirmed, ) tx_list = [] # Format for clawback transactions @@ -1326,10 +1319,15 @@ async def get_transactions(self, request: dict[str, Any]) -> EndpointResult: continue tx["metadata"]["coin_id"] = coin.name().hex() tx["metadata"]["spent"] = record.spent - return { - "transactions": tx_list, - "wallet_id": wallet_id, - } + return GetTransactionsResponse( + transactions=[ + UserFriendlyTransactionRecordWithMetadata.from_transaction_record( + TransactionRecord.from_json_dict_convenience(tx), self.service.config + ) + for tx in tx_list + ], + wallet_id=request.wallet_id, + ) async def get_transaction_count(self, request: dict[str, Any]) -> EndpointResult: wallet_id = int(request["wallet_id"]) diff --git a/chia/rpc/wallet_rpc_client.py b/chia/rpc/wallet_rpc_client.py index e5fbd66ce81c..ed44dc880666 100644 --- a/chia/rpc/wallet_rpc_client.py +++ b/chia/rpc/wallet_rpc_client.py @@ -61,6 +61,8 @@ GetTransactionMemo, GetTransactionMemoResponse, GetTransactionResponse, + GetTransactions, + GetTransactionsResponse, GetWalletBalance, GetWalletBalanceResponse, GetWalletBalances, @@ -110,7 +112,6 @@ from chia.wallet.trade_record import TradeRecord from chia.wallet.trading.offer import Offer from chia.wallet.transaction_record import TransactionRecord -from chia.wallet.transaction_sorting import SortKey from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable from chia.wallet.util.query_filter import TransactionTypeFilter from chia.wallet.util.tx_config import CoinSelectionConfig, TXConfig @@ -214,38 +215,8 @@ async def get_wallet_balances(self, request: GetWalletBalances) -> GetWalletBala async def get_transaction(self, request: GetTransaction) -> GetTransactionResponse: return GetTransactionResponse.from_json_dict(await self.fetch("get_transaction", request.to_json_dict())) - async def get_transactions( - self, - wallet_id: int, - start: Optional[int] = None, - end: Optional[int] = None, - sort_key: Optional[SortKey] = None, - reverse: bool = False, - to_address: Optional[str] = None, - type_filter: Optional[TransactionTypeFilter] = None, - confirmed: Optional[bool] = None, - ) -> list[TransactionRecord]: - request: dict[str, Any] = {"wallet_id": wallet_id} - - if start is not None: - request["start"] = start - if end is not None: - request["end"] = end - if sort_key is not None: - request["sort_key"] = sort_key.name - request["reverse"] = reverse - - if to_address is not None: - request["to_address"] = to_address - - if type_filter is not None: - request["type_filter"] = type_filter.to_json_dict() - - if confirmed is not None: - request["confirmed"] = confirmed - - res = await self.fetch("get_transactions", request) - return [TransactionRecord.from_json_dict_convenience(tx) for tx in res["transactions"]] + async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse: + return GetTransactionsResponse.from_json_dict(await self.fetch("get_transactions", request.to_json_dict())) async def get_transaction_count( self, wallet_id: int, confirmed: Optional[bool] = None, type_filter: Optional[TransactionTypeFilter] = None