Skip to content

Commit

Permalink
Port get_transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
Quexington committed Oct 31, 2024
1 parent 5b06ba0 commit 521fa64
Show file tree
Hide file tree
Showing 8 changed files with 160 additions and 112 deletions.
38 changes: 16 additions & 22 deletions chia/_tests/cmds/wallet/test_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,16 @@
CreateOfferForIDsResponse,
GetHeightInfoResponse,
GetTransaction,
GetTransactions,
GetTransactionsResponse,
GetWalletBalance,
GetWalletBalanceResponse,
GetWallets,
GetWalletsResponse,
SendTransactionResponse,
TakeOfferResponse,
UserFriendlyMemos,
UserFriendlyTransactionRecordWithMetadata,
WalletInfoResponse,
)
from chia.server.outbound_message import NodeType
Expand All @@ -51,7 +55,7 @@
from chia.wallet.trading.trade_status import TradeStatus
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.transaction_sorting import SortKey
from chia.wallet.util.query_filter import HashFilter, TransactionTypeFilter
from chia.wallet.util.query_filter import HashFilter
from chia.wallet.util.transaction_type import TransactionType
from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG, TXConfig
from chia.wallet.util.wallet_types import WalletType
Expand Down Expand Up @@ -110,24 +114,13 @@ def test_get_transactions(capsys: object, get_test_cli_clients: tuple[TestRpcCli

# set RPC Client
class GetTransactionsWalletRpcClient(TestWalletRpcClient):
async def get_transactions(
self,
wallet_id: int,
start: int,
end: int,
sort_key: Optional[SortKey] = None,
reverse: bool = False,
to_address: Optional[str] = None,
type_filter: Optional[TransactionTypeFilter] = None,
confirmed: Optional[bool] = None,
) -> list[TransactionRecord]:
self.add_to_log(
"get_transactions", (wallet_id, start, end, sort_key, reverse, to_address, type_filter, confirmed)
)
async def get_transactions(self, request: GetTransactions) -> GetTransactionsResponse:
self.add_to_log("get_transactions", (request,))
l_tx_rec = []
for i in range(start, end):
t_type = TransactionType.INCOMING_CLAWBACK_SEND if i == end - 1 else TransactionType.INCOMING_TX
tx_rec = TransactionRecord(
assert request.start is not None and request.end is not None
for i in range(request.start, request.end):
t_type = TransactionType.INCOMING_CLAWBACK_SEND if i == request.end - 1 else TransactionType.INCOMING_TX
tx_rec = UserFriendlyTransactionRecordWithMetadata(
confirmed_at_height=uint32(1 + i),
created_at_time=uint64(1234 + i),
to_puzzle_hash=bytes32([1 + i] * 32),
Expand All @@ -143,12 +136,13 @@ async def get_transactions(
trade_id=None,
type=uint32(t_type.value),
name=bytes32([2 + i] * 32),
memos=[(bytes32([3 + i] * 32), [bytes([4 + i] * 32)])],
memos=UserFriendlyMemos([(bytes32([3 + i] * 32), [bytes([4 + i] * 32)])]),
valid_times=ConditionValidTimes(),
to_address="",
)
l_tx_rec.append(tx_rec)

return l_tx_rec
return GetTransactionsResponse(l_tx_rec, request.wallet_id)

async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]:
self.add_to_log("get_coin_records", (request,))
Expand Down Expand Up @@ -197,8 +191,8 @@ async def get_coin_records(self, request: GetCoinRecords) -> dict[str, Any]:
expected_calls: logType = {
"get_wallets": [(GetWallets(type=None, include_data=True),)] * 2,
"get_transactions": [
(1, 2, 4, SortKey.RELEVANCE, True, None, None, None),
(1, 2, 4, SortKey.RELEVANCE, True, None, None, None),
(GetTransactions(uint32(1), uint16(2), uint16(4), SortKey.RELEVANCE.name, True, None, None, None),),
(GetTransactions(uint32(1), uint16(2), uint16(4), SortKey.RELEVANCE.name, True, None, None, None),),
],
"get_coin_records": [
(GetCoinRecords(coin_id_filter=HashFilter.include([expected_coin_id])),),
Expand Down
4 changes: 2 additions & 2 deletions chia/_tests/pools/test_pool_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from chia.consensus.constants import ConsensusConstants
from chia.pools.pool_puzzles import SINGLETON_LAUNCHER_HASH
from chia.pools.pool_wallet_info import PoolSingletonState, PoolWalletInfo
from chia.rpc.wallet_request_types import GetWalletBalance, GetWallets
from chia.rpc.wallet_request_types import GetTransactions, GetWalletBalance, GetWallets
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.simulator.block_tools import BlockTools, get_plot_dir
from chia.simulator.full_node_simulator import FullNodeSimulator
Expand Down Expand Up @@ -483,7 +483,7 @@ async def test_absorb_self(
with pytest.raises(ValueError):
await client.pw_absorb_rewards(2, uint64(fee))

tx1 = await client.get_transactions(1)
tx1 = (await client.get_transactions(GetTransactions(uint32(1)))).transactions
assert (250_000_000_000 + fee) in [tx.amount for tx in tx1]

@pytest.mark.anyio
Expand Down
53 changes: 36 additions & 17 deletions chia/_tests/wallet/rpc/test_wallet_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
GetPrivateKey,
GetTimestampForHeight,
GetTransaction,
GetTransactions,
GetWalletBalance,
GetWalletBalances,
GetWallets,
Expand Down Expand Up @@ -951,18 +952,20 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment

await generate_funds(full_node_api, env.wallet_1, 5)

all_transactions = await client.get_transactions(1)
all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions
assert len(all_transactions) >= 10
# Test transaction pagination
some_transactions = await client.get_transactions(1, 0, 5)
some_transactions_2 = await client.get_transactions(1, 5, 10)
some_transactions = (await client.get_transactions(GetTransactions(uint32(1), uint16(0), uint16(5)))).transactions
some_transactions_2 = (
await client.get_transactions(GetTransactions(uint32(1), uint16(5), uint16(10)))
).transactions
assert some_transactions == all_transactions[0:5]
assert some_transactions_2 == all_transactions[5:10]

# Testing sorts
# Test the default sort (CONFIRMED_AT_HEIGHT)
assert all_transactions == sorted(all_transactions, key=attrgetter("confirmed_at_height"))
all_transactions = await client.get_transactions(1, reverse=True)
all_transactions = (await client.get_transactions(GetTransactions(uint32(1), reverse=True))).transactions
assert all_transactions == sorted(all_transactions, key=attrgetter("confirmed_at_height"), reverse=True)

# Test RELEVANCE
Expand All @@ -972,13 +975,17 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment
1, uint64(1), encode_puzzle_hash(puzhash, "txch"), DEFAULT_TX_CONFIG
) # Create a pending tx

all_transactions = await client.get_transactions(1, sort_key=SortKey.RELEVANCE)
all_transactions = (
await client.get_transactions(GetTransactions(uint32(1), sort_key=SortKey.RELEVANCE.name))
).transactions
sorted_transactions = sorted(all_transactions, key=attrgetter("created_at_time"), reverse=True)
sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed_at_height"), reverse=True)
sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed"))
assert all_transactions == sorted_transactions

all_transactions = await client.get_transactions(1, sort_key=SortKey.RELEVANCE, reverse=True)
all_transactions = (
await client.get_transactions(GetTransactions(uint32(1), sort_key=SortKey.RELEVANCE.name, reverse=True))
).transactions
sorted_transactions = sorted(all_transactions, key=attrgetter("created_at_time"))
sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed_at_height"))
sorted_transactions = sorted(sorted_transactions, key=attrgetter("confirmed"), reverse=True)
Expand All @@ -989,31 +996,43 @@ async def test_get_transactions(wallet_rpc_environment: WalletRpcTestEnvironment
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
await client.send_transaction(1, uint64(1), encode_puzzle_hash(ph_by_addr, "txch"), DEFAULT_TX_CONFIG)
await full_node_api.wait_for_wallet_synced(wallet_node=wallet_node, timeout=20)
tx_for_address = await client.get_transactions(1, to_address=encode_puzzle_hash(ph_by_addr, "txch"))
tx_for_address = (
await client.get_transactions(GetTransactions(uint32(1), to_address=encode_puzzle_hash(ph_by_addr, "txch")))
).transactions
assert len(tx_for_address) == 1
assert tx_for_address[0].to_puzzle_hash == ph_by_addr

# Test type filter
all_transactions = await client.get_transactions(
1, type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD])
)
all_transactions = (
await client.get_transactions(
GetTransactions(uint32(1), type_filter=TransactionTypeFilter.include([TransactionType.COINBASE_REWARD]))
)
).transactions
assert len(all_transactions) == 5
assert all(transaction.type == TransactionType.COINBASE_REWARD for transaction in all_transactions)
# Test confirmed filter
all_transactions = await client.get_transactions(1, confirmed=True)
all_transactions = (await client.get_transactions(GetTransactions(uint32(1), confirmed=True))).transactions
assert len(all_transactions) == 10
assert all(transaction.confirmed for transaction in all_transactions)
all_transactions = await client.get_transactions(1, confirmed=False)
all_transactions = (await client.get_transactions(GetTransactions(uint32(1), confirmed=False))).transactions
assert len(all_transactions) == 2
assert all(not transaction.confirmed for transaction in all_transactions)

# Test bypass broken txs
await wallet.wallet_state_manager.tx_store.add_transaction_record(
dataclasses.replace(all_transactions[0], type=uint32(TransactionType.INCOMING_CLAWBACK_SEND))
)
all_transactions = await client.get_transactions(
1, type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]), confirmed=False
dataclasses.replace(
all_transactions[0].to_transaction_record(), type=uint32(TransactionType.INCOMING_CLAWBACK_SEND)
)
)
all_transactions = (
await client.get_transactions(
GetTransactions(
uint32(1),
type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CLAWBACK_SEND]),
confirmed=False,
)
)
).transactions
assert len(all_transactions) == 1


Expand All @@ -1026,7 +1045,7 @@ async def test_get_transaction_count(wallet_rpc_environment: WalletRpcTestEnviro

await generate_funds(full_node_api, env.wallet_1)

all_transactions = await client.get_transactions(1)
all_transactions = (await client.get_transactions(GetTransactions(uint32(1)))).transactions
assert len(all_transactions) > 0
transaction_count = await client.get_transaction_count(1)
assert transaction_count == len(all_transactions)
Expand Down
21 changes: 13 additions & 8 deletions chia/_tests/wallet/vc_wallet/test_vc_wallet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

import pytest
from chia_rs import G2Element
from chia_rs.sized_ints import uint32
from typing_extensions import Literal

from chia._tests.environments.wallet import WalletEnvironment, WalletStateTransition, WalletTestFramework
from chia._tests.util.time_out_assert import time_out_assert_not_none
from chia.rpc.wallet_request_types import GetWallets, WalletInfoResponse
from chia.rpc.wallet_request_types import GetTransactions, GetWallets, WalletInfoResponse
from chia.rpc.wallet_rpc_client import WalletRpcClient
from chia.simulator.full_node_simulator import FullNodeSimulator
from chia.types.blockchain_format.coin import coin_as_list
Expand Down Expand Up @@ -430,13 +431,17 @@ async def test_vc_lifecycle(wallet_environments: WalletTestFramework) -> None:
assert await wallet_node_1.wallet_state_manager.wallets[env_1.dealias_wallet_id("crcat")].match_hinted_coin(
next(c for tx in txs for c in tx.additions if c.amount == 90), wallet_1_ph
)
pending_tx = await client_1.get_transactions(
env_1.dealias_wallet_id("crcat"),
0,
1,
reverse=True,
type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CRCAT_PENDING]),
)
pending_tx = (
await client_1.get_transactions(
GetTransactions(
uint32(env_1.dealias_wallet_id("crcat")),
uint16(0),
uint16(1),
reverse=True,
type_filter=TransactionTypeFilter.include([TransactionType.INCOMING_CRCAT_PENDING]),
)
)
).transactions
assert len(pending_tx) == 1

# Send the VC to wallet_1 to use for the CR-CATs
Expand Down
16 changes: 13 additions & 3 deletions chia/cmds/wallet_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
CATSpendResponse,
GetNotifications,
GetTransaction,
GetTransactions,
GetWalletBalance,
GetWallets,
SendTransactionResponse,
Expand Down Expand Up @@ -207,9 +208,18 @@ async def get_transactions(
[TransactionType.INCOMING_CLAWBACK_RECEIVE, TransactionType.INCOMING_CLAWBACK_SEND]
)
)
txs: list[TransactionRecord] = await wallet_client.get_transactions(
wallet_id, start=offset, end=(offset + limit), sort_key=sort_key, reverse=reverse, type_filter=type_filter
)
txs = (
await wallet_client.get_transactions(
GetTransactions(
uint32(wallet_id),
start=uint16(offset),
end=uint16(offset + limit),
sort_key=sort_key.name,
reverse=reverse,
type_filter=type_filter,
)
)
).transactions

address_prefix = selected_network_address_prefix(config)
if len(txs) == 0:
Expand Down
55 changes: 53 additions & 2 deletions chia/rpc/wallet_request_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
from chia.wallet.trade_record import TradeRecord
from chia.wallet.trading.offer import Offer
from chia.wallet.transaction_record import TransactionRecord
from chia.wallet.transaction_sorting import SortKey
from chia.wallet.util.clvm_streamable import json_deserialize_with_clvm_streamable
from chia.wallet.util.query_filter import TransactionTypeFilter
from chia.wallet.util.tx_config import TXConfig
from chia.wallet.vc_wallet.vc_store import VCRecord
from chia.wallet.wallet_info import WalletInfo
Expand Down Expand Up @@ -52,6 +54,12 @@ class UserFriendlyMemos:
def __init__(self, unfriendly_memos: list[tuple[bytes32, list[bytes]]]) -> None:
self.unfriendly_memos = unfriendly_memos

def __eq__(self, other: Any) -> bool:
if isinstance(other, UserFriendlyMemos) and other.unfriendly_memos == self.unfriendly_memos:
return True
else:
return False

def __bytes__(self) -> bytes:
raise NotImplementedError("Should not be serializing this object as bytes, it's only for RPC")

Expand All @@ -74,6 +82,9 @@ def from_json_dict(cls, json_dict: dict[str, Any]) -> UserFriendlyMemos:
)


_T_UserFriendlyTransactionRecord = TypeVar("_T_UserFriendlyTransactionRecord", bound="UserFriendlyTransactionRecord")


@streamable
@dataclass(frozen=True)
class UserFriendlyTransactionRecord(TransactionRecord):
Expand All @@ -87,9 +98,11 @@ def to_transaction_record(self) -> TransactionRecord:
return TransactionRecord.from_json_dict_convenience(self.to_json_dict())

@classmethod
def from_transaction_record(cls, tx: TransactionRecord, config: dict[str, Any]) -> UserFriendlyTransactionRecord:
def from_transaction_record(
cls: type[_T_UserFriendlyTransactionRecord], tx: TransactionRecord, config: dict[str, Any]
) -> _T_UserFriendlyTransactionRecord:
dict_convenience = tx.to_json_dict_convenience(config)
return super().from_json_dict(dict_convenience)
return cls.from_json_dict(dict_convenience)


@streamable
Expand Down Expand Up @@ -321,6 +334,44 @@ class GetTransactionResponse(Streamable):
transaction_id: bytes32


@streamable
@dataclass(frozen=True)
class GetTransactions(Streamable):
wallet_id: uint32
start: Optional[uint16] = None
end: Optional[uint16] = None
sort_key: Optional[str] = None
reverse: bool = False
to_address: Optional[str] = None
type_filter: Optional[TransactionTypeFilter] = None
confirmed: Optional[bool] = None

def __post_init__(self) -> None:
if self.sort_key is not None and self.sort_key not in SortKey.__members__:
raise ValueError(f"There is no known sort {self.sort_key}")


# utilities for GetTransactionsResponse
@streamable
@dataclass(frozen=True)
class TransactionRecordMetadata(Streamable):
coin_id: bytes32
spent: bool


@streamable
@dataclass(frozen=True)
class UserFriendlyTransactionRecordWithMetadata(UserFriendlyTransactionRecord):
metadata: Optional[TransactionRecordMetadata] = None


@streamable
@dataclass(frozen=True)
class GetTransactionsResponse(Streamable):
transactions: list[UserFriendlyTransactionRecordWithMetadata]
wallet_id: uint32


@streamable
@dataclass(frozen=True)
class GetNotifications(Streamable):
Expand Down
Loading

0 comments on commit 521fa64

Please sign in to comment.